
// Rate4site - function are now in rate4siteGL.cpp
/********************************************************************************************
*********************************************************************************************/
void gainLoss::startRate4Site(){

	LOGnOUT(4,<<"Starting rate4site..."<<endl);
	computeRate4site();
	computeAveAndStd(); // put them in _ave, and _std

	ofstream nonNormalizedOutStream(gainLossOptions::_outFileNotNormalize.c_str());
	printRates(nonNormalizedOutStream,_rates);
	nonNormalizedOutStream.close();

	normalizeRates(); // change also the _ave, the _std the quantiles, etc.

	ofstream normalizedOutStream(gainLossOptions::_outFile.c_str());
	printRates(normalizedOutStream,_normalizedRates);
	normalizedOutStream.close();	
}
/********************************************************************************************
computeRate4site
*********************************************************************************************/
Vdouble  gainLoss::computeRate4site()
{
	time_t t1;
	time(&t1);
	time_t t2;

	if (gainLossOptions::_rateEstimationMethod == gainLossOptions::ebExp) {
		LOGnOUT (4,<<"perform computeEB_EXP_siteSpecificRate... while computing posteriorProb PerCategory PerPosition"<<endl);
		_LpostPerCat.resize(_sp->categories());
		for (int rateIndex=0 ; rateIndex<_sp->categories(); ++rateIndex){
			_LpostPerCat[rateIndex].resize(_sc.seqLen());
		}
		computeEB_EXP_siteSpecificRate(_rates,_BayesianSTD,_BayesianLowerBound,_BayesianUpperBound,_sc,*_sp,_tr,_alphaConf,&_LpostPerCat);
	}
	else if (gainLossOptions::_rateEstimationMethod == gainLossOptions::mlRate) {
		LOGnOUT (4,<<"perform computeML_siteSpecificRate with maxRate= "<<gainLossOptions::_maxRateForML<<endl);
		computeML_siteSpecificRate(_rates,_Lrate,_sc, *_sp,_tr, gainLossOptions::_maxRateForML);		
	}
	else 
		errorMsg::reportError("non such method for rate inference, in function void rate4site::computeRate4site()");

	time(&t2);
	LOGnOUT(4,<<endl<<"computeRate4site RUNNING TIME = "<<(t2-t1)/60.0<<" minutes"<<endl);
	return _rates;
}
/********************************************************************************************
printRates
*********************************************************************************************/
void gainLoss::printRates(ostream & out, const Vdouble & rate2print) {

	if (gainLossOptions::_rateDistributionType == gainLossOptions::GAMMA_MIXTURE){
		mixtureDistribution* pMixture = static_cast<mixtureDistribution*>(_sp->distr());
		pMixture->printParams(out);
	}

	switch (gainLossOptions::_rateEstimationMethod){
		case (gainLossOptions::ebExp):  
			printRatesBayes(out,rate2print);
			break;
		case (gainLossOptions::mlRate):
			printRatesML(out,rate2print);
			break;
	}
	printAveAndStd(out);
}
/********************************************************************************************
*********************************************************************************************/
void gainLoss::printRatesML(ostream& out, const Vdouble & rate2print) {
	out<<"#Rates were calculated using Maximum Likelihood"<<endl;
	out<<"#SEQ: The presence(1) or Absence(0) in the reference sequence."<<"Displayed on sequence "<<_refSeq->name()<<endl;
	out<<"#SCORE: The conservation scores. lower value = higher conservation."<<endl;
	out<<"#MSA DATA: The number of aligned sequences having an amino acid (non-gapped) from the overall number of sequences at each position."<<endl;
	out<<endl;
	out<<"========================================================================================================================================================="<<endl;
	out<<"#POS"<<"\t"<<"SEQ"<<"\t"<<"SCORE"<<"\t"<<"MSA DATA"<<endl; // note position start from 1.
	out<<"========================================================================================================================================================="<<endl;

#ifdef unix
	for (int pos=0; pos < _sc.seqLen(); ++pos) {
		out<<pos+1<<"\t"<<_refSeq->getAlphabet()->fromInt((*_refSeq)[pos])<<"\t"<<setprecision(7)<<rate2print[pos]<<"\t";
		out<<_sc.numberOfSequencesWithoutGaps(pos)<<"/"<<_sc.numberOfSeqs()<<endl; // note position start from 1.
	}
#else
	for (int pos=0; pos < _sc.seqLen(); ++pos) {
		out<<left<<pos+1<<left<<"\t"<<_refSeq->getAlphabet()->fromInt((*_refSeq)[pos])<<"\t";
		out<<left<<setprecision(7)<<fixed<<rate2print[pos]<<"\t";
		out<<right<<_sc.numberOfSequencesWithoutGaps(pos)<<"/"<<_sc.numberOfSeqs()<<endl; // note position start from 1. 
	}
#endif
}
/********************************************************************************************
*********************************************************************************************/
void gainLoss::printRatesBayes(ostream& out, const Vdouble & rate2print) {
	out<<"# Rates were calculated using the expectation of the posterior rate distribution"<<endl;
	out<<"# Prior distribution is Gamma with "<<gainLossOptions::_numberOfRateCategories<<" discrete categories"<<endl;
	out<<"# SEQ: The presence(1) or Absence(0) in the reference sequence."<<"Displayed on sequence "<<_refSeq->name()<<endl;
	out<<"# SCORE: The conservation scores. lower value = higher conservation."<<endl;
	out<<"# QQ-INTERVAL: the confidence interval for the rate estimates. The default interval is 25-75 percentiles"<<endl;
	out<<"# STD: the standard deviation of the posterior rate distribution."<<endl;
	out<<"# MSA DATA: The number of aligned sequences having an amino acid (non-gapped) from the overall number of sequences at each position."<<endl;
	MDOUBLE AlphaRate;
	if(dynamic_cast<gammaDistribution*>(_sp->distr()) ) {
		AlphaRate = static_cast<gammaDistribution*>(_sp->distr())->getAlpha();
	}
	if(dynamic_cast<generalGammaDistributionPlusInvariant*>(_sp->distr())){
		AlphaRate = static_cast<generalGammaDistributionPlusInvariant*>(_sp->distr())->getAlpha();
	}
	if(dynamic_cast<gammaDistributionFixedCategories*>(_sp->distr())){
		AlphaRate = static_cast<gammaDistributionFixedCategories*>(_sp->distr())->getAlpha();
	}
	out<<"# The alpha parameter "<<AlphaRate<<endl;
	int k=0;
	while (k < _sp->categories()){
		out<<"# sp.rates(j)  j= " <<k<<"\t"<<_sp->rates(k)<<"\t"<<_sp->ratesProb(k)<<endl;
		k++;
	}		


	out<<endl;
	out<<"========================================================================================================================================================="<<endl;
	out<<"#POS"<<"\t"<<"SEQ"<<"\t"<<"SCORE"<<"\t"<<"QQ-INTERVAL"<<"\t"<<"STD"<<"\t"<<"MSA DATA"<<endl; // note position start from 1.
	out<<"========================================================================================================================================================="<<endl;

#ifdef unix	
	for (int pos=0; pos < _sc.seqLen(); ++pos) {
		out<<pos+1<<"\t"<<_refSeq->getAlphabet()->fromInt((*_refSeq)[pos])<<"\t"<<setprecision(7)<<rate2print[pos]<<"\t";
		out<<"["<<setprecision(4)<<_BayesianLowerBound[pos]<<","<<setprecision(4)<<_BayesianUpperBound[pos]<<"]"<<"\t";
		out<<setprecision(4)<<_BayesianSTD[pos]<<"\t";
		out<<_sc.numberOfSequencesWithoutGaps(pos)<<"/"<<_sc.numberOfSeqs()<<endl; // note position start from 1.
	}
#else
	for (int pos=0; pos < _sc.seqLen(); ++pos) {
		out<<left<<pos+1;
		out<<left<<"\t"<<_refSeq->getAlphabet()->fromInt((*_refSeq)[pos])<<"\t";
		out<<left<<setprecision(7)<<fixed<<rate2print[pos]<<"\t";
		out<<right<<"["<<setprecision(4)<<left<<_BayesianLowerBound[pos]<<","<<setprecision(4)<<right<<_BayesianUpperBound[pos]<<"]"<<"\t";
		out<<right<<setprecision(4)<<_BayesianSTD[pos];
		out<<right<<"\t"<<_sc.numberOfSequencesWithoutGaps(pos)<<"/"<<_sc.numberOfSeqs()<<endl; // note position start from 1.
	}
#endif
}
/********************************************************************************************
*********************************************************************************************/
void gainLoss::printAveAndStd(ostream& out) {
	out<<"#Average = "<<_ave<<endl;
	out<<"#Standard Deviation = "<<_std<<endl;
}
/********************************************************************************************
computeAveAndStd
*********************************************************************************************/
void gainLoss::computeAveAndStd(){
	MDOUBLE sum = 0;
	MDOUBLE sumSqr=0.0;
	for (int i=0; i < _sc.seqLen(); ++i) {
		sum+=_rates[i];
		sumSqr+=(_rates[i]*_rates[i]);
	}
	_ave = sum/_sc.seqLen();
	_std= sumSqr-(sum*sum/_sc.seqLen());
	_std /= (_sc.seqLen()-1.0);
	_std = sqrt(_std);
	if (((_ave<1e-9)) && (_ave>(-(1e-9)))) _ave=0;
	if ((_std>(1-(1e-9))) && (_std< (1.0+(1e-9)))) _std=1.0;
}
/********************************************************************************************
normalizeRates
*********************************************************************************************/
void gainLoss::normalizeRates() {
	int i=0;
	if (_std==0) errorMsg::reportError(" std = 0 in function normalizeRates",1);
	_normalizedRates.resize(_sc.seqLen(),0.0);
	for (i=0;i<_normalizedRates.size();++i) {
		_normalizedRates[i]=(_rates[i]-_ave)/_std;
	}

	if (gainLossOptions::_rateEstimationMethod == gainLossOptions::ebExp) {
		for (int k=0; k < _sc.seqLen(); ++k) {
			_BayesianUpperBound[k] = (_BayesianUpperBound[k] - _ave)/_std;
			_BayesianLowerBound[k] = (_BayesianLowerBound[k] - _ave)/_std;
			_BayesianSTD[k] = (_BayesianSTD[k])/_std;
		}
	}
	_ave = 0.0;
	_std = 1.0;
}






