#include "mMatrix.h"

/* BEGIN MATRIX */

/***********
 * memory management
 ***********/

void mInitMatrix(mMatrix *m, int nrows, int ncols) {
	int i;
	m->elem = (double**) mMalloc(nrows*sizeof(double*));
	for (i=0; i<nrows; i++)
		m->elem[i] = (double*) mMalloc(ncols*sizeof(double));
	m->nrows = nrows;
	m->ncols = ncols;
}

void mFreeMatrix(mMatrix *m) {
	int i;
	for (i=0; i<m->nrows; i++)
		mFree(m->elem[i]);
	mFree(m->elem);
}

/***********
 * Matrix operations
 ***********/

mMatrix* mMatrixTranspose(mMatrix *a) {
	int i, j;
	mMatrix *tra = (mMatrix*) mMalloc(sizeof(mMatrix));

	mInitMatrix(tra, a->ncols, a->nrows);
	for (i=0; i<a->nrows; i++) {
		for (j=0; j<a->ncols; j++) {
			tra->elem[j][i] = a->elem[i][j];
		}
	}
	return tra;
}

mMatrix* mMatrixElementsSquared(mMatrix *a) {
	int i, j;
	mMatrix *sqr = (mMatrix*) mMalloc(sizeof(mMatrix));

	mInitMatrix(sqr, a->nrows, a->ncols);
	for (i=0; i<a->nrows; i++) {
		for (j=0; j<a->ncols; j++) {
			sqr->elem[i][j] = a->elem[i][j]*a->elem[i][j];
		}
	}
	return sqr;
}

void mSquareMatrixElements(mMatrix *a) {
	int i, j;
	for (i=0; i<a->nrows; i++) {
		for (j=0; j<a->ncols; j++) {
			a->elem[i][j] = a->elem[i][j]*a->elem[i][j];
		}
	}
}

void mAddToMatrix(mMatrix *a, mMatrix *b) {
	int i, j;
	if (a->nrows != b->nrows || a->ncols != b->ncols) {
		mDie("Incompatible matrices in addition");
	}
	for (i=0; i<a->nrows; i++) {
		for (j=0; j<a->ncols; j++) {
			a->elem[i][j] += b->elem[i][j];
		}
	}
}

mMatrix* mAddMatrices(mMatrix *a, mMatrix *b) {
	mMatrix* sum;
	int i, j;
	if (a->nrows != b->nrows || a->ncols != b->ncols) {
		mDie("Incompatible matrices in addition");
	}
	sum = (mMatrix*) mMalloc(sizeof(mMatrix));
	mInitMatrix(sum, a->nrows, a->ncols);
	for (i=0; i<sum->nrows; i++) {
		for (j=0; j<sum->ncols; j++) {
			sum->elem[i][j] = a->elem[i][j] + b->elem[i][j];
		}
	}
	return sum;
}

mMatrix* mSubtractMatrices(mMatrix *a, mMatrix *b) {
	mMatrix* diff;
	int i, j;
	if (a->nrows != b->nrows || a->ncols != b->ncols) {
		mDie("Incompatible matrices in subtraction");
	}
	diff = (mMatrix*) mMalloc(sizeof(mMatrix));
	mInitMatrix(diff, a->nrows, a->ncols);
	for (i=0; i<diff->nrows; i++) {
		for (j=0; j<diff->ncols; j++) {
			diff->elem[i][j] = a->elem[i][j] - b->elem[i][j];
		}
	}
	return diff;
}

mMatrix* mMultiplyMatrices(mMatrix *a, mMatrix *b) {
	mMatrix* product;
	int i, j, k;
	if (a->ncols != b->nrows) {
		mDie("Incompatible matrices in multiplication");
	}
	product = (mMatrix*) mMalloc(sizeof(mMatrix));
	mInitMatrix(product, a->nrows, b->ncols);
	for (i=0; i<product->nrows; i++) {
		for (j=0; j<product->ncols; j++) {
			double sum = 0;
			for (k=0; k<a->ncols; k++) {
				sum += a->elem[i][k]*b->elem[k][j];
			}
			product->elem[i][j] = sum;
		}
	}
	return product;
}

void mDivideMatrixByScalar(mMatrix *m, double n) {
	int i, j;
	for (i=0; i<m->nrows; i++) {
		for (j=0; j<m->ncols; j++) {
			m->elem[i][j] /= n;
		}
	}
}

void mFillMatrix(mMatrix *m, double x) {
	int i, j;
	for (i=0; i<m->nrows; i++) {
		for (j=0; j<m->ncols; j++) {
			m->elem[i][j] = x;
		}
	}
}

void mZerofyMatrix(mMatrix *m) {
	mFillMatrix(m, 0);
}

void mLogTransformMatrix(mMatrix *m) {
	int i, j;
	for (i=0; i<m->nrows; i++) {
		for (j=0; j<m->ncols; j++) {
			m->elem[i][j] = log(m->elem[i][j]);
		}
	}
}

mMatrix* mMatrixInverse(mMatrix *m) {
	return m;
}

