# ------------------------------------------------------------
# Author: Natalie Romanov
# ------------------------------------------------------------

class figure5a:
	
	@staticmethod
	def execute(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		print('FIGURE5A: figure5a_complex_abundanceChange_demonstration')
		figure5a.figure5a_complex_abundanceChange_demonstration()

		print('FIGURE5A: figure5a_overview_stoichiometry')
		figure5a.figure5a_overview_stoichiometry()

	@staticmethod
	def get_data(**kwargs):
		folder = kwargs.get('folder','PATH')

		data_dict = DataFrameAnalyzer.read_pickle(folder + 'figure5a_data_dictionary.pkl')
		relevant_complexes = ['HC663:COPII',
							  'HC2402:COPI',
							  'HC1479:retromer complex',
					  		  'SMC_AM000047:Cohesin complex']
		return data_dict, relevant_complexes

	@staticmethod
	def figure5a_complex_abundanceChange_demonstration(**kwargs):
		folder = kwargs.get('folder','PATH')

		data_dict, relevant_complexes = figure5a.get_data(folder = folder)

		sns.set(context='notebook', style='white', 
			palette='deep', font='Liberation Sans', font_scale=1, 
			color_codes=False, rc=None)
		plt.rcParams["axes.grid"] = True

		positions = [1,1.5]
		plt.clf()
		fig = plt.figure(figsize = (10,5))
		ax = fig.add_subplot(141)
		dataList = [data_dict[relevant_complexes[0]][0]
					data_dict[relevant_complexes[0]][1]]
		bp = ax.boxplot(dataList,notch=0,sym="",vert=1,
						patch_artist=True, widths=[0.4]*len(dataList),
						positions=positions)
		plt.setp(bp['medians'], color="black")
		plt.setp(bp['whiskers'], color="black",linestyle="--",alpha=0.8)
		for i,patch in enumerate(bp['boxes']):
			if i%2==0:
				patch.set_facecolor("#7FBF7F")	
			else:
				patch.set_facecolor("orange")	
			patch.set_edgecolor("black")
			patch.set_alpha(0.9)
			sub_group = dataList[i]
			x = numpy.random.normal(positions[i], 0.04, size=len(dataList[i]))
			ax.scatter(x,sub_group,color='white', alpha=0.5,edgecolor="black",s=5)
		ax.set_ylim(7,14)

		ax = fig.add_subplot(142)
		dataList = [data_dict[relevant_complexes[1]][0],data_dict[relevant_complexes[1]][1]]
		bp = ax.boxplot(dataList,notch=0,sym="",vert=1,patch_artist=True,
						widths=[0.4]*len(dataList),positions=positions)
		plt.setp(bp['medians'], color="black")
		plt.setp(bp['whiskers'], color="black",linestyle="--",alpha=0.8)
		for i,patch in enumerate(bp['boxes']):
			if i%2==0:
				patch.set_facecolor("#7FBF7F")	
			else:
				patch.set_facecolor("orange")	
			patch.set_edgecolor("black")
			patch.set_alpha(0.9)
			sub_group = dataList[i]
			x = numpy.random.normal(positions[i], 0.04, size=len(dataList[i]))
			ax.scatter(x,sub_group,color='white', alpha=0.5,edgecolor="black",s=5)
		ax.set_ylim(7,14)

		ax = fig.add_subplot(143)
		dataList = [data_dict[relevant_complexes[2]][0],data_dict[relevant_complexes[2]][1]]
		bp = ax.boxplot(dataList,notch=0,sym="",vert=1, patch_artist=True,
						widths=[0.4]*len(dataList),positions=positions)
		plt.setp(bp['medians'], color="black")
		plt.setp(bp['whiskers'], color="black",linestyle="--",alpha=0.8)
		for i,patch in enumerate(bp['boxes']):
			if i%2==0:
				patch.set_facecolor("#7FBF7F")	
			else:
				patch.set_facecolor("orange")	
			patch.set_edgecolor("black")
			patch.set_alpha(0.9)
			sub_group = dataList[i]
			x = numpy.random.normal(positions[i], 0.04, size=len(dataList[i]))
			ax.scatter(x,sub_group,color='white', alpha=0.5,edgecolor="black",s=5)
		ax.set_ylim(7,14)

		ax = fig.add_subplot(144)
		dataList = [data_dict[relevant_complexes[3]][0],data_dict[relevant_complexes[3]][1]]
		bp = ax.boxplot(dataList,notch=0,sym="",vert=1, patch_artist=True,
						widths=[0.4]*len(dataList),positions=positions)
		plt.setp(bp['medians'], color="black")
		plt.setp(bp['whiskers'], color="black",linestyle="--",alpha=0.8)
		for i,patch in enumerate(bp['boxes']):
			if i%2==0:
				patch.set_facecolor("#7FBF7F")	
			else:
				patch.set_facecolor("orange")	
			patch.set_edgecolor("black")
			patch.set_alpha(0.9)
			sub_group = dataList[i]
			x = numpy.random.normal(positions[i], 0.04, size=len(dataList[i]))
			ax.scatter(x,sub_group,color='white', alpha=0.5,edgecolor="black",s=5)
		ax.set_ylim(7,14)
		plt.savefig(folder + 'fig5a_complex_abundances.pdf', bbox_inches = 'tight',dpi = 400)

	@staticmethod
	def figure5a_overview_stoichiometry(**kwargs):
		folder = kwargs.get('folder','PATH')

		df = DataFrameAnalyzer.open_in_chunks(folder,'figure5a_bigPicture_underlying_data.tsv.gz')
		complex_list = list(df['complex'])
		cohen_list = list(df['cohen'])
		tt_pval_list = list(df['tt_pval'])
		stable_fractions = list(df['stable'])
		variable_fractions = list(df['variable'])
		tt_pvalCorrs = list(df['tt_pvalCorr'])
		tt1 = list(df.tt1)[0]
		tt2 = list(df.tt2)[0]

		sns.set(context='notebook', style='white', 
			palette='deep', font='Liberation Sans', font_scale=1, 
			color_codes=False, rc=None)
		plt.rcParams["axes.grid"] = False

		plt.clf()
		fig = plt.figure(figsize = (15,5))
		gs = gridspec.GridSpec(10,32)
		ax = plt.subplot(gs[0:5,0:])
		ind = np.arange(len(complex_list))
		width = 0.85
		ax.axhline(0, color = 'k')
		ax.axhline(0.5,linestyle = '--', color = 'k')
		ax.axhline(-0.5,linestyle = '--', color = 'k')
		ax.axhline(1.0,linestyle = '--', color = 'k')
		ax.axhline(-1.0,linestyle = '--', color = 'k')
		ax.axhline(1.5,linestyle = '--', color = 'k')
		ax.axhline(-1.5,linestyle = '--', color = 'k')
		ax.axhline(tt1, color = 'red')
		ax.axhline(tt2, color = 'red')
		colors = list()
		for p,pval in enumerate(tt_pval_list):
			if str(pval)=='nan':
				colors.append('grey')
			else:
				if cohen_list[p]>=tt1:
					colors.append('purple')
				elif cohen_list[p]<=tt2:
					colors.append('lightgreen')
				else:
					colors.append('grey')
		rects = ax.bar(ind, cohen_list, width, color=colors,
			edgecolor = 'white')
		ax.set_ylim(-2,2)
		ax.set_xlim(-0.1, len(complex_list)+0.1)
		ax.set_xticklabels([])

		ax = plt.subplot(gs[5:7,0:])
		rects1 = ax.bar(ind, stable_fractions, width, color='darkblue',
						edgecolor = 'white')
		rects2 = ax.bar(ind, variable_fractions, width, color='red',
						bottom =np.array(stable_fractions),
						edgecolor = 'white')
		ax.set_ylim(0,100)
		ax.set_xlim(-0.1, len(complex_list)+0.1)
		plt.xticks(list(utilsFacade.frange(0.5,len(complex_list)+0.5,1)))
		ax.set_xticklabels(complex_list,rotation = 90, fontsize = 7)
		plt.savefig(folder + 'fig5a_abundance_stoichiometry_comparison.pdf',
					bbox_inches = 'tight', dpi = 400)

	@staticmethod
	def figure5a_overview_stoichiometry_diet(**kwargs):
		folder = kwargs.get('folder','PATH')

		df = DataFrameAnalyzer.open_in_chunks(folder,'figure5a_bigPicture_diet_underlying_data.tsv.gz')
		complex_list = list(df['complex'])
		cohen_list = list(df['cohen'])
		tt_pval_list = list(df['tt_pval'])
		stable_fractions = list(df['stable'])
		variable_fractions = list(df['variable'])
		tt_pvalCorrs = list(df['tt_pvalCorr'])
		tt1 = list(df.tt1)[0]
		tt2 = list(df.tt2)[0]

		sns.set(context='notebook', style='white', 
			palette='deep', font='Liberation Sans', font_scale=1, 
			color_codes=False, rc=None)
		plt.rcParams["axes.grid"] = False

		plt.clf()
		fig = plt.figure(figsize = (15,5))
		gs = gridspec.GridSpec(10,32)
		ax = plt.subplot(gs[0:5,0:])
		ind = np.arange(len(complex_list))
		width = 0.85
		ax.axhline(0, color = 'k')
		ax.axhline(0.5,linestyle = '--', color = 'k')
		ax.axhline(-0.5,linestyle = '--', color = 'k')
		ax.axhline(1.0,linestyle = '--', color = 'k')
		ax.axhline(-1.0,linestyle = '--', color = 'k')
		ax.axhline(1.5,linestyle = '--', color = 'k')
		ax.axhline(-1.5,linestyle = '--', color = 'k')
		ax.axhline(tt1, color = 'red')
		ax.axhline(tt2, color = 'red')
		colors = list()
		for p,pval in enumerate(tt_pval_list):
			if str(pval)=='nan':
				colors.append('grey')
			else:
				if cohen_list[p]>=tt1:
					colors.append('purple')
				elif cohen_list[p]<=tt2:
					colors.append('lightgreen')
				else:
					colors.append('grey')
		rects = ax.bar(ind, cohen_list, width, color=colors,
			edgecolor = 'white')
		ax.set_ylim(-2,2)
		ax.set_xlim(-0.1, len(complex_list)+0.1)
		ax.set_xticklabels([])

		ax = plt.subplot(gs[5:7,0:])
		rects1 = ax.bar(ind, stable_fractions, width, color='darkblue',
						edgecolor = 'white')
		rects2 = ax.bar(ind, variable_fractions, width, color='red',
						bottom =np.array(stable_fractions),
						edgecolor = 'white')
		ax.set_ylim(0,100)
		ax.set_xlim(-0.1, len(complex_list)+0.1)
		plt.xticks(list(utilsFacade.frange(0.5,len(complex_list)+0.5,1)))
		ax.set_xticklabels(complex_list,rotation = 90, fontsize = 7)
		plt.savefig(folder + 'fig5a_abundance_diet_stoichiometry_comparison.pdf', 
					bbox_inches = 'tight', dpi = 400)

class figure5b:
	@staticmethod
	def execute(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		print('FIGURE5B: main_figure5b_volcanoPlots')
		figure5b.main_figure5b_volcanoPlots()

		print('FIGURE5B: main_figure5b_boxPlots')
		figure5b.main_figure5b_boxPlots()

	@staticmethod
	def get_data(name,**kwargs):
		folder = kwargs.get('folder','PATH')

		filename = 'gygi3_complex_mf_merged_perGeneRun.tsv.gz'
		data = DataFrameAnalyzer.open_in_chunks(folder, filename)
		data = data.drop_duplicates()
		data = data[data.analysis_type=='stoch_male-female']
		data = data.drop(['C9','C2','Psd3'], axis = 0)
		complex_data = data[data.complex_search.str.contains(name)]

		pval_list = -np.log10(np.array(data['pval.adj']))
		fc_list = np.array(data['logFC'])
		complex_pval_list = -np.log10(np.array(complex_data['pval.adj']))
		complex_fc_list = np.array(complex_data['logFC'])
		complex_label_list = list(complex_data.index)
		complex_colors = ['red' if item<=0.01 else 'darkblue' for item in list(complex_data['pval.adj'])]

		return {'pval':pval_list,
				'fc': fc_list,
				'complex_pval': complex_pval_list,
				'complex_fc':complex_fc_list,
				'complex_label': complex_label_list,
				'complex_color':complex_colors}

	@staticmethod
	def main_figure5b_volcanoPlots(**kwargs):
		output_folder = kwargs.get('output_folder','PATH')

		print('volcano_plot_stoichiometry:COPI')
		figure5b.figure_volcanoPlot_complex_mfComparison('HC2402:COPI', 'cop1', folder = output_folder)

		print('volcano_plot_stoichiometry:COPII')
		figure5b.figure_volcanoPlot_complex_mfComparison('HC663:COPII', 'cop2', folder = output_folder)

		print('volcano_plot_stoichiometry:Cohesin')
		figure5b.figure_volcanoPlot_complex_mfComparison('Cohesin', 'cohesin', folder = output_folder)
		
		print('volcano_plot_stoichiometry:retromer complex')
		figure5b.figure_volcanoPlot_complex_mfComparison('retromer', 'retromer', folder = output_folder)

	@staticmethod
	def main_figure5b_boxPlots(**kwargs):
		output_folder = kwargs.get('output_folder','PATH')

		print('boxplot:COPI')
		figure5b.figure_boxplot_complex_logFC_male_female('HC2402:COPI', 'cop1', folder = output_folder)

		print('boxplot:COPII')
		figure5b.figure_boxplot_complex_logFC_male_female('HC663:COPII', 'cop2', folder = output_folder)

		print('boxplot:retromer complex')
		figure5b.figure_boxplot_complex_logFC_male_female('retromer', 'retromer', folder = output_folder)

		print('boxplot:Cohesin')
		figure5b.figure_boxplot_complex_logFC_male_female('Cohesin', 'cohesin', folder = output_folder)

	@staticmethod
	def figure_volcanoPlot_complex_mfComparison(complex_id, output_name, **kwargs):
		folder = kwargs.get('folder','PATH')

		data_dict = figure5b.get_data(complex_id)
		pval_list = data_dict['pval']
		fc_list = data_dict['fc']
		complex_pval_list = data_dict['complex_pval']
		complex_fc_list = data_dict['complex_fc']
		complex_label_list = data_dict['complex_label']
		complex_colors = data_dict['complex_color']

		sns.set(context='notebook', style='white', 
			palette='deep', font='Liberation Sans', font_scale=1, 
			color_codes=False, rc=None)
		plt.rcParams["axes.grid"] = True

		plt.clf()
		fig = plt.figure(figsize = (5,5))
		ax = fig.add_subplot(111)
		ax.scatter(fc_list, pval_list, edgecolor = 'black',
				   s = 30, color = 'white', alpha = 0.2)
		ax.scatter(complex_fc_list, complex_pval_list,
				   edgecolor = 'black', s = 50, color = complex_colors)
		for count, label in enumerate(complex_label_list):
			x,y = complex_fc_list[count], complex_pval_list[count]
			ax.annotate(label, xy = (x,y), color = complex_colors[count])
		ax.set_xlabel('logFC (male/female)')
		ax.set_ylabel('p.value[-log10]')
		ax.set_xlim(-1.25,1.25)
		ax.set_ylim(-0.01, 35)
		plt.savefig(folder + 'fig5b_' + output_name + '_volcano_mf_comparison.pdf',
					bbox_inches = 'tight', dpi = 400)

	@staticmethod
	def get_data_boxplot(name,**kwargs):
		folder = kwargs.get('folder','PATH')

		filename = 'gygi3_complex_mf_merged_perGeneRun.tsv.gz'
		data = DataFrameAnalyzer.open_in_chunks(folder, filename)
		data = data.drop_duplicates()
		data = data.drop(['C9','C2','Psd3'], axis = 0)
		complex_data = data[data.complex_search.str.contains(name)]
		male_data = complex_data[complex_data.analysis_type=='male']
		female_data = complex_data[complex_data.analysis_type=='female']
		stoch_maleFemale_data = complex_data[complex_data.analysis_type=='stoch_male-female']
		pval_dict = stoch_maleFemale_data['pval.adj'].to_dict()

		wpc_data = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_stoch_gygi3.tsv.gz')
		wpc_data = wpc_data[wpc_data.complex_search.str.contains(name)]
		quant_list = utilsFacade.filtering(list(wpc_data.columns),'quant_')
		male_list = utilsFacade.filtering(quant_list, 'M')
		female_list = utilsFacade.filtering(quant_list, 'F')
		male_data = wpc_data[male_list].T
		female_data = wpc_data[female_list].T

		protein_list = list(set(male_data.columns).intersection(set(female_data.columns)))
		pval_list = [pval_dict[protein] for protein in protein_list]
		pval_list, protein_list = zip(*sorted(zip(pval_list, protein_list)))
		data_list = list()
		color_list = list()
		sig_protein_list = list()
		positions = [0.5]
		for protein in protein_list:
			if pval_dict[protein]<=0.01:
				try:
					data_list.append(utilsFacade.finite(list(male_data[protein])))
					data_list.append(utilsFacade.finite(list(female_data[protein])))
					color_list.append('purple')
					color_list.append('green')
					sig_protein_list.append(protein)
					sig_protein_list.append(protein)
					positions.append(positions[-1]+0.4)
					positions.append(positions[-1]+0.6)
				except:
					data_list.append(utilsFacade.finite(list(male_data[protein].T.iloc[0].T)))
					data_list.append(utilsFacade.finite(list(female_data[protein].T.iloc[0].T)))
					color_list.append('purple')
					color_list.append('green')
					sig_protein_list.append(protein)
					sig_protein_list.append(protein)
					positions.append(positions[-1]+0.4)
					positions.append(positions[-1]+0.6)					

		male_median_data = male_data.T.median().T
		female_median_data = female_data.T.median().T
		data_list.append(list(male_median_data))
		data_list.append(list(female_median_data))
		color_list.append('grey')
		color_list.append('grey')
		if name=='HC2402:COPI':
			sig_protein_list.append('COPI-male')
			sig_protein_list.append('COPI-female')
		elif name == 'HC663:COPII':
			sig_protein_list.append('COPII-male')
			sig_protein_list.append('COPII-female')
		elif name=='MAPK':
			sig_protein_list.append('MAPK-male')
			sig_protein_list.append('MAPK-female')
		elif name == 'Cohesin':
			sig_protein_list.append('Cohesin-male')
			sig_protein_list.append('Cohesin-female')			
		positions.append(positions[-1]+0.5)

		return {'data':data_list,
				'color': color_list,
				'sig_protein':sig_protein_list,
				'positions':positions}

	@staticmethod
	def figure_boxplot_complex_logFC_male_female(complex_id, output_name, **kwargs):
		folder = kwargs.get('folder','PATH')

		data_dict = figure5b.get_data_boxplot(complex_id)
		data_list = data_dict['data']
		color_list = data_dict['color']
		sig_protein_list = data_dict['sig_protein']
		positions = data_dict['positions']

		sns.set(context='notebook', style='white', 
			palette='deep', font='Liberation Sans', font_scale=1, 
			color_codes=False, rc=None)
		plt.rcParams["axes.grid"] = True

		plt.clf()
		fig = plt.figure(figsize = (5,5))
		ax = fig.add_subplot(111)
		bp = ax.boxplot(data_list,notch=0,sym="",vert=1,patch_artist=True,
			 			widths=[0.4]*len(data_list), positions = positions)
		plt.setp(bp['medians'], color="black")
		plt.setp(bp['whiskers'], color="black",linestyle="--",alpha=0.8)
		for i,patch in enumerate(bp['boxes']):
			patch.set_edgecolor("black")
			patch.set_alpha(0.6)
			patch.set_color(color_list[i])
			sub_group = data_list[i]
			x = numpy.random.normal(positions[i], 0.04, size=len(data_list[i]))
			ax.scatter(x,sub_group,color='white', alpha=0.5,edgecolor="black",s=5)
		ax.set_xticklabels(sig_protein_list, rotation = 90)
		ax.set_ylim(-1,1)
		ax.set_ylabel('normalized abundances')
		plt.savefig(folder + 'fig5b_' + output_name + '_boxplot_males_females.pdf',
					bbox_inches = 'tight', dpi = 400)

if __name__ == "__main__":
	## EXECUTE FIGURE5
	figure5a.execute(folder = sys.argv[1], output_folder = sys.argv[2])
	figure5b.execute(folder = sys.argv[1], output_folder = sys.argv[2])