// gainLoss4site - function are now in gainLoss4siteGL.cpp
/********************************************************************************************
Computes the Empirical Bayesian expectation of the posterior
estimators for GL values
*********************************************************************************************/
void gainLoss::computeEB_EXP_GL4Site()
{
	time_t t1;
	time(&t1);
	time_t t2;
	LOGnOUT(4,<<"Starting gain4site and loss4site..."<<endl);

	Vdouble  gainV,stdGainV,lowerBoundGainV,upperBoundGainV;							  
	VVdouble posteriorsGainV;	
	computeEB_EXP_siteSpecificGL(gainV, stdGainV, lowerBoundGainV, upperBoundGainV, posteriorsGainV, _sc, _spVVec, _tr, _gainDist,_lossDist,_gainDist);

	ofstream outGain(gainLossOptions::_outFileGain4Site.c_str());
	printGainLossBayes(outGain,gainV,lowerBoundGainV,upperBoundGainV, posteriorsGainV, _gainDist);
	outGain.close();

	Vdouble  lossV,stdLossV,lowerBoundLossV,upperBoundLossV;							  
	VVdouble posteriorsLossV;	
	computeEB_EXP_siteSpecificGL(lossV, stdLossV, lowerBoundLossV, upperBoundLossV, posteriorsLossV, _sc, _spVVec, _tr, _gainDist,_lossDist,_lossDist);

	ofstream outLoss(gainLossOptions::_outFileLoss4Site.c_str());
	printGainLossBayes(outLoss,lossV,lowerBoundLossV,upperBoundLossV, posteriorsLossV, _lossDist);
	outLoss.close();

	time(&t2);
	LOGnOUT(4,<<endl<<"computeEB_EXP_GL4Site RUNNING TIME = "<<(t2-t1)/60.0<<" minutes"<<endl);
}
/********************************************************************************************
*********************************************************************************************/
void gainLoss::printGainLossBayes(ostream& out, const Vdouble& rate2printV, const Vdouble& lowerBoundV, const Vdouble& upperBoundV,const VVdouble& posteriorV, const distribution* dist) 
{	
	out.precision(7);
	out<<"#gainLoss Bayesian Results"<<endl;
	out<<"#Displayed on sequence "<<_refSeq->name()<<endl;
	out<<"========================================================================================================================================================="<<endl;
	out<<"POS\t"<<"1/0\t"<<"Rate\t"<<"[Confidence Interval]\t";
	int cat;
	for (cat = 0; cat <dist->categories(); ++cat)
		out<<dist->rates(cat)<<"\t";
	out<<endl;

	int numOfCategories = dist->categories();
	for (int i=0;i<_sc.seqLen();i++){	 
		string aaStr = _refSeq->getAlphabet()->fromInt((*_refSeq)[i]);
		out<<i+1 <<"\t"<<aaStr<<"\t"<< rate2printV[i]<<"\t"<<"["<<lowerBoundV[i]<<","<<upperBoundV[i]<<"]\t";
		//if (lowerBoundV[i]>1) out <<"*"; //significance indicator: if entire confidence interval >1
		for (cat = 0; cat < numOfCategories; ++cat)
			out<<posteriorV[i][cat]<<"\t";
		out<<endl;
	}
}	

