# ------------------------------------------------------------
# Author: Natalie Romanov
# ------------------------------------------------------------

class step15_preparation:
	@staticmethod
	def execute(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		print('get_complex_underlying_data')
		step15 = step15_preparation.get_complex_underlying_data(folder = folder, output_folder = output_folder)

	@staticmethod
	def load_multivariate_subunit_specific_data(dataset, module_type, datatype, covariate, **kwargs):
		folder = kwargs.get('folder','PATH')

		joint_name = '_'.join([dataset, module_type, datatype, covariate])
		filename = joint_name + '_data_all_multi_covariate_subunitSpecific_empiricalFDR.tsv.gz'
		data = DataFrameAnalyzer.open_in_chunks(folder, filename)
		return data

	@staticmethod
	def get_complex_underlying_data(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		module = 'complex'
		sex_data = step15_preparation.load_multivariate_subunit_specific_data('gygi3',
				   module, 'quant', 'sex', folder = folder)
		sex_stochdata = step15_preparation.load_multivariate_subunit_specific_data('gygi3',
						module, 'stoichiometry', 'sex', folder = folder)
		diet_data = step15_preparation.load_multivariate_subunit_specific_data('gygi3',
				   module, 'quant', 'diet', folder = folder)
		diet_stochdata = step15_preparation.load_multivariate_subunit_specific_data('gygi3',
						module, 'stoichiometry', 'diet', folder = folder)

		goldComplexes = list()
		for complexID in complexDict.keys():
			gold = complexDict[complexID]['goldComplex'][0]
			altName = complexDict[complexID]['altName'][0]
			if gold=='yes':
				goldComplexes.append(complexID + ':' + altName)
		s = sex_data[sex_data['complex.name'].str.contains('tRNA splicing')]
		sex_data = sex_data[sex_data['n.subunits']>=5]
		sex_data = pd.concat([sex_data,s])

		sex_data.index = sex_data['complex.name']
		sex_stochdata.index = sex_stochdata['complex.name']
		diet_data.index = diet_data['complex.name']
		diet_stochdata.index = diet_stochdata['complex.name']

		sex_data = sex_data[['r2.all.module','empirical.FDR.module']].drop_duplicates()
		sex_stochdata = sex_stochdata[['r2.all.module','empirical.FDR.module']].drop_duplicates()
		diet_data = diet_data[['r2.all.module','empirical.FDR.module']].drop_duplicates()
		diet_stochdata = diet_stochdata[['r2.all.module','empirical.FDR.module']].drop_duplicates()

		sex_dict = sex_data['r2.all.module'].to_dict()
		sex_stochdict = sex_stochdata['r2.all.module'].to_dict()
		diet_dict = diet_data['r2.all.module'].to_dict()
		diet_stochdict = diet_stochdata['r2.all.module'].to_dict()

		sex_pval_dict = sex_data['empirical.FDR.module'].to_dict()
		sex_pval_stochdict = sex_stochdata['empirical.FDR.module'].to_dict()
		diet_pval_dict = diet_data['empirical.FDR.module'].to_dict()
		diet_pval_stochdict = diet_stochdata['empirical.FDR.module'].to_dict()

		df_list = list()
		key_list = list()
		pval_df_list = list()
		for key in sex_dict:
			sex_effect = sex_dict[key]
			sex_pval = sex_pval_dict[key]
			try:
				diet_effect = diet_dict[key]
				diet_pval = diet_pval_dict[key]
			except:
				diet_effect = np.nan
				diet_pval = np.nan
			try:
				sex_stoch_effect = sex_stochdict[key]
				sex_stoch_pval = sex_pval_stochdict[key]
			except:
				sex_stoch_effect = np.nan
				sex_stoch_pval = np.nan
			try:
				diet_stoch_effect = diet_stochdict[key]
				diet_stoch_pval = diet_pval_stochdict[key]
			except:
				diet_stoch_effect = np.nan
				diet_stoch_pval = np.nan
			temp = [sex_effect, sex_stoch_effect, diet_effect, diet_stoch_effect]
			temp_pval = [sex_pval, sex_stoch_pval, diet_pval, diet_stoch_pval]
			df_list.append(temp)
			pval_df_list.append(temp_pval)
			key_list.append(key)
		df = pd.DataFrame(df_list)
		df.index = key_list
		df.columns = ['sex_quant','sex_stoch','diet_quant','diet_stoch']

		pval_df = pd.DataFrame(pval_df_list)
		pval_df.index = key_list
		pval_df.columns = ['sex_quant','sex_stoch','diet_quant','diet_stoch']

		df.to_csv(output_folder + 'suppFigure4a_underlyingData_gygi3_complex_effectSizeMatrix.tsv', sep = '\t')
		pval_df.to_csv(output_folder + 'suppFigure4a_underlyingData_gygi3_complex_pvalMatrix.tsv', sep = '\t')
		return df, pval_df

class step15_figure:
	@staticmethod
	def execute(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		print('FIGURE6B: plot_effectSize_complexDistribution')
		step15_figure.plot_effectSize_complexDistribution(folder = folder)

	@staticmethod
	def get_data(**kwargs):
		folder = kwargs.get('folder','PATH')

		mouse_df = DataFrameAnalyzer.open_in_chunks(folder,'suppFigure4a_underlyingData_gygi3_complex_effectSizeMatrix.tsv.gz')
		mouse_pval_df = DataFrameAnalyzer.open_in_chunks(folder,'suppFigure4a_underlyingData_gygi3_complex_pvalMatrix.tsv.gz')
		return {'mouse': (mouse_df, mouse_pval_df)}

	@staticmethod
	def plot_effectSize_complexDistribution(**kwargs):
		folder = kwargs.get('folder','PATH')

		data_dict = step15_figure.get_data()
		mouse_df, mouse_pval_df = data_dict['mouse']
		human_df, human_pval_df = data_dict['human']

		mouse_df.columns = ['mouse_'+item for item in list(mouse_df.columns)]
		mouse_pval_df.columns = ['mouse_'+item for item in list(mouse_pval_df.columns)]

		mouse_df = mouse_df.replace(np.nan,-100)
		mdf, proteinList = utilsFacade.recluster_matrix_only_rows(mouse_df)
		mdf = mdf.replace(-100, np.nan)
		mdf = mdf.T

		ranked_sorted_list = list()
		for i,row in mdf.iterrows():
			temp = list()
			for item in list(row):
				if str(item)!='nan':
					temp.append(item)
			ranked_temp = rankdata(temp)
			rank_dict = dict()
			for t,r in zip(temp, ranked_temp):
				rank_dict[t] = r
			final_temp = list()
			for t in list(row):
				if str(t)!='nan':
					final_temp.append(rank_dict[t])
				else:
					final_temp.append(np.nan)
			ranked_sorted_list.append(final_temp)
		ranked_df = pd.DataFrame(ranked_sorted_list)
		ranked_df.index = mdf.index
		ranked_df.columns = mdf.columns

		sex_quant_list = [item*100 for item in list(mdf.T['mouse_sex_quant'])]
		sex_stoch_list = [item*100 for item in list(mdf.T['mouse_sex_stoch'])]
		sex_sum_list =  np.array(sex_quant_list) + np.array(sex_stoch_list)
		diet_quant_list = [item*100 for item in list(mdf.T['mouse_diet_quant'])]
		diet_stoch_list = [item*100 for item in list(mdf.T['mouse_diet_stoch'])]
		key_list = list(mdf.columns)
		lists = [sex_sum_list, sex_quant_list, sex_stoch_list, diet_quant_list, diet_stoch_list, key_list]
		sorted_lists = utilsFacade.sort_multiple_lists(lists, reverse = True)
		sex_sum_list, sex_quant_list, sex_stoch_list, diet_quant_list, diet_stoch_list, key_list = sorted_lists
		ranked_df = ranked_df[key_list]


		sns.set(context='notebook', style='white', 
			palette='deep', font='Liberation Sans', font_scale=1, 
			color_codes=False, rc=None)
		plt.rcParams["axes.grid"] = False

		plt.clf()
		fig = plt.figure(figsize = (17,10))
		gs = gridspec.GridSpec(10,32)
		ax = plt.subplot(gs[0:4,0:])
		ax.axhline(10,color = 'k', linestyle = '--')
		ax.axhline(20,color = 'k', linestyle = '--')
		ax.axhline(30,color = 'k', linestyle = '--')
		ax.axhline(40,color = 'k', linestyle = '--')
		ax.axhline(50,color = 'k', linestyle = '--')
		ind = list(xrange(len(sex_quant_list)))
		width = 1
		rects = ax.bar(ind, sex_quant_list, width, color='lightblue', edgecolor = 'white')
		rects = ax.bar(ind, sex_stoch_list, width, color='darkblue',
					   edgecolor = 'white', bottom = np.array(sex_quant_list))

		ax.set_xlim(-0.5,len(sex_quant_list)+0.5)
		ax.set_xticklabels([])

		ax = plt.subplot(gs[4:8,0:])
		ax.set_ylim(-60,0)
		ax.axhline(-10,color = 'k', linestyle = '--')
		ax.axhline(-20,color = 'k', linestyle = '--')
		ax.axhline(-30,color = 'k', linestyle = '--')
		ax.axhline(-40,color = 'k', linestyle = '--')
		ax.axhline(-50,color = 'k', linestyle = '--')		
		ind = list(xrange(len(sex_quant_list)))
		width = 1
		rects = ax.bar(ind, (-1)*np.array(diet_quant_list), width, color='lightgreen', edgecolor = 'white')
		rects = ax.bar(ind, (-1)*np.array(diet_stoch_list), width, color='darkgreen',
					   edgecolor = 'white', bottom = (-1)*np.array(diet_quant_list))
		ax.set_xlim(-0.5,len(diet_quant_list)+0.5)
		plt.xticks(list(utilsFacade.frange(0.5,len(ranked_df.columns)+0.5,1)))
		ax.set_xticklabels([':'.join(item.split(':')[1:]) for item in list(ranked_df.columns)],
						   rotation = 90, fontsize = 5)
		plt.savefig(folder + 'fig6b_complex_effectSize_Distribution.pdf',
					bbox_inches = 'tight', dpi = 400)

if __name__ == "__main__":
	## EXECUTE STEP15
	step15_preparation.execute(folder = sys.argv[1], output_folder = sys.argv[2])
	step15_figure.execute(folder = sys.argv[1], output_folder = sys.argv[2])
