/*
  The Broad Institute
  SOFTWARE COPYRIGHT NOTICE AGREEMENT
  This software and its documentation are copyright (2005) by the
  Broad Institute/Massachusetts Institute of Technology. All rights are
  reserved.

  This software is supplied without any warranty or guaranteed support
  whatsoever. Neither the Broad Institute nor MIT can be responsible for its
  use, misuse, or functionality.
*/
/* $Id */

#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
#include "defs.h"
#include "segment.h"
#include "node.h"
#include "nodelist.h"
#include "pop.h"
#include "demography.h"
#include "recomb.h"
#include "geneconversion.h"
#include "../cosi_rand/random.h"
#include "../cosi_rand/poisson.h"
#include "sweep.h"
int sw_coalesce(int nnode, Node **nodes, Pop *popptr, double t);
double sw_get_poisson_rate(popid sel_popname);
void sw_do_poisson(popid sel_popname, double gen);

genid sweep_execute(popid popname, double sel_coeff, genid gen, double sel_pos, int start_nchrom, double final_sel_freq) {
  Pop *popptr;
  int inode, nsel, nunsel, newn, foundit;
  Node **sel_nodes, **unsel_nodes, **rec_nodes, **gc_nodes;
  Node *old_allele_node, *new_allele_node;
  double deltaT, sel_freq, epsilon, t, x, tend, alpha, rate;
  double prob_coal_unsel, prob_coal_sel, prob_recomb, all_prob;
  double tsave, loc, end_shift=0., t_nonsweep, poisson_event_time;
  double prob_gc, loc1, loc2;
  int selsize, unselsize, dorandom=1;

  popptr = dg_get_pop_by_name(popname);
  nunsel = nsel = 0;
  selsize = unselsize = popptr->members.nummembers;
  sel_nodes = malloc(selsize * sizeof(Node*));
  unsel_nodes = malloc(unselsize * sizeof(Node*));
  assert(sel_nodes != NULL && unsel_nodes != NULL);

  epsilon = 1 / (2. * popptr->popsize);
  if (final_sel_freq > 1-epsilon) {
    end_shift = 0.;
    dorandom = 0;
  }
  else {
    end_shift = log( (1-final_sel_freq) / (final_sel_freq * epsilon * (1-epsilon)) ) / sel_coeff;
  }
  deltaT = .01 / sel_coeff;
  tend = gen - 2 * log(epsilon) / sel_coeff;
  alpha = 2 * popptr->popsize * sel_coeff;

  for (inode = 0; inode < popptr->members.nummembers; inode++) {
    if (!dorandom || random_double() < final_sel_freq) {
      sel_nodes[nsel] = popptr->members.nodes[inode];
      nsel++;
    }
    else {
      unsel_nodes[nunsel] = popptr->members.nodes[inode];
      nunsel++;
    }
  }
  sel_freq = 1;
  tsave = gen;
  
  rate = sw_get_poisson_rate(popname); 
  poisson_event_time = 1e50;
  if (rate > 0) {
    poisson_event_time = poisson_get_next(rate);
  }
  t_nonsweep = gen + poisson_event_time;

  for (t = gen; sel_freq >= epsilon; t += deltaT) {
    while (t > t_nonsweep) {
      /* Process event in a nonsweep population, and update time for next nonsweep event */
      sw_do_poisson(popname, t_nonsweep); 
      poisson_event_time = poisson_get_next(sw_get_poisson_rate(popname));
      t_nonsweep += poisson_event_time;  
    }
    sel_freq = epsilon / (epsilon + (1 - epsilon) * exp(sel_coeff*(t + end_shift - tend))); 
    prob_coal_sel = deltaT * nsel * (nsel-1) / 4 / popptr->popsize / sel_freq;
    prob_coal_unsel = deltaT * nunsel * (nunsel-1) / 4 / popptr->popsize / (1 - sel_freq);
    prob_recomb = deltaT * (nsel + nunsel) * recomb_get_r();
    prob_gc = deltaT * (nsel + nunsel) * gc_get_r();
    all_prob = 0; 
    all_prob += prob_coal_unsel;
    all_prob += prob_coal_sel;
    all_prob += prob_recomb;
    all_prob += prob_gc; 

    x = random_double();

    if (all_prob >= x) {
      if (x < prob_recomb / all_prob) {
	/* Recombination.  Pick sel/unsel for new chrom piece */
	int selpop_index = dg_get_pop_index_by_name(popname);
	rec_nodes = recomb_execute(t, selpop_index, &loc);
	if (rec_nodes != NULL) {
	  /* Which segment carries old allele? */
	  if (loc < sel_pos) {
	    old_allele_node = rec_nodes[2];
	    new_allele_node = rec_nodes[1];
	  }
	  else {
	    old_allele_node = rec_nodes[1];
	    new_allele_node = rec_nodes[2];
	  }

	  foundit = 0;
	  for (inode = 0; inode < nsel; inode++) {
	    if (sel_nodes[inode]->name == rec_nodes[0]->name) {
	      sel_nodes[inode] = old_allele_node;
	      foundit = 1;
	      break;
	    }
	  }
	  if (!foundit) {
	    for (inode = 0; inode < nunsel; inode++) {
	      if (unsel_nodes[inode]->name == rec_nodes[0]->name) {
		unsel_nodes[inode] = old_allele_node;
		foundit = 1;
		break;
	      }
	    }
	  }
	  if (foundit != 1) {fprintf(stderr, "Could not find %d\n", rec_nodes[0]->name);}
	  assert(foundit == 1);

	  /* Decide whether the chrom that we've recombined onto is selected or unselected */
	  if (random_double() < sel_freq) {
	    if (nsel == selsize) {
	      selsize *= 2;
	      sel_nodes = realloc(sel_nodes, selsize * sizeof(Node*));
	      assert(sel_nodes != NULL);
	    }
	    sel_nodes[nsel] = new_allele_node;
	    nsel++;
	  }
	  else {
	    if (nunsel == unselsize) {
	      unselsize *= 2;
	      unsel_nodes = realloc(unsel_nodes, unselsize * sizeof(Node*));
	      assert(unsel_nodes != NULL);
	    }
	    unsel_nodes[nunsel] = new_allele_node;
	    nunsel++; 
	  }
	  free(rec_nodes);
	}
      }

      else if (x < (prob_recomb + prob_coal_unsel) / all_prob) {
	/* Coalescence among unselected chroms */
	newn = sw_coalesce(nunsel, unsel_nodes, popptr, t);
	assert(newn == nunsel - 1);
	nunsel = newn;
      }
      else if (x < (prob_recomb + prob_coal_unsel + prob_coal_sel) / all_prob) {
	/* Coalescence among selected chroms */
	newn = sw_coalesce(nsel, sel_nodes, popptr, t);
	assert(newn == nsel - 1);
	nsel = newn;
      }
      else {
	/* Gene conversion. Pick sel/unsel for new chrom piece */
	int selpop_index = dg_get_pop_index_by_name(popname);
	gc_nodes = gc_execute(t, selpop_index, &loc1, &loc2);
	if (gc_nodes != NULL) {
	  /* Which segment carries old allele? */
	  if (loc1 <= sel_pos && loc2 >= sel_pos) {
	    old_allele_node = gc_nodes[2];
	    new_allele_node = gc_nodes[1];
	  }
	  else {
	    old_allele_node = gc_nodes[1];
	    new_allele_node = gc_nodes[2];
	  }

	  foundit = 0;
	  for (inode = 0; inode < nsel; inode++) {
	    if (sel_nodes[inode]->name == gc_nodes[0]->name) {
	      sel_nodes[inode] = old_allele_node;
	      foundit = 1;
	      break;
	    }
	  }
	  if (!foundit) {
	    for (inode = 0; inode < nunsel; inode++) {
	      if (unsel_nodes[inode]->name == gc_nodes[0]->name) {
		unsel_nodes[inode] = old_allele_node;
		foundit = 1;
		break;
	      }
	    }
	  }
	  if (foundit != 1) {fprintf(stderr, "Could not find %d\n", gc_nodes[0]->name);}
	  assert(foundit == 1);

	  /* Decide whether the chrom that we've recombined onto is selected or unselected */
	  if (random_double() < sel_freq) {
	    if (nsel == selsize) {
	      selsize *= 2;
	      sel_nodes = realloc(sel_nodes, selsize * sizeof(Node*));
	      assert(sel_nodes != NULL);
	    }
	    sel_nodes[nsel] = new_allele_node;
	    nsel++;
	  }
	  else {
	    if (nunsel == unselsize) {
	      unselsize *= 2;
	      unsel_nodes = realloc(unsel_nodes, unselsize * sizeof(Node*));
	      assert(unsel_nodes != NULL);
	    }
	    unsel_nodes[nunsel] = new_allele_node;
	    nunsel++; 
	  }
	  free(gc_nodes);
	}
      } 

      if (popptr->members.nummembers != nsel + nunsel) {
	printf("mismatch: pop(n): %d  here: %d\n", popptr->members.nummembers, nsel + nunsel);
      }
    }
    if (nsel == start_nchrom) {break;}
  }
  printf("selective sweep ends at %f generations\n", t);  
  return t;
}

