#include "mDefinitions.h"
#include "mCommon.h"
#include "zoeTools.h"
#include "mMatrix.h"
#include <argtable2.h>

#define FORMAT "%lf"
#define EUCLIDEAN  (1)
#define JSD        (2)
#define KLD        (3)
#define NONE      (-1)

#define ROWS       (1)
#define ELEMENTS   (2)

void mPrintDist(FILE *stream, zoeHash id_names, zoeHash names, int dimensionality, mMatrix *var) {
	int i, j;
	zoeTVec keys = zoeKeysOfHash(names);
	double threshold = 960.0;

	qsort(keys->elem, keys->size, sizeof(char*), zoeTcmp);
	fprintf(stream, "cog_0");
	for (j=0; j<dimensionality; j++)
		if (var->elem[j][0] > threshold)
			fprintf(stream, "\tcog_%d", j+1);
	fprintf(stream, "\n");
	for (i=0; i<keys->size; i++) {
		char *key = keys->elem[i];
		char *ncbi_name = (char*) zoeGetHash(names, key);
		double *val = (double*) zoeGetHash(id_names, ncbi_name);
		fprintf(stream, "%s", ncbi_name);
		for (j=0; j<dimensionality; j++)
			if (var->elem[j][0] > threshold)
				fprintf(stream, "\t%d", (int)val[j]);
		fprintf(stream, "\n");
	}
}

