
// TCPW-NJ - TCPW Newjersey with NewReno Features
//
//

#include <stdio.h>
#include <stdlib.h>
#include <sys/types.h>
#include <math.h>

#include "packet.h"
#include "ip.h"
#include "tcp.h"
#include "flags.h"
#include "address.h"

#include "tcp-nj.h"

static class NJTcpClass : public TclClass {
public:
	NJTcpClass() : TclClass("Agent/TCP/NJ") {}
	TclObject* create(int, const char*const*) {
		return (new NJTcpAgent());
	}
} class_tcp_nj;

NJTcpAgent::NJTcpAgent() : NewRenoTcpAgent(),
	current_are_(0), t_last_(0), min_rtt_estimate(5.0), wloss_cnt_(0),
	loss_diff_(0), congested_(0)
  
{
	ts_option_ = TRUE;		

	bind("current_are_", &current_are_);
  	bind("min_rtt_estimate", &min_rtt_estimate);
	bind_bool("loss_diff_", &loss_diff_);
	
	bind("newreno_changes_", &newreno_changes_);
	bind("newreno_changes1_", &newreno_changes1_);
	bind("exit_recovery_fix_", &exit_recovery_fix_);
	bind("partial_window_deflation_", &partial_window_deflation_);

	wloss_.uid = -1;
}

void 
NJTcpAgent::ecn(int seqno)
{
	if (seqno > recover_ || 
	      last_cwnd_action_ == CWND_ACTION_TIMEOUT) {
		recover_ =  maxseq_;
		last_cwnd_action_ = CWND_ACTION_ECN;
		if (cwnd_ <= 1.0) {
			if (ecn_backoff_) 
				rtt_backoff();
			else ecn_backoff_ = 1;
		} else ecn_backoff_ = 0;
		reduce_to_est_rate();
		// remember to set the CWR bit
		cong_action_ = 1;
		// slowdown already performed by reduce_to_est_rate()
		//slowdown(CLOSE_CWND_HALF|CLOSE_SSTHRESH_HALF);
		++necnresponses_ ;
		// added by sylvia to count number of ecn responses 
	}
}


void NJTcpAgent::dupack_action()
{

	int recovered = (highest_ack_ > recover_);
        int allowFastRetransmit = allow_fast_retransmit(last_cwnd_action_);
        if (recovered || (!bug_fix_ && !ecn_) || allowFastRetransmit) {
                goto reno_action;
        }

        if (ecn_ && last_cwnd_action_ == CWND_ACTION_ECN) {
                last_cwnd_action_ = CWND_ACTION_DUPACK;
                /*
                 * What if there is a DUPACK action followed closely by ECN
                 * followed closely by a DUPACK action?
                 * The optimal thing to do would be to remember all
                 * congestion actions from the most recent window
                 * of data.  Otherwise "bugfix" might not prevent
                 * all unnecessary Fast Retransmits.
                 */
                reset_rtx_timer(1,0);
                output(last_ack_ + 1, TCP_REASON_DUPACK);
                return;
        }

	if(! marked_) {	// TCPNJ action when the 3rd DUPACK is not marked
                last_cwnd_action_ = CWND_ACTION_DUPACK;
                reset_rtx_timer(1,0);
                output(last_ack_ + 1, TCP_REASON_DUPACK);
		wloss_cnt_++;	/* trigger the trace of wloss_ structure */
                return;
	}

        if (bug_fix_) {
                /*
                 * The line below, for "bug_fix_" true, avoids
                 * problems with multiple fast retransmits in one
                 * window of data.
                 */
                return;
        }

reno_action:

	if(marked_) { // reduce the rate only if loss due to congestion
		reduce_to_est_rate();
	} else {
		wloss_cnt_++;	/* trigger the trace of wloss_ structure */
	}
              
	trace_event("TCPNJ_FAST_RETX");
        recover_ = maxseq_;
        last_cwnd_action_ = CWND_ACTION_DUPACK;
        // The slowdown was already performed by reduce_to_est_rate()
        // slowdown(CLOSE_SSTHRESH_HALF|CLOSE_CWND_HALF);
        reset_rtx_timer(1,0);
        output(last_ack_ + 1, TCP_REASON_DUPACK);
        return;
      
}

