# ------------------------------------------------------------
# Author: Natalie Romanov
# ------------------------------------------------------------

class module_wise_normalization:
	@staticmethod
	def complex_execute(**kwargs):
		'''
		normalizes subunit abundance on the complex median (or trimmed mean)
		'''
		folder = kwargs.get('folder', 'PATH')

		print('gather_complex_stoichiometry_data')
		com_yeast_dict = module_wise_normalization.gather_complex_stoichiometry_data(folder)

		print('export_complex_stoichiometry_data')
		module_wise_normalization.export_complex_stoichiometry_data(folder, com_yeast_dict)

	@staticmethod
	def pathway_execute(**kwargs):
		'''
		normalizes subunit abundance on the pathway median (or trimmed mean)
		'''
		folder = kwargs.get('folder', 'PATH')

		print('gather_complex_stoichiometry_data')
		pat_yeast_dict = module_wise_normalization.gather_pathway_stoichiometry_data(folder)

		print('export_complex_stoichiometry_data')
		module_wise_normalization.export_pathway_stoichiometry_data(folder, pat_yeast_dict)

	@staticmethod
	def gather_complex_stoichiometry_data(folder):
		yeast_datasets = ['yeast' + str(i) for i in [3, 4, 5, 10, 11, 14, 16, 18, 19, 20, 21]]
		com_yeast_dict = dict((e1,dict()) for e1 in yeast_datasets)

		fname = 'yeast3_quant_Proteome_carbonSources_MAPPED_complexes_pathways_NORM2.tsv.gz'
		data = DataFrameAnalyzer.getFile(folder, fname)
		quant_cols = utilsFacade.filtering(list(data.columns), 'rel. Intensity')
		stochData = module_wise_normalization.perform_complexBased_normalization(data, quant_cols)
		com_yeast_dict['yeast3'] = stochData
		##################################################################################
		fname = 'yeast4_quant_Proteome_naclStress_MAPPED_complexes_pathways_NORM2.tsv.gz'
		data = DataFrameAnalyzer.getFile(folder, fname)
		quant_cols = utilsFacade.filtering(list(data.columns), 'rel.Intensity')
		stochData = module_wise_normalization.perform_complexBased_normalization(data, quant_cols)
		com_yeast_dict['yeast4'] = stochData
		##################################################################################
		fname = 'yeast5_quant_Proteome_MAPPED_complexes_pathways_NORM2_IBAQ.tsv.gz'
		data = DataFrameAnalyzer.getFile(folder, fname)
		quant_cols = utilsFacade.filtering(list(data.columns), 'iBAQ')[:-1]
		stochData = module_wise_normalization.perform_complexBased_normalization(data, quant_cols)
		com_yeast_dict['yeast5'] = stochData
		##################################################################################
		fname = 'yeast10_quant_RNA_MAPPED_complexes_pathways_NORM2.tsv.gz'
		data = DataFrameAnalyzer.getFile(folder, fname)
		quant_cols = utilsFacade.filtering(list(data.columns), 'quant')
		stochData = module_wise_normalization.perform_complexBased_normalization(data, quant_cols)
		com_yeast_dict['yeast10'] = stochData
		##################################################################################
		fname = 'yeast11_quant_proteome_MAPPED_complexes_pathways_NORM2.tsv.gz'
		data = DataFrameAnalyzer.getFile(folder, fname)
		quant_cols = utilsFacade.filtering(list(data.columns), 'quant')
		stochData = module_wise_normalization.perform_complexBased_normalization(data, quant_cols)
		com_yeast_dict['yeast11'] = stochData
		##################################################################################
		fname = 'yeast14_quant_proteome_MAPPED_complexes_pathways_NORM2.tsv.gz'
		data = DataFrameAnalyzer.getFile(folder, fname)
		quant_cols = utilsFacade.filtering(list(data.columns), 'quant_')
		stochData = module_wise_normalization.perform_complexBased_normalization(data, quant_cols)
		com_yeast_dict['yeast14'] = stochData
		##################################################################################
		fname = 'yeast16_quant_proteome_MAPPED_complexes_pathways_NORM2.tsv.gz'
		data = DataFrameAnalyzer.getFile(folder, fname)
		quant_cols = utilsFacade.filtering(list(data.columns), 'quant_')
		stochData = module_wise_normalization.perform_complexBased_normalization(data, quant_cols)
		com_yeast_dict['yeast16'] = stochData
		##################################################################################
		fname = 'yeast18_quant_proteome_MAPPED_complexes_pathways_NORM2.tsv.gz'
		data = DataFrameAnalyzer.getFile(folder, fname)
		quant_cols = utilsFacade.filtering(list(data.columns), 'quant_')
		stochData = module_wise_normalization.perform_complexBased_normalization(data, quant_cols)
		com_yeast_dict['yeast18'] = stochData
		##################################################################################
		fname = 'yeast19_quant_transcriptome_MAPPED_complexes_pathways_NORM2.tsv.gz'
		data = DataFrameAnalyzer.getFile(folder, fname)
		quant_cols = utilsFacade.filtering(list(data.columns), 'quant_')
		print('get_complex_correlation_values:yeast19')
		stochData = module_wise_normalization.perform_complexBased_normalization(data, quant_cols)
		com_yeast_dict['yeast19'] = stochData
		##################################################################################
		fname = 'yeast20_quant_proteome_MAPPED_complexes_pathways_NORM2.tsv.gz'
		data = DataFrameAnalyzer.getFile(folder, fname)
		quant_cols = utilsFacade.filtering(list(data.columns), 'quant_')
		stochData = module_wise_normalization.perform_complexBased_normalization(data, quant_cols)
		com_yeast_dict['yeast20'] = stochData
		##################################################################################
		fname = 'yeast21_quant_proteome_MAPPED_complexes_pathways_NORM2_REP.tsv.gz'
		data = DataFrameAnalyzer.getFile(folder, fname)
		quant_cols = utilsFacade.filtering(list(data.columns), 'quant_')
		stochData = module_wise_normalization.perform_complexBased_normalization(data, quant_cols)
		com_yeast_dict['yeast21'] = stochData

	@staticmethod
	def gather_pathway_stoichiometry_data(folder):
		yeast_datasets = ['yeast' + str(i) for i in [3, 4, 5, 10, 11, 14, 16, 18, 19, 20, 21]]
		pat_yeast_dict = dict((e1,dict()) for e1 in yeast_datasets)

		fname = 'yeast3_quant_Proteome_carbonSources_MAPPED_complexes_pathways_NORM2.tsv.gz'
		data = DataFrameAnalyzer.getFile(folder, fname)
		quant_cols = utilsFacade.filtering(list(data.columns), 'rel. Intensity')
		stochData = module_wise_normalization.perform_pathwayBased_normalization(data, quant_cols)
		pat_yeast_dict['yeast3'] = stochData
		##################################################################################
		fname = 'yeast4_quant_Proteome_naclStress_MAPPED_complexes_pathways_NORM2.tsv.gz'
		data = DataFrameAnalyzer.getFile(folder, fname)
		quant_cols = utilsFacade.filtering(list(data.columns), 'rel.Intensity')
		stochData = module_wise_normalization.perform_pathwayBased_normalization(data, quant_cols)
		pat_yeast_dict['yeast4'] = stochData
		##################################################################################
		fname = 'yeast5_quant_Proteome_MAPPED_complexes_pathways_NORM2_IBAQ.tsv.gz'
		data = DataFrameAnalyzer.getFile(folder, fname)
		quant_cols = utilsFacade.filtering(list(data.columns), 'iBAQ')[:-1]
		stochData = module_wise_normalization.perform_pathwayBased_normalization(data, quant_cols)
		pat_yeast_dict['yeast5'] = stochData
		##################################################################################
		fname = 'yeast10_quant_RNA_MAPPED_complexes_pathways_NORM2.tsv.gz'
		data = DataFrameAnalyzer.getFile(folder, fname)
		quant_cols = utilsFacade.filtering(list(data.columns), 'quant')
		stochData = module_wise_normalization.perform_pathwayBased_normalization(data, quant_cols)
		pat_yeast_dict['yeast10'] = stochData
		##################################################################################
		fname = 'yeast11_quant_proteome_MAPPED_complexes_pathways_NORM2.tsv.gz'
		data = DataFrameAnalyzer.getFile(folder, fname)
		quant_cols = utilsFacade.filtering(list(data.columns), 'quant')
		stochData = module_wise_normalization.perform_pathwayBased_normalization(data, quant_cols)
		pat_yeast_dict['yeast11'] = stochData
		##################################################################################
		fname = 'yeast14_quant_proteome_MAPPED_complexes_pathways_NORM2.tsv.gz'
		data = DataFrameAnalyzer.getFile(folder, fname)
		quant_cols = utilsFacade.filtering(list(data.columns), 'quant_')
		stochData = module_wise_normalization.perform_pathwayBased_normalization(data, quant_cols)
		pat_yeast_dict['yeast14'] = stochData
		##################################################################################
		fname = 'yeast16_quant_proteome_MAPPED_complexes_pathways_NORM2.tsv.gz'
		data = DataFrameAnalyzer.getFile(folder, fname)
		quant_cols = utilsFacade.filtering(list(data.columns), 'quant_')
		stochData = module_wise_normalization.perform_pathwayBased_normalization(data, quant_cols)
		pat_yeast_dict['yeast16'] = stochData
		##################################################################################
		fname = 'yeast18_quant_proteome_MAPPED_complexes_pathways_NORM2.tsv.gz'
		data = DataFrameAnalyzer.getFile(folder, fname)
		quant_cols = utilsFacade.filtering(list(data.columns), 'quant_')
		stochData = module_wise_normalization.perform_pathwayBased_normalization(data, quant_cols)
		pat_yeast_dict['yeast18'] = stochData
		##################################################################################
		fname = 'yeast19_quant_transcriptome_MAPPED_complexes_pathways_NORM2.tsv.gz'
		data = DataFrameAnalyzer.getFile(folder, fname)
		quant_cols = utilsFacade.filtering(list(data.columns), 'quant_')
		print('get_complex_correlation_values:yeast19')
		stochData = module_wise_normalization.perform_pathwayBased_normalization(data, quant_cols)
		pat_yeast_dict['yeast19'] = stochData
		##################################################################################
		fname = 'yeast20_quant_proteome_MAPPED_complexes_pathways_NORM2.tsv.gz'
		data = DataFrameAnalyzer.getFile(folder, fname)
		quant_cols = utilsFacade.filtering(list(data.columns), 'quant_')
		stochData = module_wise_normalization.perform_pathwayBased_normalization(data, quant_cols)
		pat_yeast_dict['yeast20'] = stochData
		##################################################################################
		fname = 'yeast21_quant_proteome_MAPPED_complexes_pathways_NORM2_REP.tsv.gz'
		data = DataFrameAnalyzer.getFile(folder, fname)
		quant_cols = utilsFacade.filtering(list(data.columns), 'quant_')
		stochData = module_wise_normalization.perform_pathwayBased_normalization(data, quant_cols)
		pat_yeast_dict['yeast21'] = stochData
		return pat_yeast_dict

	@staticmethod
	def get_complex_dictionary(data):
		complexes = list(set(data.complexId))
		all_complexes = list()
		for com in complexes:
			if str(com)!='nan':
				for c in com.split(';'):
					all_complexes.append(c)
		all_complexes = list(set(all_complexes))

		filtered_complexes = dict()
		for com in all_complexes:
			sub = data[data.complexId.str.contains(com, na = False)]
			sub = sub.drop_duplicates()
			if len(sub) >= 5:
				filtered_complexes.setdefault(com,[])
				filtered_complexes[com] = list(sub.index)
		return filtered_complexes

	@staticmethod
	def get_pathway_dictionary(data):
		complexes = list(set(data.pathway))
		all_complexes = list()
		for com in complexes:
			if str(com)!='nan':
				for c in com.split(';'):
					all_complexes.append(c)
		all_complexes = list(set(all_complexes))

		filtered_complexes = dict()
		for com in all_complexes:
			sub = data[data.pathway.str.contains(com, na = False, regex = False)]
			sub = sub.drop_duplicates()
			if len(sub) >= 5:
				filtered_complexes.setdefault(com,[])
				filtered_complexes[com] = list(sub.index)
		return filtered_complexes

	@staticmethod
	def check_protein_in_otherComplexes(filtered_complexes, protein, complex_analyzed):
		complexMedianList = list()
		returnValue = False
		considered_complexes = list(filtered_complexes.keys())
		for complexID in considered_complexes:
			if complexID!=complex_analyzed:
				members = filtered_complexes[complexID]
				if protein in members:
					print(complexID)
					returnValue = True
					break
		return returnValue

	@staticmethod
	def check_protein_in_otherComplexes_and_getMedian(filtered_complexes, protein, complex_analyzed, data, quant_cols):
		complexMedianList = list()
		considered_complexes = list(filtered_complexes.keys())
		for complexID in considered_complexes:
			#if complexID!=complex_analyzed:
			members = filtered_complexes[complexID]
			if protein in members:
				genes = filtered_complexes[complexID]
				sub = data.T[genes].T[quant_cols]
				medianList = list(sub.median())
				complexMedianList.append(medianList)
		if len(complexMedianList) > 1:
			averageList = module_wise_normalization.getMeanList(complexMedianList)
		else:
			averageList = complexMedianList[0]
		return averageList
	
	@staticmethod
	def getMeanList(lists):
	    """
	    input: lists (nested lists)
	    This function calculates the medianProfile out of several lists.
	    """	
	    meanList = list()
	    for i in xrange(len(lists[0])):
	        tempList = list()
	        for l in lists:
	        	tempList.append(l[i])
	        meanList.append(numpy.mean(filter(lambda a:str(a)!="nan",tempList)))
	    return meanList

	@staticmethod
	def trimmean(arr,**kwargs):
	    """
	    calculates the trimmed mean for a given array (percentage 25% -- kwargs: percent)
	    """
	    percent = kwargs.get("percent",25)
	    n = len(arr)
	    k = int(round(n*(float(percent)/100)/2))
	    return numpy.median(sorted(arr[k+1:n-k]))

	@staticmethod
	def perform_complexBased_normalization(data, quant_cols):
		#1) relative abundances of proteins with respect to median
		#2) abundance value of protein was corrected by subtracting the mean relative abundance of the rest of the complex

		filtered_complexes = module_wise_normalization.get_complex_dictionary(data)
		considered_complexes = list(filtered_complexes.keys())

		concatList = list()
		for key in considered_complexes:
			proteins = filtered_complexes[key]

			complexData = data.T[proteins].T[quant_cols + ['symbol']]
			complexData = complexData.drop_duplicates()
			complexData = complexData.drop('symbol', axis = 1)
			returnValue = False
			stochiometryData = list()
			coverage_list = list()
			for count,member in enumerate(proteins):
				returnValue = module_wise_normalization.check_protein_in_otherComplexes(filtered_complexes, member, key)
				if returnValue == True:
					medianList = module_wise_normalization.check_protein_in_otherComplexes_and_getMedian(filtered_complexes,
														   member, key, data, quant_cols)
				else:
					medianList = list(complexData.median())
				temp = list(complexData.loc[member])
				temp_finite = utilsFacade.finite(temp)
				coverage = float(len(temp_finite))/float(len(temp))
				coverage_list.append(coverage)
				tempList = numpy.array(complexData.loc[member])-numpy.array(medianList)
				stochiometryData.append(tempList)
			corrData = pd.DataFrame(stochiometryData)
			corrData.index = proteins
			corrData.columns = quant_cols

			varList = corrData.T.var()
			corrData['relative_variance'] = pd.Series(varList, index = corrData.index)
			corrData['coverage'] = pd.Series(coverage_list, index = corrData.index)
			corrData['complexId'] = pd.Series([key]*len(coverage_list), index = corrData.index)
			concatList.append(corrData)
		total_data = pd.concat(concatList)
		return total_data

	@staticmethod
	def perform_pathwayBased_normalization(data, quant_cols):
		#reproduce complex-wise normalization the way it was done before
		#1) relative abundances of proteins with respect to median
		#2) abundance value of protein was corrected by subtracting the mean relative abundance of the rest of the complex

		filtered_complexes = module_wise_normalization.get_pathway_dictionary(data)
		considered_complexes = list(filtered_complexes.keys())

		concatList = list()
		for key in considered_complexes:
			proteins = filtered_complexes[key]

			complexData = data.T[proteins].T[quant_cols + ['symbol']]
			complexData = complexData.drop_duplicates()
			complexData = complexData.drop('symbol', axis = 1)
			returnValue = False
			stochiometryData = list()
			coverage_list = list()
			for count,member in enumerate(proteins):
				returnValue = module_wise_normalization.check_protein_in_otherComplexes(filtered_complexes, member, key)
				if returnValue == True:
					medianList = module_wise_normalization.check_protein_in_otherComplexes_and_getMedian(filtered_complexes,
														   member, key, data, quant_cols)
				else:
					medianList = list(complexData.median())
				temp = list(complexData.loc[member])
				temp_finite = utilsFacade.finite(temp)
				coverage = float(len(temp_finite))/float(len(temp))
				coverage_list.append(coverage)
				tempList = numpy.array(complexData.loc[member])-numpy.array(medianList)
				stochiometryData.append(tempList)
			corrData = pd.DataFrame(stochiometryData)
			corrData.index = proteins
			corrData.columns = quant_cols

			varList = corrData.T.var()
			corrData['relative_variance'] = pd.Series(varList, index = corrData.index)
			corrData['coverage'] = pd.Series(coverage_list, index = corrData.index)
			corrData['complexId'] = pd.Series([key]*len(coverage_list), index = corrData.index)
			concatList.append(corrData)
		total_data = pd.concat(concatList)
		return total_data

	@staticmethod
	def export_complex_stoichiometry_data(folder, com_yeast_dict):
		DataFrameAnalyzer.to_pickle(com_yeast_dict, folder + 'modulewise_norm_complex_stoichiometry.pkl')

	@staticmethod
	def export_pathway_stoichiometry_data(folder, pat_yeast_dict):
		DataFrameAnalyzer.to_pickle(pat_yeast_dict, folder + 'modulewise_norm_pathway_stoichiometry.pkl')

if __name__ == "__main__":
	m1 = module_wise_normalization.complex_execute(folder = sys.argv[1])
	m2 = module_wise_normalization.pathway_execute(folder = sys.argv[1])