int main(int argc, char *argv[]) {
	int i, j, iter;
	mMatrix *m, **col;
	mVector  *labels   = (mVector*) mMalloc(sizeof(mVector));
	FILE *fp;

	int distance_metric;
	int features;
	int samples;
	int rownorm;
	int colnorm;
	int noise_removal;
	int bootstrapcount;
	int bootstraprand;
	int bootstraptype;

	int              argcount = 0;
	int              nerrors;
	void           **argtable;

	struct arg_int  *arg_features;
	struct arg_int  *arg_samples;
	struct arg_str  *arg_distance;
	struct arg_str  *arg_bootstrap_type;
	struct arg_int  *arg_bootstrap_count;
	struct arg_lit  *arg_bootstrap_rand;
	struct arg_lit  *arg_row_norm;
	struct arg_lit  *arg_col_norm;
	struct arg_lit  *arg_noise_removal;
	struct arg_file *arg_file;
	struct arg_lit  *help;
	struct arg_end  *end;
	arg_features        = arg_int1("f",  "features",        "<n>",         "number of features (required)");
	arg_samples         = arg_int1("s",  "samples",         "<n>",         "number of samples (required)");
	arg_distance        = arg_str1("d",  "distance",        "<name>",      "distance measure: euclidean, kld (Kullback-Leibler) or jsd (Jensen-Shannon) (required)");
	arg_bootstrap_type  = arg_str1("t",  "bootstrap_type",  "<type>",      "bootstrap type: rows or elements (required)");
	arg_row_norm        = arg_lit0("r",  "rownorm",                        "perform row normalization (default is false)");
	arg_col_norm        = arg_lit0("c",  "colnorm",                        "perform column normalization (default is false)");
	arg_noise_removal   = arg_lit0(NULL, "noise",                          "remove very low values (noise) from the matrix (default is false)");
	arg_bootstrap_rand  = arg_lit0(NULL, "randomize",                      "perform randomized bootstrap (default is false)");
	arg_bootstrap_count = arg_int0("n",  "bootstrap_count", "<n>",         "number of replicates for bootstrap (default:100)");
	arg_file            = arg_file1(NULL, NULL,             "<file>",      "file with feature vs sample matrix");
	help                = arg_lit0( "h", "help",                           "print this help and exit");
	end                 = arg_end(8); /* this needs to be even, otherwise each element in end->parent[] crosses an 8-byte boundary */

	argtable          = (void**) mMalloc(12*sizeof(void*));

	argtable[argcount++] = arg_features;
	argtable[argcount++] = arg_samples;
	argtable[argcount++] = arg_distance;
	argtable[argcount++] = arg_row_norm;
	argtable[argcount++] = arg_col_norm;
	argtable[argcount++] = arg_noise_removal;
	argtable[argcount++] = arg_bootstrap_type;
	argtable[argcount++] = arg_bootstrap_rand;
	argtable[argcount++] = arg_bootstrap_count;
	argtable[argcount++] = arg_file;
	argtable[argcount++] = help;
	argtable[argcount++] = end;

	arg_bootstrap_count->ival[0] = 100;

	if (arg_nullcheck(argtable) != 0) {
		mDie("insufficient memory");
	}
	nerrors = arg_parse(argc, argv, argtable);

	if (help->count > 0) {
		fprintf(stdout, "Usage: distance_matrix");
		arg_print_syntax(stdout, argtable, "\n");
		arg_print_glossary(stdout, argtable, "  %-25s %s\n");
		mQuit("");
	}

	if (nerrors > 0) {
		arg_print_errors(stderr, end, "distance_matrix");
		fprintf(stderr, "try using -h\n");
		mQuit("");
	}

	features       = arg_features->ival[0];
	samples        = arg_samples->ival[0];
	rownorm        = arg_row_norm->count;
	colnorm        = arg_col_norm->count;
	noise_removal  = arg_noise_removal->count;
	bootstraprand  = arg_bootstrap_rand->count;
	bootstrapcount = arg_bootstrap_count->ival[0];
	if (strcmp(arg_distance->sval[0], "euclidean") == 0) {
		distance_metric = EUCLIDEAN;
	} else if (strcmp(arg_distance->sval[0], "jsd") == 0) {
		distance_metric = JSD;
	} else if (strcmp(arg_distance->sval[0], "kld") == 0) {
		distance_metric = KLD;
	} else {
		distance_metric = NONE;
		fprintf(stderr, "invalid value for --distance");
		fprintf(stderr, "try using -h\n");
		mQuit("");
	}
	if (strcmp(arg_bootstrap_type->sval[0], "rows") == 0) {
		bootstraptype = ROWS;
	} else if (strcmp(arg_bootstrap_type->sval[0], "elements") == 0) {
		bootstraptype = ELEMENTS;
	} else {
		bootstraptype = NONE;
		fprintf(stderr, "invalid value for --bootstrap_type");
		fprintf(stderr, "try using -h\n");
		mQuit("");
	}

	srand(3510);

	if (strcmp(arg_file->filename[0], "-") == 0) {
		fp = stdin;
	} else if ((fp = fopen(arg_file->filename[0], "r")) == NULL) {
		exit(-1);
	}
/* m is a features x samples matrix
   S1 S2 S3
F1  a  b  c
F2  d  e  f
F3  g  h  i
F4  j  k  l
*/
/* col breaks the columns
col [0]
    S1
F1  a
F2  d
F3  g
F4  j

m->elem[feature][sample]
becomes
col[sample]->elem[feature][0]
*/
	m   = (mMatrix*) mMalloc(sizeof(mMatrix));

	mInitVector(labels, samples);
	mInitMatrix(m, features, samples);
	mReadMatrixWithLabels(fp, m, features, samples, 2, labels);
#ifdef DEBUG
mWriteMatrix(stdout, m);
#endif
	for (iter=0; iter<bootstrapcount; iter++) {
		double global_sum = 0;

		col = (mMatrix**) mMalloc(samples*sizeof(mMatrix*));
		for (i=0; i<samples; i++) {
			col[i] = (mMatrix*) mMalloc(sizeof(mMatrix));
			mInitMatrix(col[i], features, 1);
			for (j=0; j<features; j++) col[i]->elem[j][0] = 0;
		}

		if (bootstraptype == ROWS) {
			int *indices = (int*) mMalloc(features*sizeof(int));
			for (i=0; i<features; i++) {
				if (bootstraprand == 1) 
					indices[i] = (int) (1.0 * features * (rand() / (RAND_MAX + 1.0)));
				else
					indices[i] = i;
			}

			/* split into columns */
			for (i=0; i<samples; i++) {
				for (j=0; j<features; j++) {
					col[i]->elem[j][0] = m->elem[indices[j]][i];
					global_sum += col[i]->elem[j][0];
				}
			}
			mFree(indices);
		} else if (bootstraptype == ELEMENTS) {
			for (i=0; i<samples; i++) {
				int sum = 0;
				int *indices;
				int current;

				for (j=0; j<features; j++) sum += m->elem[j][i];
				indices = (int*) mMalloc(sum*sizeof(int));

				/* make flattened set of all hits */
				current = 0;
				for (j=0; j<features; j++) {
					int k;
					for (k=0; k<m->elem[j][i]; k++) {
						indices[current++] = j;
					}
				}

				/* mFisherYatesShuffle(indices); */

				/* subsample from the flattened set with replacement */
				for (j=0; j<sum; j++) {
					int s = (int) (1.0 * sum * (rand() / (RAND_MAX + 1.0)));
					col[i]->elem[indices[s]][0]++;
				}
				mFree(indices);
			}
		}

#ifdef DEBUG
{
for (i=0; i<samples; i++) {
	fprintf(stdout, "| ");
	for (j=0; j<features; j++) {
		fprintf(stdout, "%.4f\t", col[i]->elem[j][0]);
	}
	fprintf(stdout, " |\n");
}
}
#endif
		/* filter low numbers */

		if (noise_removal > 0) {
			for (i=0; i<features; i++) {
				double row_sum = 0;
				/* get row sum */
				for (j=0; j<samples; j++) {
					row_sum += col[j]->elem[i][0];
				}
				/* if too low, kill it */
				if (row_sum*100/global_sum < 0.01) 
					for (j=0; j<samples; j++) 
						col[j]->elem[i][0] = 0;
			}
		}

		/* normalize the columns on column sum */

		if (colnorm > 0) {
			for (i=0; i<samples; i++) {
				double sum = 0;
				for (j=0; j<features; j++) sum+=col[i]->elem[j][0];
				if (sum != 0.0) {
					for (j=0; j<features; j++) {
						col[i]->elem[j][0] /= sum;
					}
				}
			}
		}
#ifdef DEBUG
{
double sum1 = 0;
for (i=0; i<samples; i++) {
	sum1 = 0;
	for (j=0; j<features; j++) {
		sum1 += col[i]->elem[j][0];
	}
	mWriteMatrix(stdout, col[i]);
	printf("Sample %d sums to "FORMAT"\n", i, sum1);
}
}
#endif

		/* rescale using mean/stdev */
		if (rownorm > 0) {
			for (i=0; i<features; i++) {
				double stdev, mean;
				mMatrix* buffer = (mMatrix*) mMalloc(sizeof(mMatrix));
				mInitMatrix(buffer, 1, samples);
				for (j=0; j<samples; j++) {
					buffer->elem[0][j] = col[j]->elem[i][0];
				}
				stdev = mStdev(samples, buffer->elem[0], &mean);
#ifdef DEBUG
mWriteMatrix(stdout, buffer);
puts("becomes");
printf("MEAN="FORMAT"; STDEV="FORMAT"\n", mean, stdev);
#endif
				if (stdev != 0) {
					for (j=0; j<samples; j++) {
						buffer->elem[0][j]  -= mean;
						buffer->elem[0][j]  /= stdev;
						col[j]->elem[i][0] = buffer->elem[0][j];
					}
				}
#ifdef DEBUG
mWriteMatrix(stdout, buffer);
puts("end");
#endif
				mFreeMatrix(buffer);
				mFree(buffer);
			}

		}
		fprintf(stdout, "%d\n", samples);
		for (i=0; i<samples; i++) {
			fprintf(stdout, "%-10s", (char*)(labels->elem[i]));
			for (j=0; j<samples; j++) {
				if (distance_metric == EUCLIDEAN)
					fprintf(stdout, " %.6f", mEuclideanDistance(col[i], col[j]));
				else if (distance_metric == JSD)
					fprintf(stdout, " %.6f", mJensenShannonDistance(col[i], col[j]));
				else if (distance_metric == KLD)
					fprintf(stdout, " %.6f", mKullbackLeiblerDivergence(col[i], col[j]));
			}
			fprintf(stdout, "\n");
		}
		for (i=0; i<samples; i++) {
			mFreeMatrix(col[i]);
			mFree(col[i]);
		}
		mFree(col);
	}
	mFreeMatrix(m);
	mFree(m);
	for (i=0; i<samples; i++) {
		mFree(labels->elem[i]);
	}
	mFreeVector(labels);
	mFree(labels);
	arg_freetable(argtable, argcount);
	mFree(argtable);
	return 0;
}