int sw_coalesce(int nnode, Node **nodes, Pop *popptr, double t) {
  /* The difference between this routine and the one in demography.c is that here 
     I have to update both the main population node list and the local list that I'm using. */
  int node1index, node2index, inode, nodesatloc, contains;
  Node *node1, *node2, *newnode;
  struct sitelist * sitetemp = dg_recombsites();
  double loc;
  Segment *tseg;

  /* STEP 1 */
  node1index = (int) (random_double() * nnode);
  node2index = (int) (random_double() * (nnode - 1));
  if (node2index >= node1index) node2index++;
  node1 = nodes[node1index];
  node2 = nodes[node2index];

  /* STEP 2 */
  newnode = node_coalesce (node1, node2, t);
  
  /* STEP 2a */
  while (sitetemp->next != NULL) {
    loc = (sitetemp->site + sitetemp->next->site) / 2;
    nodesatloc = 0;
    tseg = node1->segs;
	  contains = 0;
	  while (tseg != NULL) {
	    if (loc >= tseg->begin) {
	      if (loc <= tseg->end) {
		contains = 1;
		break;
	      }
	      tseg = tseg->next;
	    }
	    else { 
	      contains = 0;
	      break;
	    }
	  }
	  if (contains) {
	    nodesatloc++;
	  }
	  tseg = node2->segs;
	  contains = 0;
	  while (tseg != NULL) {
	    if (loc >= tseg->begin) {
	      if (loc <= tseg->end) {
		contains = 1;
		break;
	      }
	      tseg = tseg->next;
	    }
	    else { 
	      contains = 0;
	      break;
	    }
	  }
	  if (contains) {
	    nodesatloc++;
	  }
	  if (nodesatloc > 1) {
	    sitetemp->nnode--;
	  }
	  sitetemp = sitetemp->next;
	}

  /* STEP 3 */
  pop_remove_node (popptr, node1);
  pop_remove_node (popptr, node2);
  if (node1index < node2index) {
    for (inode = node2index; inode < nnode-1; inode++) {nodes[inode] = nodes[inode+1];}
    nnode--;
    for (inode = node1index; inode < nnode-1; inode++) {nodes[inode] = nodes[inode+1];}
  }
  else {
    for (inode = node1index; inode < nnode-1; inode++) {nodes[inode] = nodes[inode+1];}
    nnode--;
    for (inode = node2index; inode < nnode-1; inode++) {nodes[inode] = nodes[inode+1];}
  }
  nnode--;
  
  /* STEP 4 */
  pop_add_node (popptr, newnode);
  nodes[nnode] = newnode;
  nnode++;

  /* STEP 5 */
  dg_log (COALESCE, t, node1, node2, newnode, popptr);
 
  return nnode;
}