// comupreCounts - function are now in comupreCounts.cpp
/********************************************************************************************
*********************************************************************************************/
void gainLoss::computePosteriorExpectationOfChangePerSite(){
	LOGnOUT(4,<<"Starting calculePosteriorExpectationOfChange..."<<endl);
	if(_LpostPerCat.size()==0){	// to fill _LpostPerCat - run computeRate4site()
		computeRate4site();
	}
	Vdouble expV01(_sc.seqLen(),0.0);
	Vdouble expV10(_sc.seqLen(),0.0);		
	computePosteriorExpectationOfChangePerSite(expV01, expV10);	//		

	// printOut the final results
	ofstream posteriorExpectationStream(gainLossOptions::_outFilePosteriorExpectationOfChange.c_str());
	posteriorExpectationStream<<"POS"<<"\t"<<"exp01"<<"\t"<<"exp10"<<endl;
	for (int pos = 0; pos <_sc.seqLen(); ++pos){
		posteriorExpectationStream<<pos+1<<"\t"<<expV01[pos]<<"\t"<<expV10[pos]<<endl;
	}
}
/********************************************************************************************
computePosteriorExpectationOfChangePerSite
*********************************************************************************************/
void gainLoss::computePosteriorExpectationOfChangePerSite(Vdouble& expV01, Vdouble& expV10){
	VVVVdouble posteriorsGivenTerminals;	// posteriorsGivenTerminals[pos][nodeID][x][y]
	VVVVdouble probChangesForBranch;		// probChangesForBranch[pos][nodeID][x][y]
	resizeVVVV(_sc.seqLen(),_tr.getNodesNum(),_sp->alphabetSize(),_sp->alphabetSize(),posteriorsGivenTerminals);
	resizeVVVV(_sc.seqLen(),_tr.getNodesNum(),_sp->alphabetSize(),_sp->alphabetSize(),probChangesForBranch);


	// Per RateCategory -- All the computations is done while looping over rate categories
	for (int rateIndex=0 ; rateIndex< _sp->categories(); ++rateIndex)
	{
		tree copy_et = _tr;
		MDOUBLE rateVal = _sp->rates(rateIndex);
		MDOUBLE  minimumRate = 0.0000001;
		MDOUBLE rate2multiply = max(rateVal,minimumRate);
		if(rateVal<minimumRate){
			LOGnOUT(4, <<" >>> NOTE: the rate category "<<rateVal<<" is too low for computePosteriorExpectationOfChangePerSite"<<endl);	}
		copy_et.multipleAllBranchesByFactor(rate2multiply);

		LOGnOUT(4, <<"running "<<gainLossOptions::_numOfSimulationsForPotExp<<" simulations for rate "<<rate2multiply<<endl);
		gainLossAlphabet alph;	// needed for Alphabet size
		simulateJumps simPerRateCategory(copy_et,*_sp,&alph);	
		simPerRateCategory.runSimulation(gainLossOptions::_numOfSimulationsForPotExp);
		LOGnOUT(4,<<"finished simulations"<<endl);

		// Per POS		
		for (int pos = 0; pos <_sc.seqLen(); ++pos)
		{
			LOG(6,<<"pos "<<pos+1<<endl);
			VVVdouble posteriorsGivenTerminalsPerRateCategoryPerPos;
			VVVdouble probChangesForBranchPerRateCategoryPerPos;
			computePosteriorExpectationOfChange cpecPerRateCategoryPerPos(copy_et,_sc,_sp);	// Per POS,CAT
			cpecPerRateCategoryPerPos.computePosteriorOfChangeGivenTerminals(posteriorsGivenTerminalsPerRateCategoryPerPos,pos);
			MDOUBLE exp01 = cpecPerRateCategoryPerPos.computeExpectationOfChangeAcrossTree(simPerRateCategory,posteriorsGivenTerminalsPerRateCategoryPerPos,0,1);	// Per POS
			MDOUBLE exp10 = cpecPerRateCategoryPerPos.computeExpectationOfChangeAcrossTree(simPerRateCategory,posteriorsGivenTerminalsPerRateCategoryPerPos,1,0);	// Per POS
			expV01[pos]+=exp01*_LpostPerCat[rateIndex][pos];
			expV10[pos]+=exp10*_LpostPerCat[rateIndex][pos];

			cpecPerRateCategoryPerPos.computePosteriorAcrossTree(simPerRateCategory,posteriorsGivenTerminalsPerRateCategoryPerPos,probChangesForBranchPerRateCategoryPerPos);

			//	Store all information PerCat,PerPOS
			for(int i=0;i<posteriorsGivenTerminals[pos].size();++i){
				for(int j=0;j<posteriorsGivenTerminals[pos][i].size();++j){
					for(int k=0;k<posteriorsGivenTerminals[pos][i][j].size();++k){
						posteriorsGivenTerminals[pos][i][j][k] += posteriorsGivenTerminalsPerRateCategoryPerPos[i][j][k]*_LpostPerCat[rateIndex][pos];
						probChangesForBranch[pos][i][j][k] += probChangesForBranchPerRateCategoryPerPos[i][j][k]*_LpostPerCat[rateIndex][pos];
					}
				}
			}
		}
	}
	// end of rateCategories loop
	//////////////////////////////////////////////////////////////////////////


	// ProbabilityPerPosPerBranch
	string gainLossProbabilityPerPosPerBranch = gainLossOptions::_outDir + "//" + "gainLossProbabilityPerPosPerBranch.txt"; 
	ofstream gainLossProbabilityPerPosPerBranchStream(gainLossProbabilityPerPosPerBranch.c_str());
	gainLossProbabilityPerPosPerBranchStream<<"G/L"<<"\t"<<"POS"<<"\t"<<"branch"<<"\t"<<"branchLength"<<"\t"<<"distance2root"<<"\t"<<"probability"<<endl;
	string gainLossCountProbPerPos = gainLossOptions::_outDir + "//" + "gainLossCountProbPerPos.txt"; 
	ofstream gainLossCountProbPerPosStream(gainLossCountProbPerPos.c_str());
	gainLossCountProbPerPosStream<<"POS"<<"\t"<<"count01"<<"\t"<<"count10"<<endl;	
	for (int pos = 0; pos <_sc.seqLen(); ++pos){
		printGainLossProbabilityPerPosPerBranch(pos, gainLossOptions::_probCutOff, probChangesForBranch[pos],gainLossProbabilityPerPosPerBranchStream,gainLossCountProbPerPosStream);
	}

	// ExpectationPerBranch
	VVVdouble posteriorsGivenTerminalsTotal;
	resizeVVV(_tr.getNodesNum(),_sp->alphabetSize(),_sp->alphabetSize(),posteriorsGivenTerminalsTotal);
	for (int pos = 0; pos <_sc.seqLen(); ++pos){
		for(int i=0;i<posteriorsGivenTerminals[pos].size();++i){
			for(int j=0;j<posteriorsGivenTerminals[pos][i].size();++j){
				for(int k=0;k<posteriorsGivenTerminals[pos][i][j].size();++k){
					posteriorsGivenTerminalsTotal[i][j][k] += posteriorsGivenTerminals[pos][i][j][k];
				}
			}
		}
	}
	string gainLossExpectationPerBranch = gainLossOptions::_outDir + "//" + "gainLossExpectationPerBranch.txt"; 
	ofstream gainLossExpectationPerBranchStream(gainLossExpectationPerBranch.c_str());
	printGainLossExpectationPerBranch(posteriorsGivenTerminalsTotal,gainLossExpectationPerBranchStream);


	// ProbabilityPerPosPerBranch - Print Trees
	Vstring Vnames;
	fillVnames(Vnames,_tr);
	if(gainLossOptions::_printTreesWithProbabilityValuesAsBP){
		createDir(gainLossOptions::_outDir, "TreesWithProbabilityValuesAsBP");
		for (int pos = 0; pos <_sc.seqLen(); ++pos){
			string strTreeNum = gainLossOptions::_outDir + "//" + "TreesWithProbabilityValuesAsBP"+ "//" + "probTree" + int2string(pos+1) + ".ph";
			ofstream tree_out(strTreeNum.c_str());
			printTreeWithValuesAsBP(tree_out,_tr,Vnames,&probChangesForBranch[pos]);
		}
	}
	// ExpectationPerPosPerBranch - Print Trees
	if(gainLossOptions::_printTreesWithExpectationValuesAsBP){
		createDir(gainLossOptions::_outDir, "TreesWithExpectationValuesAsBP");
		for (int pos = 0; pos <_sc.seqLen(); ++pos){
			string strTreeNum = gainLossOptions::_outDir + "//" + "TreesWithExpectationValuesAsBP" + "//" + "expTree" + int2string(pos+1) + ".ph";
			ofstream tree_out(strTreeNum.c_str());
			printTreeWithValuesAsBP(tree_out,_tr,Vnames,&posteriorsGivenTerminals[pos]);
		}
	}
}
/********************************************************************************************
*********************************************************************************************/
void gainLoss::printGainLossProbabilityPerPosPerBranch(int pos, MDOUBLE probCutOff, VVVdouble& probChanges, ostream& out, ostream& outCount)
{
	MDOUBLE count01 =0;
	MDOUBLE count10 =0;
	treeIterTopDownConst tIt(_tr);
	for (tree::nodeP mynode = tIt.first(); mynode != tIt.end(); mynode = tIt.next()) {
		if (probChanges[mynode->id()][0][1] > probCutOff){
			out<<"gain"<<"\t"<<pos<<"\t"<<mynode->name()<<"\t"<<mynode->dis2father()<<"\t"<<getDistance2ROOT(mynode)<<"\t"<<probChanges[mynode->id()][0][1]<<endl;
			count01+= probChanges[mynode->id()][0][1];
		}
		//}
		//for (tree::nodeP mynode = tIt.first(); mynode != tIt.end(); mynode = tIt.next()) {
		if (probChanges[mynode->id()][1][0] > probCutOff){
			out<<"loss"<<"\t"<<pos<<"\t"<<mynode->name()<<"\t"<<mynode->dis2father()<<"\t"<<getDistance2ROOT(mynode)<<"\t"<<probChanges[mynode->id()][1][0]<<endl;
			count10+= probChanges[mynode->id()][1][0];
		}
	}
	outCount<<pos<<"\t"<<count01<<"\t"<<count10<<endl;
}
//////////////////////////////////////////////////////////////////////////
void gainLoss::printGainLossExpectationPerBranch(VVVdouble& expectChanges, ostream& out)
{
	treeIterTopDownConst tIt(_tr);
	out<<"# Gain"<<"\n";
	out<<"branch"<<"\t"<<"branchLength"<<"\t"<<"distance2root"<<"\t"<<"expectation"<<endl;
	for (tree::nodeP mynode = tIt.first(); mynode != tIt.end(); mynode = tIt.next()) {
		out<<mynode->name()<<"\t"<<mynode->dis2father()<<"\t"<<getDistance2ROOT(mynode)<<"\t"<<expectChanges[mynode->id()][0][1]<<endl;
	}
	out<<"# Loss"<<"\n";
	out<<"branch"<<"\t"<<"branchLength"<<"\t"<<"distance2root"<<"\t"<<"expectation"<<endl;
	for (tree::nodeP mynode = tIt.first(); mynode != tIt.end(); mynode = tIt.next()) {
		out<<mynode->name()<<"\t"<<mynode->dis2father()<<"\t"<<getDistance2ROOT(mynode)<<"\t"<<expectChanges[mynode->id()][1][0]<<endl;
	}
}
