# ------------------------------------------------------------
# Author: Natalie Romanov
# ------------------------------------------------------------

class file_Loader:

	@staticmethod
	def load_data(**kwargs):
		folder = kwargs.get('folder','PATH')

		dat_gygi3 = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_gygi3.tsv.gz')
		dat_gygi1 = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_gygi1.tsv.gz')
		dat_mann = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_mann.tsv.gz')
		dat_battle_protein = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_battle_protein.tsv.gz')
		dat_tcga_colo = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_tcga_colo.tsv.gz')
		dat_tcga_breast = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_tcga_breast.tsv.gz')
		dat_tcga_ovarian = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_tcga_ovarian.tsv.gz')
		data_list = [dat_gygi1, dat_gygi2, dat_gygi3, dat_battle_protein 
					 dat_mann, dat_tcga_colo, dat_tcga_breast, dat_tcga_ovarian]
		return data_list

class step9_preparation:
	@staticmethod
	def execute(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		print('load_data')
		data_list = file_Loader.load_data(folder = folder)

		print('load_GO_input_df')
		step9_preparation.load_GO_input_df(data_list, output_folder)

	@staticmethod
	def load_GO_input_df(data_list, output_folder):
		df = DataFrameAnalyzer.getFile(output_folder ,'figure2B_underlying_data.tsv')

		mean_df = df.mean()
		stable_df = mean_df[mean_df > 0]
		variable_df = mean_df[mean_df < 0]
		complex_list = list(df.columns)
		stable_protein_list = list()
		stable_protein_dict = dict()

		for complexID in list(stable_df.index):
			print(complexID)
			protein_list = list()
			for d in data_list:
				sub = d[d.ComplexName==complexID]
				protein_list.append([item.upper() for item in list(sub.index)])
			protein_list = list(set(utilsFacade.flatten(protein_list)))
			stable_protein_dict[complexID] = protein_list
			stable_protein_list.append(protein_list)

		variable_protein_list = list()
		variable_protein_dict = dict()
		for complexID in list(variable_df.index):
			print(complexID)
			protein_list = list()
			for d in data_list:
				sub = d[d.ComplexName==complexID]
				protein_list.append([item.upper() for item in list(sub.index)])
			protein_list = list(set(utilsFacade.flatten(protein_list)))
			variable_protein_dict[complexID] = protein_list
			variable_protein_list.append(protein_list)

		o = open(output_folder + 'suppFigure3_stable_protein_list.tsv','wb')
		stable_proteins = map(str,list(set(utilsFacade.flatten(stable_protein_list))))
		m = Mapper(stable_proteins, input = 'symbol', output = 'entrezgene')
		stable_entrezgenes = map(str,map(int,utilsFacade.finite(list(m.trans_df['entrezgene']))))
		exportText = '\n'.join(stable_entrezgenes)
		o.write(exportText)
		o.close()

		o = open(output_folder + 'suppFigure3_variable_protein_list.tsv','wb')
		variable_proteins = map(str,list(set(utilsFacade.flatten(variable_protein_list))))
		m = Mapper(variable_proteins, input = 'symbol', output = 'entrezgene')
		variable_entrezgenes = map(str,map(int,utilsFacade.finite(list(m.trans_df['entrezgene']))))
		exportText = '\n'.join(variable_entrezgenes)
		o.write(exportText)
		o.close()

		#these files are submitted to DAVID v6.8

class step9_figure:
	@staticmethod
	def execute(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		print('SUPP-FIGURE3: main_supp_figure2c_GOenrichment')
		step9_figure.main_supp_figure2c_GOenrichment(folder = folder, output_folder = output_folder)

	@staticmethod
	def main_supp_figure2c_GOenrichment(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		print('load_data')
		go_dict = step9_figure.load_data(output_folder)

		print('make_go_plot_scatter')
		step9_figure.make_go_barplot(go_dict, output_folder)

	@staticmethod
	def load_data(output_folder):
		category_list = ['GOTERM_BP_DIRECT','GOTERM_MF_DIRECT','GOTERM_CC_DIRECT']
		stable_data = DataFrameAnalyzer.getFile(output_folder,
					  'wpc_suppFigure2C_goEnrichment_stableProteins_DAVID6_8.txt')
		variable_data = DataFrameAnalyzer.getFile(output_folder,
						'wpc_suppFigure2C_goEnrichment_variableProteins_DAVID6_8.txt')

		#filter GO result table according to strict criteria
		stable_data = stable_data[stable_data.index.isin(category_list)]
		variable_data = variable_data[variable_data.index.isin(category_list)]
		stable_data = stable_data[stable_data.FDR<=0.01]
		variable_data = variable_data[variable_data.FDR<=0.01]

		go_dict = dict((e1,list()) for e1 in ['BP','CC','MF'])
		for category in category_list:
			stable_sub = stable_data[stable_data.index == category]
			variable_sub = variable_data[variable_data.index == category]

			stable_sub = stable_sub[stable_sub.Count >= 50]
			stable_sub = stable_sub[stable_sub['%'] >= 10]

			variable_sub = variable_sub[variable_sub.Count >= 50]
			variable_sub = variable_sub[variable_sub['%'] >= 10]

			term_list = [item.split('~')[1] for item in list(stable_sub['Term'])]
			stable_sub['go_term'] = pd.Series(term_list, index = stable_sub.index)
			term_list = [item.split('~')[1] for item in list(variable_sub['Term'])]
			variable_sub['go_term'] = pd.Series(term_list, index = variable_sub.index)
			go_dict[category.split('_')[1]].append((stable_sub, variable_sub))

		return go_dict
	
	@staticmethod
	def make_go_barplot(go_dict, output_folder):

		sns.set(context='notebook', style='white', 
			palette='deep', font='Liberation Sans', font_scale=1, 
			color_codes=False, rc=None)
		plt.rcParams["axes.grid"] = True

		for category in go_dict:
			stable_data, variable_data = go_dict[category][0]
			stable_data = stable_data[stable_data.FDR < 0.01]
			variable_data = variable_data[variable_data.FDR < 0.01]
			all_terms = set.intersection(*[set(stable_data.go_term), set(variable_data.go_term)])

			stable_data = stable_data.sort_values('Fold Enrichment', ascending = False)
			stable_fe_list = list(stable_data['Fold Enrichment'])
			label_list = list(stable_data.go_term)
			variable_fe_list = list()
			for term in list(stable_data.go_term):
				if term in all_terms:
					variable_fe_list.append(list(variable_data[variable_data.go_term==term]['Fold Enrichment'])[0])
				else:
					variable_fe_list.append(0)
			variable_other_data = variable_data[~variable_data.go_term.isin(list(stable_data.go_term))]
			variable_other_data = variable_other_data.sort_values('Fold Enrichment', ascending = True)
			for term in list(variable_other_data.go_term):
				variable_fe_list.append(list(variable_other_data[variable_other_data.go_term==term]['Fold Enrichment'])[0])
				label_list.append(term)
			stable_length = len(stable_fe_list)
			stable_fe_list = stable_fe_list + [0]*(len(variable_fe_list)-stable_length)

			sc_stable,colors_stable = colorFacade.get_specific_color_gradient(plt.cm.Blues, np.array(stable_fe_list))
			sc_variable,colors_variable = colorFacade.get_specific_color_gradient(plt.cm.Reds, np.array(variable_fe_list))

			plt.clf()
			fig = plt.figure()
			ax = fig.add_subplot(111)
			y_pos = np.arange(len(label_list))
			ax.barh(y_pos, list((-1)*np.array(stable_fe_list)), align='center',color=colors_stable, ecolor='black')
			ax.barh(y_pos, list(variable_fe_list), align='center',color=colors_variable, ecolor='black')
			ax.set_yticklabels([])

			ax2 = ax.twinx()
			plt.yticks(list(xrange(len(label_list))))
			ax2.set_yticklabels(label_list)
			ax.set_xlabel('Fold Enrichment')
			ax.set_title('GO:'+str(category))
			ax.set_ylim(-1,len(label_list) + 1)
			ax2.set_ylim(-1,len(label_list) + 1)
			plt.savefig(output_folder + 'suppFigure3_' + category + '_complexes_DAVID6_8.pdf',
						bbox_inches = 'tight', dpi = 400)

if __name__ == "__main__":
	## EXECUTE STEP9
	step9_preparation.execute(folder = sys.argv[1], output_folder = sys.argv[2])
	step9_figure.execute(folder = sys.argv[1], output_folder = sys.argv[2])
	