/* coalesce_get_rate, modified to ignore coalescence in target pop */
double sw_coalesce_get_rate (popid sel_popname) {
  int numpops = dg_get_num_pops();
  int i;
  double rate = 0;
  int numnodes;
  int popsize;
  
  for (i = 0; i < numpops; i++) {
    popid thispop = dg_get_pop_name_by_index(i);
    if (thispop == sel_popname) {continue;}
    numnodes = dg_get_num_nodes_in_pop_by_index (i);
    popsize = dg_get_pop_size_by_index (i);
    if (numnodes > 1)
      rate += (double) (numnodes * (numnodes - 1)) 
	/ (4 * popsize);
  }
  return rate;
}

/* Similar to preceeding function */
double sw_recomb_get_rate (popid sel_popname) {
  int numpops = dg_get_num_pops();
  int i;
  double rate = 0, r = recomb_get_r();
  int numnodes;

  for (i = 0; i < numpops; i++) {
    popid thispop = dg_get_pop_name_by_index(i);
    if (thispop == sel_popname) {continue;}
    numnodes = dg_get_num_nodes_in_pop_by_index (i);
    rate += ( numnodes * r);
  }
  return rate;  
}

/* Similar to preceeding function */
double sw_gc_get_rate (popid sel_popname) {
  int numpops = dg_get_num_pops();
  int i;
  double rate = 0, r = gc_get_r();
  int numnodes;

  for (i = 0; i < numpops; i++) {
    popid thispop = dg_get_pop_name_by_index(i);
    if (thispop == sel_popname) {continue;}
    numnodes = dg_get_num_nodes_in_pop_by_index (i);
    rate += ( numnodes * r);
  }
  return rate;  
}