void
NJTcpAgent::reduce_to_est_rate()
{
	double are_ = double(current_are_);

        ssthresh_ = (int)((are_/size_/8) * min_rtt_estimate);
      
        // Safety Check: ssthresh should not be < 2
        if (ssthresh_ < 2) {
        	ssthresh_ = 2;
        }
	
	/* our algorithm dictates that cwnd = ssthresh after a 3DUPACK, but */
	/* we should not forcefully increase awnd if it is smaller than   */
	/* ssthresh                                                       */
	
      	if (cwnd_ > ssthresh_) {
      		cwnd_ = ssthresh_;
      	}

}


void NJTcpAgent::timeout(int tno)
{
	/* retransmit timer */
	if (tno == TCP_TIMER_RTX) {
		// These three lines catch the RenoTcpAgent::timeout() behavior
		dupwnd_ = 0;
		dupacks_ = 0;
		if (bug_fix_) recover_ = maxseq_;
				
		// There has been a timeout - will trace this event
		trace_event("TIMEOUT");

	        if (cwnd_ < 1) cwnd_ = 1;
		if (highest_ack_ == maxseq_ && !slow_start_restart_) {
			/*
			 * TCP option:
			 * If no outstanding data, then don't do anything.  
			 */
			 // Should this return be here?
			 // What if CWND_ACTION_ECN and cwnd < 1?
			 // return;
		} else {
			recover_ = maxseq_;
			if (highest_ack_ == -1 && wnd_init_option_ == 2)
				/* 
				 * First packet dropped, so don't use larger
				 * initial windows. 
				 */
				wnd_init_option_ = 1;
			if (highest_ack_ == maxseq_ && restart_bugfix_)
			       /* 
				* if there is no outstanding data, don't cut 
				* down ssthresh_.
				*/
				slowdown(CLOSE_CWND_ONE);
			else if (highest_ack_ < recover_ &&
			  last_cwnd_action_ == CWND_ACTION_ECN) {
			       /*
				* if we are in recovery from a recent ECN,
				* don't cut down ssthresh_.
				*/
				slowdown(CLOSE_CWND_ONE);
			}
			else {
				++nrexmit_;
				last_cwnd_action_ = CWND_ACTION_TIMEOUT;
				//slowdown(CLOSE_SSTHRESH_HALF|CLOSE_CWND_RESTART);
				slowdown(CLOSE_TCPNJ); // TCPNJ action
			}
		}
		/* if there is no outstanding data, don't back off rtx timer */
		if (highest_ack_ == maxseq_ && restart_bugfix_) {
			reset_rtx_timer(0,0);
		}
		else {
			reset_rtx_timer(0,1);
		}
		last_cwnd_action_ = CWND_ACTION_TIMEOUT;
		send_much(0, TCP_REASON_TIMEOUT, maxburst_);
	} 
	else {
		timeout_nonrtx(tno);
	}
}

void 
NJTcpAgent::compute_are(Packet *pkt) 
{
	
	hdr_flags *fh = hdr_flags::access(pkt);
	hdr_tcp  *th = hdr_tcp::access(pkt);
	int use_ts = (!fh->no_ts_ & ts_option_);
	int seq = th->seqno();

	if(!use_ts) {
		fprintf(stderr, "TCPNJ must set ts_option_ to TRUE, %d, %d\n",
		fh->no_ts_, ts_option_);
		abort();
	}
	
	double rtt_estimate = t_rtt_ * tcp_tick_;
		
	if ((rtt_estimate < min_rtt_estimate)&&(rtt_estimate > 0)) {
		min_rtt_estimate = rtt_estimate;
	}

	double now = Scheduler::instance().clock();
	/*
	 * here use the smoothed rtt plus rtt variance
	*/
	double rtt = (int(t_srtt_) >> T_SRTT_BITS) * tcp_tick_ + int(t_rttvar_) * tcp_tick_ / 4.0;
	double delta;
	int size;

	if(use_ts) {
		delta = th->ts() - t_last_;
		t_last_ = th->ts();
	} else {
		delta = now - t_last_;
		t_last_ = now;
	}

	size = (size_ + tcpip_base_hdr_size_);


	current_are_ = (rtt * double(current_are_) + size * 8) / (delta + rtt);

}