/***********
 * inter-matrix methods
 ***********/
double mSquaredEuclideanDistance(mMatrix *a, mMatrix *b) {
	int i;
	double sum = 0;
	if (a->nrows != b->nrows || a->ncols != b->ncols || a->ncols != 1) {
		mDie("Cannot calculate Euclidean distances for %d X %d matrices", a->nrows, a->ncols);
	}
	for (i=0; i<a->nrows; i++) {
		double dist = (a->elem[i][0] - b->elem[i][0]);
		sum += dist*dist;
	}
	return sum;
}

double mEuclideanDistance(mMatrix *a, mMatrix *b) {
	return sqrt(mSquaredEuclideanDistance(a, b));
}

double mBinaryEuclideanDistance(mMatrix *a, mMatrix *b) {
	int i;
	double sum = 0;
	if (a->nrows != b->nrows || a->ncols != b->ncols || a->ncols != 1) {
		mDie("Cannot calculate Euclidean distances for %d X %d matrices", a->nrows, a->ncols);
	}
	for (i=0; i<a->nrows; i++) {
		double dist = (MIN(1,a->elem[i][0]) - MIN(1,b->elem[i][0]));
		sum += dist*dist;
	}
	return sqrt(sum);
}

double mKullbackLeiblerDivergence(mMatrix *a, mMatrix *b) {
	int i;
	double sum = 0;
	if (a->nrows != b->nrows || a->ncols != b->ncols || a->ncols != 1) {
		mDie("Cannot calculate Kullback-Leibler divergence for %d X %d matrices", a->nrows, a->ncols);
	}
	for (i=0; i<a->nrows; i++) {
		sum += a->elem[i][0]*log(a->elem[i][0]/b->elem[i][0]);
	}
	return sum;
}

double mJensenShannonDivergence(mMatrix *a, mMatrix *b) {
	double divergence;
	mMatrix* mean;

	if (a->nrows != b->nrows || a->ncols != b->ncols || a->ncols != 1) {
		mDie("Cannot calculate Jensen-Shannon divergence for %d X %d matrices", a->nrows, a->ncols);
	}
	
	mean = mAddMatrices(a, b);
	mDivideMatrixByScalar(mean, 2);
	divergence = (mKullbackLeiblerDivergence(a, mean) + mKullbackLeiblerDivergence(b, mean))/2;
	mFreeMatrix(mean);
	mFree(mean);
	return divergence;
}

double mJensenShannonDistance(mMatrix *a, mMatrix *b) {
	return sqrt(mJensenShannonDivergence(a, b));
}

mMatrix* mMahalanobisDistance(mMatrix *a, mMatrix *b, mMatrix *covar) {
	mMatrix* inv = mMatrixInverse(covar);
	mMatrix* buf = mSubtractMatrices(a, b);
	mMatrix* buft = mMatrixTranspose(buf);
	mMatrix* maha1 = mMultiplyMatrices(buft, inv);
	mMatrix* maha2 = mMultiplyMatrices(maha1, buf);
	mFreeMatrix(inv); mFree(inv);
	mFreeMatrix(buf); mFree(buf);
	mFreeMatrix(buft); mFree(buft);
	mFreeMatrix(maha1); mFree(maha1);

	return maha2;
}

/***********
 * Matrix read/write methods
 ***********/
void mMatrixFromArray(mMatrix *m, int nrows, int ncols, double **array) {
	int i, j;
	for (i=0; i<nrows; i++) {
		for (j=0; j<ncols; j++) {
			m->elem[i][j] = array[i][j];
		}
	}
}

void mReadMatrix(FILE *stream, mMatrix *m, int nrows, int ncols) {
	int i, j;
	for (i=0; i<nrows; i++) {
		for (j=0; j<ncols; j++) {
			fscanf(stream, "%lf", m->elem[i]+j);
		}
	}
}

/* label_source = 1 => each row has label for itself, 2 => top row has labels for every column */
void mReadMatrixWithLabels(FILE *stream, mMatrix *m, int nrows, int ncols, int label_source, mVector *labels) {
	int i, j;
	char *label = (char*) mMalloc(64*sizeof(char));
	if (label_source == 2) {
		for (i=0; i<ncols; i++) {
			fscanf(stream, "%s", label);
			mPushVector(labels, label);
			label = (char*) mMalloc(64*sizeof(char));
		}
	}
	for (i=0; i<nrows; i++) {
		if (label_source == 1) {
			fscanf(stream, "%s", label);
			mPushVector(labels, label);
			label = (char*) mMalloc(64*sizeof(char));
		}
		for (j=0; j<ncols; j++) {
			fscanf(stream, "%lf", m->elem[i]+j);
		}
	}
	mFree(label);
}

void mWriteMatrix(FILE *stream, mMatrix *m) {
	int i, j;
	for (i=0; i<m->nrows; i++) {
		fprintf(stream, "| ");
		for (j=0; j<m->ncols; j++) {
			fprintf(stream, "%-8.4f ", m->elem[i][j]);
		}
		fprintf(stream, "|\n");
	}
}

/* END_MATRIX */
