# ------------------------------------------------------------
# Author: Natalie Romanov
# ------------------------------------------------------------

class figure6a:
	@staticmethod
	def execute(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		print('FIGURE6A: get_considered_pathways')
		considered_pathways = figure6a.get_considered_pathways(folder = folder)

		print('FIGURE6A: plot')
		figure6a.plot(considered_pathways, folder = folder)

	@staticmethod
	def get_considered_pathways(**kwargs):
		folder = kwargs.get('folder','PATH')

		df = DataFrameAnalyzer.getFile(folder,'corr_classification_pathways.tsv.gz')
		sub = df[df.pval1 > 0.1]
		sub = sub[sub.pval1_complex < 0.1]
		considered_pathways = list(sub.index)		
		return considered_pathways

	@staticmethod
	def plot(considered_pathways, **kwargs):
		folder = kwargs.get('folder','PATH')

		comb_dict = DataFrameAnalyzer.read_pickle(folder + 'combined_dictionary.pkl')
		sex_dict = DataFrameAnalyzer.read_pickle(folder + 'sex_effect_dictionary.pkl')
		diet_dict = DataFrameAnalyzer.read_pickle(folder + 'diet_effect_dictionary.pkl')

		for key in ['all','complex','pathway']:
			if key == 'all':
				quant_sex_all = list(set(sex_dict[key]['quant']['r2.all.module']))
				quant_diet_all = list(set(diet_dict[key]['quant']['r2.all.module']))
				quant_comb_all = list(set(comb_dict[key]['quant']['r2.all.module']))
			elif key == 'complex':
				quant_sex_complex = list(set(sex_dict[key]['quant']['r2.all.module']))
				stoch_sex_complex = list(set(sex_dict[key]['stoichiometry']['r2.all.module']))
				quant_diet_complex = list(set(diet_dict[key]['quant']['r2.all.module']))
				stoch_diet_complex = list(set(diet_dict[key]['stoichiometry']['r2.all.module']))
				quant_comb_complex = list(set(comb_dict[key]['quant']['r2.all.module']))
				stoch_comb_complex = list(set(comb_dict[key]['stoichiometry']['r2.all.module']))
			elif key == 'pathway':
				mquant_sex = sex_dict[key]['quant']
				mquant_sex = mquant_sex[~mquant_sex['complex.name'].isin(considered_pathways)]
				quant_sex_pathway = list(set(mquant_sex['r2.all.module']))
				mstoch_sex = sex_dict[key]['stoichiometry']
				mstoch_sex = mstoch_sex[~mstoch_sex['complex.name'].isin(considered_pathways)]
				stoch_sex_pathway = list(set(mstoch_sex['r2.all.module']))
				mquant_diet = diet_dict[key]['quant']
				mquant_diet = mquant_diet[~mquant_diet['complex.name'].isin(considered_pathways)]
				quant_diet_pathway = list(set(mquant_diet['r2.all.module']))
				mstoch_diet = diet_dict[key]['stoichiometry']
				mstoch_diet = mstoch_diet[~mstoch_diet['complex.name'].isin(considered_pathways)]
				stoch_diet_pathway = list(set(mstoch_diet['r2.all.module']))
				mquant_comb = comb_dict[key]['quant']
				mquant_comb = mquant_comb[~mquant_comb['complex.name'].isin(considered_pathways)]
				quant_comb_pathway = list(set(mquant_comb['r2.all.module']))
				mstoch_comb = comb_dict[key]['stoichiometry']
				mstoch_comb = mstoch_comb[~mstoch_comb['complex.name'].isin(considered_pathways)]
				stoch_comb_pathway = list(set(mstoch_comb['r2.all.module']))

		r2_quant_sex_complex = quant_sex_complex + quant_sex_pathway
		r2_stoch_sex_complex = stoch_sex_complex + stoch_sex_pathway
		r2_quant_diet_complex = quant_diet_complex + quant_diet_pathway
		r2_stoch_diet_complex = stoch_diet_complex + stoch_diet_pathway
		r2_quant_comb_complex = quant_comb_complex + quant_comb_pathway
		r2_stoch_comb_complex = stoch_comb_complex + stoch_comb_pathway

		r2_list = [r2_stoch_comb_complex, r2_quant_comb_complex, quant_comb_all,
				   r2_stoch_diet_complex, r2_quant_diet_complex, quant_diet_all,
				   r2_stoch_sex_complex, r2_quant_sex_complex, quant_sex_all]

		label_list = ['scomb_complex','qcomb_complex','qcomb_all',
					  'sdiet_complex','qdiet_complex','qdiet_all',
					  'ssex_complex','qsex_complex','qsex_all']

		big_color_list = ['black','grey','grey','green',
						  'lightgreen','lightgreen',
						  'blue','lightblue','lightblue']

		sns.set(context='notebook', style='white', 
			palette='deep', font='Liberation Sans', font_scale=1, 
			color_codes=False, rc=None)
		plt.rcParams["axes.grid"] = False

		plt.clf()
		fig = plt.figure(figsize = (10,10))
		gs = gridspec.GridSpec(9,9)
	
		ax = plt.subplot(gs[0:6,0:])
		ax.set_xlim(-0.01,0.4)
		bp = ax.boxplot(r2_list,notch=0,sym="",vert=0,
						patch_artist=True, widths=[0.8]*len(r2_list))
		plt.setp(bp['medians'], color="black")
		plt.setp(bp['whiskers'], color="black",linestyle="--",alpha=0.8)
		for i,patch in enumerate(bp['boxes']):
			patch.set_edgecolor("black")
			patch.set_alpha(0.6)
			patch.set_color(big_color_list[i])
		plt.yticks(list(xrange(len(label_list))))
		ax.set_yticklabels(label_list)
		ax.axvline(0.1, color = 'black', linestyle = '--')
		ax.axvline(0.2, color = 'black', linestyle = '--')
		ax.axvline(0.3, color = 'black', linestyle = '--')
		ax = plt.subplot(gs[6:9,0:])
		ax.axis('off')
		plt.savefig(folder + 'fig6a_combinedEffect_analysis.pdf',
					bbox_inches = 'tight', dpi = 400)

class figure6b:
	@staticmethod
	def execute(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		print('FIGURE6C: plot_effectSize_complexDistribution')
		figure6b.plot_effectSize_complexDistribution(folder = folder)

	@staticmethod
	def get_data(**kwargs):
		folder = kwargs.get('folder','PATH')

		mouse_df = DataFrameAnalyzer.open_in_chunks(folder,'suppFigure4a_underlyingData_gygi3_complex_effectSizeMatrix.tsv.gz')
		mouse_pval_df = DataFrameAnalyzer.open_in_chunks(folder,'suppFigure4a_underlyingData_gygi3_complex_pvalMatrix.tsv.gz')
		return {'mouse': (mouse_df, mouse_pval_df)}		

	@staticmethod
	def plot_effectSize_complexDistribution(**kwargs):
		folder = kwargs.get('folder','PATH')

		data_dict = figure6b.get_data()
		mouse_df, mouse_pval_df = data_dict['mouse']

		mouse_df.columns = ['mouse_'+item for item in list(mouse_df.columns)]
		mouse_pval_df.columns = ['mouse_'+item for item in list(mouse_pval_df.columns)]

		mouse_df = mouse_df.replace(np.nan,-100)
		mdf, proteinList = utilsFacade.recluster_matrix_only_rows(mouse_df)
		mdf = mdf.replace(-100, np.nan)
		mdf = mdf.T

		ranked_sorted_list = list()
		for i,row in mdf.iterrows():
			temp = list()
			for item in list(row):
				if str(item)!='nan':
					temp.append(item)
			ranked_temp = rankdata(temp)
			rank_dict = dict()
			for t,r in zip(temp, ranked_temp):
				rank_dict[t] = r
			final_temp = list()
			for t in list(row):
				if str(t)!='nan':
					final_temp.append(rank_dict[t])
				else:
					final_temp.append(np.nan)
			ranked_sorted_list.append(final_temp)
		ranked_df = pd.DataFrame(ranked_sorted_list)
		ranked_df.index = mdf.index
		ranked_df.columns = mdf.columns

		sex_quant_list = [item*100 for item in list(mdf.T['mouse_sex_quant'])]
		sex_stoch_list = [item*100 for item in list(mdf.T['mouse_sex_stoch'])]
		sex_sum_list =  np.array(sex_quant_list) + np.array(sex_stoch_list)
		diet_quant_list = [item*100 for item in list(mdf.T['mouse_diet_quant'])]
		diet_stoch_list = [item*100 for item in list(mdf.T['mouse_diet_stoch'])]
		key_list = list(mdf.columns)

		lists = [sex_sum_list, sex_quant_list,
				 sex_stoch_list, diet_quant_list,
				 diet_stoch_list, key_list]

		sorted_lists = utilsFacade.sort_multiple_lists(lists, reverse = True)
		sex_sum_list, sex_quant_list, sex_stoch_list, diet_quant_list, diet_stoch_list, key_list = sorted_lists
		ranked_df = ranked_df[key_list]


		sns.set(context='notebook', style='white', 
			palette='deep', font='Liberation Sans', font_scale=1, 
			color_codes=False, rc=None)
		plt.rcParams["axes.grid"] = False

		plt.clf()
		fig = plt.figure(figsize = (17,10))
		gs = gridspec.GridSpec(10,32)
		ax = plt.subplot(gs[0:4,0:])
		ax.axhline(10,color = 'k', linestyle = '--')
		ax.axhline(20,color = 'k', linestyle = '--')
		ax.axhline(30,color = 'k', linestyle = '--')
		ax.axhline(40,color = 'k', linestyle = '--')
		ax.axhline(50,color = 'k', linestyle = '--')
		ind = list(xrange(len(sex_quant_list)))
		width = 1
		rects = ax.bar(ind, sex_quant_list, width, color='lightblue', edgecolor = 'white')
		rects = ax.bar(ind, sex_stoch_list, width, color='darkblue',
					   edgecolor = 'white', bottom = np.array(sex_quant_list))

		ax.set_xlim(-0.5,len(sex_quant_list)+0.5)
		ax.set_xticklabels([])

		ax = plt.subplot(gs[4:8,0:])
		ax.set_ylim(-60,0)
		ax.axhline(-10,color = 'k', linestyle = '--')
		ax.axhline(-20,color = 'k', linestyle = '--')
		ax.axhline(-30,color = 'k', linestyle = '--')
		ax.axhline(-40,color = 'k', linestyle = '--')
		ax.axhline(-50,color = 'k', linestyle = '--')		
		ind = list(xrange(len(sex_quant_list)))
		width = 1
		rects = ax.bar(ind, (-1)*np.array(diet_quant_list), width,
					   color='lightgreen', edgecolor = 'white')
		rects = ax.bar(ind, (-1)*np.array(diet_stoch_list), width,
					   color='darkgreen', edgecolor = 'white',
					   bottom = (-1)*np.array(diet_quant_list))
		ax.set_xlim(-0.5,len(diet_quant_list)+0.5)
		plt.xticks(list(utilsFacade.frange(0.5,len(ranked_df.columns)+0.5,1)))
		ax.set_xticklabels([':'.join(item.split(':')[1:]) for item in list(ranked_df.columns)], 
						   rotation = 90, fontsize = 5)
		plt.savefig(folder + 'fig6b_complex_effectSize_Distribution.pdf',
					bbox_inches = 'tight', dpi = 400)

if __name__ == "__main__":
	## EXECUTE FIGURE6
	figure6a.execute(folder = sys.argv[1], output_folder = sys.argv[2])
	figure6b.execute(folder = sys.argv[1], output_folder = sys.argv[2])