int main2(int argc, char *argv[]) {
	int i=argc, j;
	FILE *fp, *distances;
	char line[512];
	char *key = NULL;
	char *species_id_str;
	int max_cog_id = 0;
	int dimensionality;
	int prev_species_id;
	double *genes;
	int num_genomes;
	int *total_genes, *total_cogs;
	double **shared, **wshared, **distance, **euclidean, **binary_euclidean;
	mMatrix *mean;
	mMatrix *var;
	mMatrix *covar;
	mMatrix **set;
	zoeHash id2genes = NULL;
	zoeHash name2id = NULL;
	zoeHash id2index = NULL;
	zoeTVec keys = NULL;
	mFVector *species = (mFVector*) mMalloc(sizeof(mFVector));
	mVector  *names   = (mVector*) mMalloc(sizeof(mVector));
	mFVector *cogs    = (mFVector*) mMalloc(sizeof(mFVector));
	mFVector *counts  = (mFVector*) mMalloc(sizeof(mFVector));

	if ((fp = fopen(argv[1], "r")) == NULL) {
		exit(-1);
	}

	if ((distances = fopen(argv[2], "r")) == NULL) {
		exit(-1);
	}

	mInitVector(names, 4096);
	mInitFVector(species, 4096);
	mInitFVector(cogs, 4096);
	mInitFVector(counts, 4096);

	while (fgets(line, 512, fp) != NULL) {
		int species_id, cog_id, count;
		char *species_name = (char*) mMalloc(64*sizeof(char));
		if (sscanf(line, "%d\t%s\t%d\t%d", &species_id, species_name, &cog_id, &count) != 4) {
			mDie("Cannot scan 4 values");
		}
		max_cog_id = MAX(cog_id, max_cog_id);
		mPushVector(names, species_name);
		mPushFVector(species, (double)species_id);
		mPushFVector(cogs,    (double)cog_id);
		mPushFVector(counts,  (double)count);
	}

	dimensionality = max_cog_id;
	
/* works */
/*
	for (i=0; i<species->size; i++) {
		fprintf(stdout, "%f\t%f\t%f\n", species->elem[i], cogs->elem[i], counts->elem[i]);
	}
*/



	prev_species_id = -1;
	genes = NULL;
	id2genes = zoeNewHash();
	id2index = zoeNewHash();
	name2id = zoeNewHash();
	species_id_str = (char*) mMalloc(32*sizeof(char));
	for (i=0; i<species->size; i++) {
		int species_id, cog_id, count;

		species_id = species->elem[i];
		cog_id     = cogs->elem[i];
		count      = counts->elem[i];
		if (species_id != prev_species_id) {
			if (prev_species_id != -1) {
				zoeSetHash(name2id, key, species_id_str);
				zoeSetHash(id2genes, species_id_str, genes);
				species_id_str = (char*) mMalloc(512*sizeof(char));
			}
			genes = (double*) mMalloc(dimensionality*sizeof(double));
			for (j=0; j<dimensionality; j++) genes[j] = 0;
		}
		genes[cog_id-1] = count;
		prev_species_id = species_id;
		key        = (char*)names->elem[i];
		sprintf(species_id_str, "%d", species_id);
	}
	if (prev_species_id != -1) {
		zoeSetHash(name2id, key, species_id_str);
		zoeSetHash(id2genes, species_id_str, genes);
	}

	keys = zoeKeysOfHash(id2genes);
	num_genomes = keys->size;
	set  = (mMatrix**) mMalloc(num_genomes*sizeof(mMatrix*));

	/* Make a set of arrays from the id2genes */
	for (i=0; i<num_genomes; i++) {
		char *key = keys->elem[i];
		double *item = (double*) zoeGetHash(id2genes, key);
		int *index = (int*) mMalloc(sizeof(int));
		set[i] = (mMatrix*) mMalloc(sizeof(mMatrix));
		mInitMatrix(set[i], dimensionality, 1);
		for (j=0; j<dimensionality; j++) set[i]->elem[j][0] = item[j];
		*index = i;
		zoeSetHash(id2index, key, index);
	}


	/* Now I have:
	set[i][j] --> instances of cog j in species i
	id2index(i) --> name of species i
	rid2index(name) --> i
	name2id(species_name) --> ncbi-id
	
	*/

	/* Total number of genes in the COG space per genome */
	total_genes = (int*) mCalloc(num_genomes, sizeof(int));
	total_cogs  = (int*) mCalloc(num_genomes, sizeof(int));
	for (i=0; i<num_genomes; i++) {
		for (j=0; j<dimensionality; j++) {
			if (set[i]->elem[j][0] > 0) {
				total_genes[i] += set[i]->elem[j][0];
				total_cogs[i] ++;
			}
		}
	}

	/* Calculate shared cog between species */
	shared = (double**) mMalloc(num_genomes*sizeof(double*));
	wshared = (double**) mMalloc(num_genomes*sizeof(double*));
	euclidean = (double**) mMalloc(num_genomes*sizeof(double*));
	binary_euclidean = (double**) mMalloc(num_genomes*sizeof(double*));
	for (i=0; i<num_genomes; i++) {
		shared[i] = (double*) mMalloc(num_genomes*sizeof(double));
		wshared[i] = (double*) mMalloc(num_genomes*sizeof(double));
		euclidean[i] = (double*) mMalloc(num_genomes*sizeof(double));
		binary_euclidean[i] = (double*) mMalloc(num_genomes*sizeof(double));
		for (j=i+1; j<num_genomes; j++) {
			int k;
			int sum = 0;
			for (k=0; k<dimensionality; k++) {
				sum += MIN(set[i]->elem[k][0], set[j]->elem[k][0]); /* shared genes inside that COG */
			}
			shared[i][j] = 1.0 * sum;
			euclidean[i][j] = mSquaredEuclideanDistance(set[i], set[j]);
			binary_euclidean[i][j] = mBinaryEuclideanDistance(set[i], set[j]);
		}
	}

	/* read in distances from chris */
	distance = (double**) mMalloc(num_genomes*sizeof(double*));
	for (i=0; i<num_genomes; i++) {
		distance[i] = (double*) mMalloc(num_genomes*sizeof(double));
		for (j=0; j<num_genomes; j++) {
			distance[i][j] = -1.0;
		}
	}
		
	while (fgets(line, 512, distances) != NULL) {
		double dist;
		int index_a, index_b;
		char id_a[32], id_b[32];
		int *buf;
		sscanf(line, "%s %s "FORMAT, id_a, id_b, &dist);
		buf = (int*)zoeGetHash(id2index, id_a);
		if (buf == NULL) {
			continue;
		}
		index_a = *buf;
		buf = (int*)zoeGetHash(id2index, id_b);
		if (buf == NULL) {
			continue;
		}
		index_b = *buf;
		distance[index_a][index_b] = dist;
	}

	for (i=0; i<num_genomes; i++) {
		for (j=i+1; j<num_genomes; j++) {
			int normalize = MIN(total_genes[i], total_genes[j]);
			double wnormalize = sqrt(2*total_genes[i]*total_genes[j]) / sqrt(total_genes[i]*total_genes[i] + total_genes[j]*total_genes[j]);
			int bnormalize = MIN(total_cogs[i], total_cogs[j]);
			double bwnormalize = sqrt(2*total_cogs[i]*total_cogs[j]) / sqrt(total_cogs[i]*total_cogs[i] + total_cogs[j]*total_cogs[j]);
			if (distance[i][j] >= 0.0) fprintf(stdout, "%.6f\t%.6f\t%.6f\t%.6f\t%.6f\t%.6f\t%.6f\n", distance[i][j], shared[i][j]/normalize, shared[i][j]/wnormalize, euclidean[i][j]/normalize, euclidean[i][j]/wnormalize, binary_euclidean[i][j]/bnormalize, binary_euclidean[i][j]/bwnormalize);
		}
	}
	exit(0);

	/* Calculate mean */
	mean = (mMatrix*)  mMalloc(sizeof(mMatrix));
	mInitMatrix(mean, dimensionality, 1);
	mZerofyMatrix(mean);
	for (i=0; i<num_genomes; i++) {
		mAddToMatrix(mean, set[i]);
	}
	mDivideMatrixByScalar(mean, num_genomes);
	/*mWriteMatrix(stdout, mean);*/

	/* Calculate variance */
	var  = (mMatrix*)  mMalloc(sizeof(mMatrix));
	mInitMatrix(var,  dimensionality, 1);
	mZerofyMatrix(var);
	for (i=0; i<num_genomes; i++) {
		mMatrix *buf, *square;
		buf = mSubtractMatrices(set[i], mean);
		square = mMatrixElementsSquared(buf);
		mAddToMatrix(var, square);
	}
	/*mWriteMatrix(stdout, var);*/

	mPrintDist(stdout, id2genes, name2id, dimensionality, var);
	exit(0);

	/* calculate covariance matrix */

	covar = (mMatrix*) mMalloc(sizeof(mMatrix));
	mInitMatrix(covar, dimensionality, dimensionality);
	mZerofyMatrix(covar);
	for (i=0; i<num_genomes; i++) {
		mMatrix *buf, *buft, *prod;
		buf = mSubtractMatrices(set[i], mean);
		buft = mMatrixTranspose(buf);
		prod = mMultiplyMatrices(buf, buft);
		mAddToMatrix(covar, prod);
		mFreeMatrix(buf);
		mFreeMatrix(buft);
		mFreeMatrix(prod);
	}
	mDivideMatrixByScalar(covar, num_genomes - 1.0);
	mWriteMatrix(stdout, covar);

	for (i=0; i<num_genomes; i++) {
		for (j=i+1; j<num_genomes; j++) {
			mMatrix *mahalanobis = mMahalanobisDistance(set[i], set[j], covar);
			double dist = sqrt(mahalanobis->elem[0][0]);
			dist = 0;
		}
	}
}