int sw_coalesce_pick_popindex (popid sel_popname) {
  double  randcounter, sumprob, totprob;
  int     popindex = -1,
    numpops = dg_get_num_pops(),
    numnodes, popsize, i;

  totprob = 0;
  for (i = 0; i < numpops; i++) {
      popid thispop = dg_get_pop_name_by_index(i);
      if (thispop == sel_popname) {continue;}
      numnodes = dg_get_num_nodes_in_pop_by_index (i);
      popsize = dg_get_pop_size_by_index (i);
      totprob += (double) (numnodes * (numnodes - 1)) 
	/ (4 * popsize);
  }
  randcounter = random_double() * totprob;

  sumprob = 0;
  for (i = 0; i < numpops; i++) {
    popid thispop = dg_get_pop_name_by_index(i);
    if (thispop == sel_popname) {continue;}
    numnodes = dg_get_num_nodes_in_pop_by_index (i);
    popsize = dg_get_pop_size_by_index (i);
    sumprob += (double) (numnodes * (numnodes - 1)) 
	/ (4 * popsize);
    if (randcounter <= sumprob) {
        popindex = i;
	break;
    }
  }		
  assert(popindex >= 0);
  return popindex;
}

int sw_recomb_pick_popindex(popid sel_popname) {
  int popindex = -1, i, randcounter, sumnodes = 0;
  int totnodes;
  int numpops = dg_get_num_pops();
  
  totnodes = 0;
  for (i = 0; i < numpops; i++) {
    popid thispop = dg_get_pop_name_by_index(i);
    if (thispop == sel_popname) {continue;}
    totnodes += dg_get_num_nodes_in_pop_by_index (i);
  }
  randcounter = (int) (totnodes * random_double() + .5);
  assert(randcounter <= totnodes);
  
  /* weight pops by numnodes. */
  for (i = 0; i < numpops; i++) {
    popid thispop = dg_get_pop_name_by_index(i);
    if (thispop == sel_popname) {continue;}
    sumnodes += dg_get_num_nodes_in_pop_by_index(i); 
    if (randcounter <= sumnodes) {
      popindex = i;
      break;
    }
  }
  assert(popindex >= 0);
  return popindex;
}

static double coalesce_rate;
static double migrate_rate;
static double recombination_rate;
static double geneconv_rate;
static double poisson_rate;

double sw_get_poisson_rate(popid popname) {
  coalesce_rate = sw_coalesce_get_rate(popname);
  /*  migrate_rate = migrate_get_rate(); */
  geneconv_rate = sw_gc_get_rate(popname); 
  recombination_rate = sw_recomb_get_rate(popname);
  poisson_rate = (double) (coalesce_rate + migrate_rate + recombination_rate + geneconv_rate);
  return poisson_rate;
}

void sw_do_poisson(popid sel_popname, double gen) {
  double randdouble = random_double();
  double dum, dum2;
  int recomb_pop;
  
  if (randdouble < recombination_rate / poisson_rate) {
    recomb_pop = sw_recomb_pick_popindex(sel_popname);
    recomb_execute(gen, recomb_pop, &dum);
  }
  else if (randdouble < (recombination_rate + migrate_rate) / poisson_rate) {
    /* migrate_execute(gen);*/
  }
  else if (randdouble < (recombination_rate + migrate_rate + coalesce_rate) / poisson_rate) { 
    int popindex = sw_coalesce_pick_popindex(sel_popname);
    dg_coalesce_by_index (popindex, gen);
  }
  else {
    int popindex = sw_recomb_pick_popindex(sel_popname);
    gc_execute(gen, popindex, &dum, &dum2);  
  }
  return;

}
