# ------------------------------------------------------------
# Author: Natalie Romanov
# ------------------------------------------------------------

class step14_preparation:
	@staticmethod
	def execute(*kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		print('load_combined_data')
		comb_dict = step14_preparation.load_combined_data(5, folder = folder, output_folder = output_folder)
		print('load_sex_data')
		sex_dict = step14_preparation.load_sex_data(5, folder = folder, output_folder = output_folder)
		print('load_diet_data')
		diet_dict = step14_preparation.load_diet_data(5, folder = folder, output_folder = output_folder)

		print('export_considered_pathways')
		step14_preparation.export_considered_pathways(folder = folder)

	@staticmethod
	def load_multivariate_module_specific_data(dataset, module_type, datatype, covariate, **kwargs):
		n_subunit = kwargs.get('n_subunit',0)
		folder = kwargs.get('folder','PATH')

		joint_name = '_'.join([dataset, module_type, datatype,covariate])
		filename = joint_name + '_data_all_multi_covariate_moduleSpecific_empiricalFDR.tsv.gz'
		data = DataFrameAnalyzer.open_in_chunks(folder, filename)
		data = data[data['n.subunits']>=n_subunit]
		return data	

	@staticmethod
	def load_multivariate_combined_module_specific_data(dataset, module_type, datatype, **kwargs):
		n_subunit = kwargs.get('n_subunit',0)
		folder = kwargs.get('folder','PATH')

		joint_name = '_'.join([dataset, module_type, datatype])
		filename = joint_name + '_data_all_multi_covariate_moduleSpecific_COMBINEDEFFECT.tsv.gz'
		data = DataFrameAnalyzer.open_in_chunks(folder, filename)
		data = data[data['n.subunits']>=n_subunit]
		return data

	@staticmethod
	def load_combined_data(su, **kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		fname = 'protein_variation.tsv.gz'
		data = DataFrameAnalyzer.open_in_chunks(folder, fname)
		comb_all_quant = data[data.covariates=='sex_diet']

		comb_mquant_complex = figure6a.load_multivariate_combined_module_specific_data('gygi3',
							  'complex','quant', n_subunit = su)
		comb_mquant_corum = figure6a.load_multivariate_combined_module_specific_data('gygi3',
							'corum','quant', n_subunit = su)
		comb_mstoch_complex = figure6a.load_multivariate_combined_module_specific_data('gygi3',
							  'complex','stoichiometry', n_subunit = su)
		comb_mstoch_corum = figure6a.load_multivariate_combined_module_specific_data('gygi3',
							'corum','stoichiometry', n_subunit = su)

		comb_dict = {'complex':{'quant':comb_mquant_complex,'stoichiometry':comb_mstoch_complex},
					 'all':{'quant':comb_all_quant}}
		DataFrameAnalyzer.to_pickle(comb_dict, output_folder + 'combined_dictionary.pkl')

		return comb_dict

	@staticmethod
	def load_sex_data(su, **kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		fname = 'protein_variation.tsv.gz'
		data = DataFrameAnalyzer.open_in_chunks(folder, fname)
		sex_all_quant = data[data.covariates=='sex']

		mquant_complex = figure6a.load_multivariate_module_specific_data('gygi3',
						 'complex','quant','sex', n_subunit = su)
		mstoch_complex = figure6a.load_multivariate_module_specific_data('gygi3',
						 'complex','stoichiometry','sex', n_subunit = su)
		mquant_corum = figure6a.load_multivariate_module_specific_data('gygi3',
					   'corum','quant','sex', n_subunit = su)
		mstoch_corum = figure6a.load_multivariate_module_specific_data('gygi3',
					   'corum','stoichiometry','sex', n_subunit = su)

		sex_dict = {'complex':{'quant':mquant_complex,'stoichiometry':mstoch_complex},
					'all':{'quant':sex_all_quant}}
		DataFrameAnalyzer.to_pickle(sex_dict, output_folder + 'sex_effect_dictionary.pkl')
		return sex_dict

	@staticmethod
	def load_diet_data(su, **kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		fname = 'protein_variation.tsv.gz'
		data = DataFrameAnalyzer.open_in_chunks(folder, fname)
		diet_all_quant = data[data.covariates=='diet']

		mquant_complex = figure6a.load_multivariate_module_specific_data('gygi3',
						 'complex','quant','diet', n_subunit = su)
		mstoch_complex = figure6a.load_multivariate_module_specific_data('gygi3',
						 'complex','stoichiometry','diet', n_subunit = su)
		mquant_corum = figure6a.load_multivariate_module_specific_data('gygi3',
					   'corum','quant','diet', n_subunit = su)
		mstoch_corum = figure6a.load_multivariate_module_specific_data('gygi3',
					   'corum','stoichiometry','diet', n_subunit = su)

		diet_dict = {'complex':{'quant':mquant_complex,'stoichiometry':mstoch_complex},
					 'all':{'quant':diet_all_quant}}
		DataFrameAnalyzer.to_pickle(diet_dict, output_folder + 'diet_effect_dictionary.pkl')
		return diet_dict

	@staticmethod
	def export_considered_pathways(**kwargs):
		folder = kwargs.get('folder','PATH')

		#pathways are considered when not different from random, but different from complex (<0.5)
		df = DataFrameAnalyzer.getFile(folder,'corr_classification_pathways.tsv')
		sub = df[df.pval1>0.1]
		sub = sub[sub.pval1_complex<0.1]
		considered_pathways = list(sub.index)

		df.to_csv(folder + 'corr_classification_pathways.tsv.gz', sep = '\t', compression = 'gzip')
		return considered_pathways

class step14_figure:
	@staticmethod
	def execute(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		print('FIGURE6A: get_considered_pathways')
		considered_pathways = step14_figure.get_considered_pathways(folder = folder)

		print('FIGURE6A: plot')
		step14_figure.plot(considered_pathways, folder = folder)

	@staticmethod
	def get_considered_pathways(**kwargs):
		folder = kwargs.get('folder','PATH')

		df = DataFrameAnalyzer.getFile(folder,'corr_classification_pathways.tsv.gz')
		sub = df[df.pval1 > 0.1]
		sub = sub[sub.pval1_complex < 0.1]
		considered_pathways = list(sub.index)		
		return considered_pathways

	@staticmethod
	def plot(considered_pathways, **kwargs):
		folder = kwargs.get('folder','PATH')

		comb_dict = DataFrameAnalyzer.read_pickle(folder + 'combined_dictionary.pkl')
		sex_dict = DataFrameAnalyzer.read_pickle(folder + 'sex_effect_dictionary.pkl')
		diet_dict = DataFrameAnalyzer.read_pickle(folder + 'diet_effect_dictionary.pkl')

		for key in ['all','complex','pathway']:
			if key == 'all':
				quant_sex_all = list(set(sex_dict[key]['quant']['r2.all.module']))
				quant_diet_all = list(set(diet_dict[key]['quant']['r2.all.module']))
				quant_comb_all = list(set(comb_dict[key]['quant']['r2.all.module']))
			elif key == 'complex':
				quant_sex_complex = list(set(sex_dict[key]['quant']['r2.all.module']))
				stoch_sex_complex = list(set(sex_dict[key]['stoichiometry']['r2.all.module']))
				quant_diet_complex = list(set(diet_dict[key]['quant']['r2.all.module']))
				stoch_diet_complex = list(set(diet_dict[key]['stoichiometry']['r2.all.module']))
				quant_comb_complex = list(set(comb_dict[key]['quant']['r2.all.module']))
				stoch_comb_complex = list(set(comb_dict[key]['stoichiometry']['r2.all.module']))
			elif key == 'pathway':
				mquant_sex = sex_dict[key]['quant']
				mquant_sex = mquant_sex[~mquant_sex['complex.name'].isin(considered_pathways)]
				quant_sex_pathway = list(set(mquant_sex['r2.all.module']))
				mstoch_sex = sex_dict[key]['stoichiometry']
				mstoch_sex = mstoch_sex[~mstoch_sex['complex.name'].isin(considered_pathways)]
				stoch_sex_pathway = list(set(mstoch_sex['r2.all.module']))
				mquant_diet = diet_dict[key]['quant']
				mquant_diet = mquant_diet[~mquant_diet['complex.name'].isin(considered_pathways)]
				quant_diet_pathway = list(set(mquant_diet['r2.all.module']))
				mstoch_diet = diet_dict[key]['stoichiometry']
				mstoch_diet = mstoch_diet[~mstoch_diet['complex.name'].isin(considered_pathways)]
				stoch_diet_pathway = list(set(mstoch_diet['r2.all.module']))
				mquant_comb = comb_dict[key]['quant']
				mquant_comb = mquant_comb[~mquant_comb['complex.name'].isin(considered_pathways)]
				quant_comb_pathway = list(set(mquant_comb['r2.all.module']))
				mstoch_comb = comb_dict[key]['stoichiometry']
				mstoch_comb = mstoch_comb[~mstoch_comb['complex.name'].isin(considered_pathways)]
				stoch_comb_pathway = list(set(mstoch_comb['r2.all.module']))

		r2_quant_sex_complex = quant_sex_complex + quant_sex_pathway
		r2_stoch_sex_complex = stoch_sex_complex + stoch_sex_pathway
		r2_quant_diet_complex = quant_diet_complex + quant_diet_pathway
		r2_stoch_diet_complex = stoch_diet_complex + stoch_diet_pathway
		r2_quant_comb_complex = quant_comb_complex + quant_comb_pathway
		r2_stoch_comb_complex = stoch_comb_complex + stoch_comb_pathway

		r2_list = [r2_stoch_comb_complex, r2_quant_comb_complex, quant_comb_all,
				   r2_stoch_diet_complex, r2_quant_diet_complex, quant_diet_all,
				   r2_stoch_sex_complex, r2_quant_sex_complex, quant_sex_all]

		label_list = ['scomb_complex','qcomb_complex','qcomb_all',
					  'sdiet_complex','qdiet_complex','qdiet_all',
					  'ssex_complex','qsex_complex','qsex_all']

		color_list = ['lightblue','blue','lightgreen','green','grey','black']
		big_color_list = ['black','grey','grey','green','lightgreen','lightgreen',
						  'blue','lightblue','lightblue']
		small_color_list = ['lightblue','lightgreen','grey']

		sns.set(context='notebook', style='white', 
			palette='deep', font='Liberation Sans', font_scale=1, 
			color_codes=False, rc=None)
		plt.rcParams["axes.grid"] = False

		plt.clf()
		fig = plt.figure(figsize = (10,10))
		gs = gridspec.GridSpec(9,9)

		ax = plt.subplot(gs[0:6,0:])
		ax.set_xlim(-0.01,0.4)
		bp = ax.boxplot(r2_list,notch=0,sym="",vert=0,patch_artist=True, widths=[0.8]*len(r2_list))
		plt.setp(bp['medians'], color="black")
		plt.setp(bp['whiskers'], color="black",linestyle="--",alpha=0.8)
		for i,patch in enumerate(bp['boxes']):
			patch.set_edgecolor("black")
			patch.set_alpha(0.6)
			patch.set_color(big_color_list[i])
		plt.yticks(list(xrange(len(label_list))))
		ax.set_yticklabels(label_list)
		print(scipy.stats.f_oneway(r2_list[0]+r2_list[1],r2_list[2])[1])
		print(scipy.stats.f_oneway(r2_list[3]+r2_list[4],r2_list[5])[1])
		print(scipy.stats.f_oneway(r2_list[6]+r2_list[7],r2_list[8])[1])
		ax.axvline(0.1, color = 'black', linestyle = '--')
		ax.axvline(0.2, color = 'black', linestyle = '--')
		ax.axvline(0.3, color = 'black', linestyle = '--')

		ax = plt.subplot(gs[6:9,0:])
		ax.axis('off')
		plt.savefig(folder + 'fig6a_combinedEffect_analysis.pdf', bbox_inches = 'tight', dpi = 400)

if __name__ == "__main__":
	## EXECUTE STEP14
	step14_preparation.execute(folder = sys.argv[1], output_folder = sys.argv[2])
	step14_figure.execute(folder = sys.argv[1], output_folder = sys.argv[2])
	