# ------------------------------------------------------------
# Author: Natalie Romanov
# ------------------------------------------------------------

class supp_figure3:

	@staticmethod
	def execute(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		print('SUPP-FIGURE3: main_supp_figure2c_GOenrichment')
		supp_figure3.main_supp_figure2c_GOenrichment(folder = folder, output_folder = output_folder)

	@staticmethod
	def main_supp_figure2c_GOenrichment(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		print('load_data')
		go_dict = supp_figure3.load_data(folder)

		print('make_go_plot_scatter')
		supp_figure3.make_go_barplot(go_dict, output_folder)

	@staticmethod
	def load_data(folder):
		category_list = ['GOTERM_BP_DIRECT','GOTERM_MF_DIRECT','GOTERM_CC_DIRECT']

		stable_data = DataFrameAnalyzer.getFile(folder,'wpc_suppFigure2C_goEnrichment_stableProteins_DAVID6_8.txt')
		variable_data = DataFrameAnalyzer.getFile(folder,'wpc_suppFigure2C_goEnrichment_variableProteins_DAVID6_8.txt')

		stable_data = stable_data[stable_data.index.isin(category_list)]
		variable_data = variable_data[variable_data.index.isin(category_list)]

		go_dict = dict((e1,list()) for e1 in ['BP','CC','MF'])
		for category in category_list:
			stable_sub = stable_data[stable_data.index == category]
			variable_sub = variable_data[variable_data.index == category]

			stable_sub = stable_sub[stable_sub.Count>=50]
			stable_sub = stable_sub[stable_sub['%']>=10]

			variable_sub = variable_sub[variable_sub.Count>=50]
			variable_sub = variable_sub[variable_sub['%']>=10]

			term_list = [item.split('~')[1] for item in list(stable_sub['Term'])]
			stable_sub['go_term'] = pd.Series(term_list, index = stable_sub.index)
			term_list = [item.split('~')[1] for item in list(variable_sub['Term'])]
			variable_sub['go_term'] = pd.Series(term_list, index = variable_sub.index)
			go_dict[category.split('_')[1]].append((stable_sub, variable_sub))
		return go_dict
	
	@staticmethod
	def make_go_barplot(go_dict, output_folder):
		sns.set(context='notebook', style='white', 
			palette='deep', font='Liberation Sans', font_scale=1, 
			color_codes=False, rc=None)
		plt.rcParams["axes.grid"] = True

		for category in go_dict:
			stable_data, variable_data = go_dict[category][0]
			stable_data = stable_data[stable_data.FDR < 0.01]
			variable_data = variable_data[variable_data.FDR < 0.01]
			all_terms = set.intersection(*[set(stable_data.go_term), set(variable_data.go_term)])

			stable_data = stable_data.sort_values('Fold Enrichment', ascending = False)
			stable_fe_list = list(stable_data['Fold Enrichment'])
			label_list = list(stable_data.go_term)
			variable_fe_list = list()
			for term in list(stable_data.go_term):
				if term in all_terms:
					variable_fe_list.append(list(variable_data[variable_data.go_term==term]['Fold Enrichment'])[0])
				else:
					variable_fe_list.append(0)
			variable_other_data = variable_data[~variable_data.go_term.isin(list(stable_data.go_term))]
			variable_other_data = variable_other_data.sort_values('Fold Enrichment', ascending = True)
			for term in list(variable_other_data.go_term):
				variable_fe_list.append(list(variable_other_data[variable_other_data.go_term==term]['Fold Enrichment'])[0])
				label_list.append(term)
			stable_length = len(stable_fe_list)
			stable_fe_list = stable_fe_list + [0]*(len(variable_fe_list)-stable_length)

			sc_stable,colors_stable = colorFacade.get_specific_color_gradient(plt.cm.Blues, np.array(stable_fe_list))
			sc_variable,colors_variable = colorFacade.get_specific_color_gradient(plt.cm.Reds, np.array(variable_fe_list))

			plt.clf()
			fig = plt.figure()
			ax = fig.add_subplot(111)
			y_pos = np.arange(len(label_list))
			ax.barh(y_pos, list((-1)*np.array(stable_fe_list)), align='center',color=colors_stable, ecolor='black')
			ax.barh(y_pos, list(variable_fe_list), align='center',color=colors_variable, ecolor='black')
			ax.set_yticklabels([])
			ax2 = ax.twinx()
			plt.yticks(list(xrange(len(label_list))))
			ax2.set_yticklabels(label_list)
			ax.set_xlabel('Fold Enrichment')
			ax.set_title('GO:'+str(category))
			ax.set_ylim(-1,len(label_list)+1)
			ax2.set_ylim(-1,len(label_list)+1)
			plt.savefig(output_folder + 'suppFigure3_' + category + '_complexes_DAVID6_8.pdf',
						bbox_inches = 'tight', dpi = 400)

if __name__ == "__main__":
	## EXECUTE SUPPFIGURE3
	supp_figure3.execute(folder = sys.argv[1], output_folder = sys.argv[2])
	