void NJTcpAgent::recv(Packet *pkt, Handler* h)
{
	hdr_tcp   *th = hdr_tcp::access(pkt);
	hdr_flags *fh = hdr_flags::access(pkt);
	hdr_cmn   *ch = hdr_cmn::access(pkt);

	Scheduler& s = Scheduler::instance();

	if(loss_diff_) {				// only if configured to diff. wireless losses
		wloss_.time = &s ? s.clock() : 0;
		wloss_.seq = th->seqno() + 1;		// +1 is necessary since NS ack the received seq
		wloss_.uid = ch->uid();			// we actually record every ACK arrived, but
							// the trace output only triggerred in 
							// dupack_action()

		marked_ = 0;
		if(fh->ecnecho() && ecn_) {		// ECN mark
			congested_ = 1;
			if(th->seqno() == last_ack_) {	// DUPACK
				marked_ = 1;
			}
		} else {
			congested_ = 0;
		}
	} else {
		marked_ = 1;
	}

	compute_are(pkt);

	NewRenoTcpAgent::recv(pkt,h);

}


void
NJTcpAgent::slowdown(int how)
{
	double win, halfwin, decreasewin;
	int slowstart = 0;
	// we are in slowstart for sure if cwnd < ssthresh
	if (cwnd_ < ssthresh_)
		slowstart = 1;
	// we are in slowstart - need to trace this event
	trace_event("SLOW_START");

        if (precision_reduce_) {
		halfwin = windowd() / 2;
                if (wnd_option_ == 6) {
                        /* binomial controls */
                        decreasewin = windowd() - (1.0-decrease_num_)*pow(windowd(),l_parameter_);
                } else
	 		decreasewin = decrease_num_ * windowd();
		win = windowd();
	} else  {
		int temp;
		temp = (int)(window() / 2);
		halfwin = (double) temp;
                if (wnd_option_ == 6) {
                        /* binomial controls */
                        temp = (int)(window() - (1.0-decrease_num_)*pow(window(),l_parameter_));
                } else
	 		temp = (int)(decrease_num_ * window());
		decreasewin = (double) temp;
		win = (double) window();
	}
	if (how & CLOSE_SSTHRESH_HALF)
		// For the first decrease, decrease by half
		// even for non-standard values of decrease_num_.
		if (first_decrease_ == 1 || slowstart ||
			last_cwnd_action_ == CWND_ACTION_TIMEOUT) {
			// Do we really want halfwin instead of decreasewin
			// after a timeout?
			ssthresh_ = (int) halfwin;
		} else {
			ssthresh_ = (int) decreasewin;
		}
        else if (how & THREE_QUARTER_SSTHRESH)
		if (ssthresh_ < 3*cwnd_/4)
			ssthresh_  = (int)(3*cwnd_/4);
	if (how & CLOSE_CWND_HALF)
		// For the first decrease, decrease by half
		// even for non-standard values of decrease_num_.
		if (first_decrease_ == 1 || slowstart || decrease_num_ == 0.5) {
			cwnd_ = halfwin;
		} else cwnd_ = decreasewin;
        else if (how & CWND_HALF_WITH_MIN) {
		// We have not thought about how non-standard TCPs, with
		// non-standard values of decrease_num_, should respond
		// after quiescent periods.
                cwnd_ = decreasewin;
                if (cwnd_ < 1)
                        cwnd_ = 1;
	}
	///
	else if (how & CLOSE_TCPNJ) {
		double rtt_estimate = t_rtt_ * tcp_tick_;
		double are_ = double(current_are_);

		if ((rtt_estimate < min_rtt_estimate)&&(rtt_estimate > 0)) {
			min_rtt_estimate = rtt_estimate;
		}

		ssthresh_ = (int)( ((are_/size_/8) * min_rtt_estimate) );
		if(ssthresh_ < 2)
			ssthresh_ = 2;
		if(loss_diff_) {
			if(congested_) {
				cwnd_ = 1;
				congested_ = 0;
			} else {
				if(cwnd_ > ssthresh_)
					cwnd_ = ssthresh_;
			}
		} else {
			cwnd_ = 1;
		}
	}
	///
	else if (how & CLOSE_CWND_RESTART)
		cwnd_ = int(wnd_restart_);
	else if (how & CLOSE_CWND_INIT)
		cwnd_ = int(wnd_init_);
	else if (how & CLOSE_CWND_ONE)
		cwnd_ = 1;
	else if (how & CLOSE_CWND_HALF_WAY) {
		// cwnd_ = win - (win - W_used)/2 ;
		cwnd_ = W_used + decrease_num_ * (win - W_used);
                if (cwnd_ < 1)
                        cwnd_ = 1;
	}
	if (ssthresh_ < 2)
		ssthresh_ = 2;
	if (how & (CLOSE_CWND_HALF|CLOSE_CWND_RESTART|CLOSE_CWND_INIT|CLOSE_CWND_ONE|CLOSE_TCPNJ))
		cong_action_ = TRUE;

	fcnt_ = count_ = 0;
	if (first_decrease_ == 1)
		first_decrease_ = 0;
}

