# ------------------------------------------------------------
# Author: Natalie Romanov
# ------------------------------------------------------------

class supp_figure5a:
	@staticmethod
	def execute(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		print('SUPP-FIGURE4A: main_suppFig5a_aebersold_liu_comparison')
		supp_figure5a.main_suppFig5a_aebersold_liu_comparison(folder = folder, output_folder = output_folder)

	@staticmethod
	def main_suppFig5a_aebersold_liu_comparison(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		print('load_aebersold_liu_data')
		data = supp_figure5a.load_aebersold_liu_data(folder = folder, output_folder = output_folder)

		print('comparison_aebersold_liu_data')
		supp_figure5a.comparison_aebersold_liu_data(data, folder = folder, output_folder = output_folder)

	@staticmethod
	def load_aebersold_liu_data(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		filename = 'aebersold_liu_twin_data.txt'
		data = DataFrameAnalyzer.getFile(folder, filename)
		data.index = data['Protein Symbol']
		return data

	@staticmethod
	def comparison_aebersold_liu_data(aedata, **kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		fname = 'protein_variation.tsv.gz'
		data = DataFrameAnalyzer.open_in_chunks(folder, fname)
		data.index = [su.upper() for su in map(str,list(data['subunit']))]
		sex_all_quant = data[data.covariates=='sex']		
		diet_all_quant = data[data.covariates=='diet']

		overlapping_proteins = list(set(aedata.index).intersection(set(sex_all_quant.index)))
		aedata = aedata.T[overlapping_proteins].T
		print(aedata.shape)
		heritability_dict = aedata['Heritability'].to_dict()
		environment_dict = aedata['Individual environment'].to_dict()

		sub_sex_data = sex_all_quant.loc[overlapping_proteins]
		sex_dict = sub_sex_data['r2.all.module'].to_dict()
		sub_diet_data = diet_all_quant.loc[overlapping_proteins]
		diet_dict = sub_diet_data['r2.all.module'].to_dict()

		ae_heritability_list = list()
		ae_environment_list = list()
		ou_gender_list = list()
		ou_diet_list = list()
		genetic_protein_list = list()
		env_protein_list = list()
		for protein in overlapping_proteins:
			if heritability_dict[protein]>0 and environment_dict[protein]>0:
				ou_gender_list.append(sex_dict[protein])
				ae_heritability_list.append(heritability_dict[protein])
				genetic_protein_list.append(protein)
				ou_diet_list.append(diet_dict[protein])
				ae_environment_list.append(environment_dict[protein])
				env_protein_list.append(protein)
		
		#51 proteins compared
		genetic_df = pd.DataFrame({'gender':ou_gender_list,
						   		   'heritability':ae_heritability_list})
		env_df = pd.DataFrame({'environment':ae_environment_list,'diet':ou_diet_list})
		genetic_df.index = genetic_protein_list
		env_df.index = env_protein_list

		df = pd.DataFrame({'gender':ou_gender_list,
						   'heritability':ae_heritability_list,
						   'environment':ae_environment_list,
						   'diet':ou_diet_list})
		df.index = env_protein_list

		rand_corrs = list()
		for i in np.arange(1,100):
			table_list = list()
			for i,row in df.iterrows():
				temp = list(row)
				np.random.shuffle(temp)
				table_list.append(temp)
			rand_data = pd.DataFrame(table_list)
			rand_data = rand_data.T
			table_list = list()
			for i,row in rand_data.iterrows():
				temp = list(row)
				np.random.shuffle(temp)
				table_list.append(temp)
			rand_data = pd.DataFrame(table_list)
			rand_data = rand_data.T
			rand_corrs.append(utilsFacade.get_correlation_values(rand_data.corr(method = 'spearman')))
		rand_corrs = utilsFacade.flatten(rand_corrs)
		rand_corrs = np.array(rand_corrs)

		corrData = df.corr(method = 'spearman')
		corrData = corrData[['gender','diet']]
		corrData = corrData.T[['heritability','environment']].T

		pvalues = list()
		for i,row in corrData.iterrows():
			c1,c2 = list(row)
			zscore1 = utilsFacade.getSampleZScore(rand_corrs, c1)
			zscore2 = utilsFacade.getSampleZScore(rand_corrs, c2)
			p1,p2 = utilsFacade.zscore_to_pvalue(zscore1)
			temp = [p1]
			p1,p2 = utilsFacade.zscore_to_pvalue(zscore2)
			temp.append(p1)
			pvalues.append(temp)

		sns.set(context='notebook', style='white', 
			palette='deep', font='Liberation Sans', font_scale=1, 
			color_codes=False, rc=None)
		plt.rcParams["axes.grid"] = False

		plt.clf()
		fig = plt.figure(figsize = (1,1))
		ax = fig.add_subplot(111)
		sns.heatmap(corrData, cmap = plt.cm.RdBu, linewidth = 0.5, cbar = False, annot = True)
		plt.savefig(folder + 'suppFig5a_comparison_aebersold.pdf', bbox_inches = 'tight', dpi = 400)

		#make a heatmap with ranking for supp
		rank_list = list()
		for col in list(df.columns):
			rank_list.append(rankdata(list(df[col])))
		ranked_df = pd.DataFrame(rank_list)
		ranked_df = ranked_df.T
		ranked_df.index =  df.index
		ranked_df.columns = df.columns
		ranked_df, pro = utilsFacade.recluster_matrix_only_rows(ranked_df)
		ranked_df = ranked_df.T

		plt.clf()
		fig = plt.figure(figsize = (20,3))
		ax = fig.add_subplot(111)
		sns.heatmap(ranked_df, cmap = plt.cm.coolwarm, linewidth = 1.0)
		plt.savefig(folder + 'suppFig5a_comparison_aebersold_heatmap_rankedDF.pdf',
					bbox_inches = 'tight', dpi = 400)

class supp_figure5b:
	@staticmethod
	def execute(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		print('SUPP-FIGURE5B: main_supp_figure5b_volcanoPlots')
		supp_figure5b.main_supp_figure5b_volcanoPlots(folder = folder, output_folder = output_folder)

		print('SUPP-FIGURE5B: main_supp_figure5b_boxPlots')
		supp_figure5b.main_supp_figure5b_boxPlots(folder = folder, output_folder = output_folder)

	@staticmethod
	def main_supp_figure5b_volcanoPlots(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		print('SUPP-FIGURE5B: suppFigure5b_volcanoPlot_pyruvateDH_hsComparison')
		supp_figure5b.suppFigure5b_volcanoPlot_pyruvateDH_hsComparison(folder = folder, output_folder = output_folder)

	@staticmethod
	def main_supp_figure5b_boxPlots(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		print('SUPP-FIGURE5B: suppFigure5b_pyruvateDH_boxplot_logFC_hsComparison')
		supp_figure5b.suppFigure5b_pyruvateDH_boxplot_logFC_hsComparison(folder = folder, output_folder = output_folder)

	@staticmethod
	def suppFigure5b_volcanoPlot_pyruvateDH_hsComparison(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		filename = 'gygi3_complex_hs_merged_perGeneRun.tsv.gz'
		data = DataFrameAnalyzer.open_in_chunks(folder, filename)
		data = data.drop_duplicates()
		data = data[data.analysis_type=='stoch_highfat-chow']
		data = data.drop(['C9','C2','Psd3'], axis = 0)
		complex_data = data[data.complex_search.str.contains('pyruvate')]

		pval_list = -np.log10(np.array(data['pval.adj']))
		fc_list = np.array(data['logFC'])

		complex_pval_list = -np.log10(np.array(complex_data['pval.adj']))
		complex_fc_list = np.array(complex_data['logFC'])
		complex_label_list = list(complex_data.index)
		complex_colors = ['red' if item<=0.01 else 'darkblue' for item in list(complex_data['pval.adj'])]

		sns.set(context='notebook', style='white', 
			palette='deep', font='Liberation Sans', font_scale=1, 
			color_codes=False, rc=None)
		plt.rcParams["axes.grid"] = True

		plt.clf()
		fig = plt.figure(figsize = (5,5))
		ax = fig.add_subplot(111)
		ax.scatter(fc_list, pval_list, edgecolor = 'black', s = 30, color = 'white', alpha = 0.2)
		ax.scatter(complex_fc_list, complex_pval_list,
			edgecolor = 'black', s = 50, color = complex_colors)
		for count, label in enumerate(complex_label_list):
			x,y = complex_fc_list[count], complex_pval_list[count]
			ax.annotate(label, xy = (x,y), color = complex_colors[count])
		ax.set_xlabel('logFC (highfat/chow)')
		ax.set_ylabel('p.value[-log10]')
		ax.set_xlim(-1.25,1.25)
		ax.set_ylim(-0.01, 14)
		plt.savefig(folder + 'suppFig5b_pyruvateDH_volcano_hs_comparison.pdf', bbox_inches = 'tight', dpi = 400)

	@staticmethod
	def suppFigure5b_pyruvateDH_boxplot_logFC_hsComparison(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		filename = 'gygi3_complex_hs_merged_perGeneRun.tsv.gz'
		data = DataFrameAnalyzer.open_in_chunks(folder, filename)
		data = data.drop_duplicates()
		data = data.drop(['C9','C2','Psd3'], axis = 0)
		complex_data = data[data.complex_search.str.contains('pyruvate')]
		male_data = complex_data[complex_data.analysis_type=='highfat']
		female_data = complex_data[complex_data.analysis_type=='chow']
		stoch_maleFemale_data = complex_data[complex_data.analysis_type=='stoch_highfat-chow']
		pval_dict = stoch_maleFemale_data['pval.adj'].to_dict()

		wpc_data = DataFrameAnalyzer.getFile(folder, 'data/complex_filtered_stoch_gygi3.tsv.gz')
		wpc_data = wpc_data[wpc_data.complex_search.str.contains('pyruvate')]
		quant_list = utilsFacade.filtering(list(wpc_data.columns),'quant_')
		male_list = utilsFacade.filtering(quant_list, 'H')
		female_list = utilsFacade.filtering(quant_list, 'S')
		male_data = wpc_data[male_list].T
		female_data = wpc_data[female_list].T

		protein_list = list(set(male_data.columns).intersection(set(female_data.columns)))
		pval_list = [pval_dict[protein] for protein in protein_list]
		pval_list, protein_list = zip(*sorted(zip(pval_list, protein_list)))
		data_list = list()
		color_list = list()
		sig_protein_list = list()
		positions = [0.5]
		for protein in protein_list:
			if pval_dict[protein]<=0.01:
				data_list.append(utilsFacade.finite(list(male_data[protein])))
				data_list.append(utilsFacade.finite(list(female_data[protein])))
				color_list.append('purple')
				color_list.append('green')
				sig_protein_list.append(protein)
				sig_protein_list.append(protein)
				positions.append(positions[-1]+0.4)
				positions.append(positions[-1]+0.6)

		male_median_data = male_data.T.median().T
		female_median_data = female_data.T.median().T
		data_list.append(list(male_median_data))
		data_list.append(list(female_median_data))
		color_list.append('grey')
		color_list.append('grey')
		sig_protein_list.append('pyruvate-highfat')
		sig_protein_list.append('pyruvate-chow')
		positions.append(positions[-1]+0.5)

		sns.set(context='notebook', style='white', 
			palette='deep', font='Liberation Sans', font_scale=1, 
			color_codes=False, rc=None)
		plt.rcParams["axes.grid"] = True

		plt.clf()
		fig = plt.figure(figsize = (5,5))
		ax = fig.add_subplot(111)
		bp = ax.boxplot(data_list,notch=0,sym="",vert=1,patch_artist=True,
						widths=[0.4]*len(data_list), positions = positions)
		plt.setp(bp['medians'], color="black")
		plt.setp(bp['whiskers'], color="black",linestyle="--",alpha=0.8)
		for i,patch in enumerate(bp['boxes']):
			patch.set_edgecolor("black")
			patch.set_alpha(0.6)
			patch.set_color(color_list[i])
			sub_group=data_list[i]
			x = numpy.random.normal(positions[i], 0.04, size=len(data_list[i]))
			ax.scatter(x,sub_group,color='white', alpha=0.5,edgecolor="black",s=5)
		ax.set_title('pyruvate complex')
		ax.set_xticklabels(sig_protein_list, rotation = 90)
		ax.set_ylim(-1,1)
		ax.set_ylabel('normalized abundances')
		plt.savefig(folder + 'suppFig5b_pyruvateDH_boxplot_highfat_chow.pdf',bbox_inches = 'tight', dpi = 400)

		com_data = DataFrameAnalyzer.getFile(folder, 'complex_filtered_gygi3.tsv.gz')
		com_data = com_data[com_data.complex_search.str.contains('pyruvate')]
		quant_list = utilsFacade.filtering(list(com_data.columns), 'quant_')
		hlist = utilsFacade.filtering(quant_list, 'H')
		clist = utilsFacade.filtering(quant_list, 'S')
		hdata = com_data[hlist]
		cdata = com_data[clist]
		h_median = utilsFacade.finite(list(hdata.median()))
		c_median = utilsFacade.finite(list(cdata.median()))
		data_list = [h_median, c_median]

		sns.set(context='notebook', style='white', 
			palette='deep', font='Liberation Sans', font_scale=1, 
			color_codes=False, rc=None)
		plt.rcParams["axes.grid"] = True

		plt.clf()
		fig = plt.figure(figsize = (5,5))
		ax = fig.add_subplot(111)
		bp = ax.boxplot(data_list,notch=0,sym="",vert=1,patch_artist=True,
						widths=[0.4]*len(data_list))
		plt.setp(bp['medians'], color="black")
		plt.setp(bp['whiskers'], color="black",linestyle="--",alpha=0.8)
		for i,patch in enumerate(bp['boxes']):
			patch.set_edgecolor("black")
			patch.set_alpha(0.6)
			patch.set_color(color_list[i])
			sub_group=data_list[i]
			x = numpy.random.normal(i+1, 0.04, size=len(data_list[i]))
			ax.scatter(x,sub_group,color='white', alpha=0.5,edgecolor="black",s=5)
		ax.set_title('pyruvate complex')
		ax.set_ylabel('complex median abundances')
		plt.savefig(folder + 'suppFig5b_Abundances_pyruvateDH_boxplot_highfat_chow.pdf', bbox_inches = 'tight', dpi = 400)


if __name__ == "__main__":
	## EXECUTE SUPPFIGURE5
	supp_figure5a.execute(folder = sys.argv[1], output_folder = sys.argv[2])
	supp_figure5b.execute(folder = sys.argv[1], output_folder = sys.argv[2])
	