void NJTcpAgent::newack(Packet* pkt)
{
	//call parent newack
	NewRenoTcpAgent::newack(pkt);
}

int
NJTcpAgent::delay_bind_dispatch(const char *varName, const char *localName, TclObject *tracer)
{

        if (delay_bind(varName, localName, "min_rtt_estimate", &min_rtt_estimate, tracer)) return TCL_OK;
	if (delay_bind(varName, localName, "wloss_", &wloss_cnt_, tracer)) return TCL_OK;
	if (delay_bind(varName, localName, "rate_", &current_are_, tracer)) return TCL_OK;
	if (delay_bind_bool(varName, localName, "loss_diff_", &loss_diff_, tracer)) return TCL_OK;
	
	// these where originally in NewRenoTcpAgent()
	if (delay_bind(varName, localName, "newreno_changes_", &newreno_changes_, tracer)) return TCL_OK;
	if (delay_bind(varName, localName, "newreno_changes1_", &newreno_changes1_, tracer)) return TCL_OK;
	if (delay_bind(varName, localName, "exit_recovery_fix_", &exit_recovery_fix_, tracer)) return TCL_OK;
	if (delay_bind(varName, localName, "partial_window_deflation_", &partial_window_deflation_, tracer)) return TCL_OK;
	
        return NewRenoTcpAgent::delay_bind_dispatch(varName, localName, tracer);
}

void
NJTcpAgent::traceVar(TracedVar* v)
{
	char wrk[500];
	int n;

	if (!strcmp(v->name(), "wloss_")) {
		if(wloss_.uid < 0)
			return;
		sprintf(wrk,"%-8.5f %-2d %-2d %-2d %-2d %s %d %d",
			wloss_.time, addr(), port(), daddr(), dport(),
			v->name(), wloss_.seq, wloss_.uid);
		goto out;
	} if(!strcmp(v->name(), "rate_")) {
		sprintf(wrk,"%-8.5f %-2d %-2d %-2d %-2d %s %-6.3f",
			Scheduler::instance().clock(), addr(), port(), daddr(), dport(),
			v->name(), double(*((TracedDouble*) v)));
		goto out;
	}
	else
		goto parent;
out:
	n = strlen(wrk);
	wrk[n] = '\n';
	wrk[n+1] = 0;
	if (channel_)
		(void)Tcl_Write(channel_, wrk, n+1);
	wrk[n] = 0;
	return;
parent:
	NewRenoTcpAgent::traceVar(v);
}



