# ------------------------------------------------------------
# Author: Natalie Romanov
# ------------------------------------------------------------

class figure2a:
	@staticmethod
	def execute(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		print('FIGURE2A: main_figure2a_rocMatrix')
		figure2a.main_figure2a_rocMatrix(folder = folder, output_folder = output_folder)

	@staticmethod
	def main_figure2a_rocMatrix(**kwargs):	
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		print('get_auc_matrix')
		auc_matrix = figure2a.get_auc_matrix(folder = folder)
		print("plot_auc_matrix")
		figure2a.plot_auc_matrix(auc_matrix, output_folder = output_folder)
		print('export_underlying_data')
		figure2a.export_underlying_data(auc_matrix, folder = folder, output_folder = output_folder)

	@staticmethod
	def get_auc_matrix(**kwargs):
		folder = kwargs.get('folder','PATH')

		auc_matrix = DataFrameAnalyzer.open_in_chunks(folder, 'fig2a_auc_matrix.tsv.gz', sep = '\t')
		return auc_matrix

	@staticmethod
	def get_specific_color_gradient(colormap,inputList, **kwargs):
		vmin = kwargs.get("vmin", False)
		vmax = kwargs.get("vmax", False)
		cm = plt.get_cmap(colormap)
		if type(inputList)==list:
			if vmin == False and vmax == False:
				cNorm = mpl.colors.Normalize(vmin=min(inputList), vmax=max(inputList))
			else:
				cNorm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
		else:
			if vmin == False and vmax == False:
				cNorm = mpl.colors.Normalize(vmin=inputList.min(), vmax=inputList.max())
			else:
				cNorm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
		scalarMap = mpl.cm.ScalarMappable(norm=cNorm, cmap=cm)
		scalarMap.set_array(inputList)
		colorList = scalarMap.to_rgba(inputList)
		return scalarMap,colorList

	@staticmethod
	def plot_auc_matrix(auc_df, output_folder):
		datasets = ["tcga_ovarian","battle_protein","colo_cancer","gygi3",
					"tcga_breast","gygi1","mann_all_log2","primatePRO","wu",
					"battle_ribo","battle_rna","gygi2",'bxd_protein',
					"primateRNA","tiannan"]
		categories = ["chromosome","housekeeping","essential","pathway",
					  "compartment","string_700","complex"]
		data = auc_df.T
		data = data[categories]
		data = data.T[datasets].T

		data.index = ["TCGA Ovarian Cancer(P)",'Human Individuals(Battle,P)',
					  'TCGA Colorectal Cancer(P)','DO mouse strains(P)',
					  "TCGA Breast Cancer(P)",'Founder mouse strains (P)',
					  'Human cell lines(P)','Primate cells(P)',
					  'Human Individuals(Wu,P)','Human Individuals(RP)',
					  'Human Individuals(RS)','DO mouse strains(RS)',
					  'BXD80 mouse strains(P)','Primate cells(RS)',
					  'Kidney cancer cells(P)']
		data.columns = ['chromosome','housekeeping','essential',
						'pathway','compartment','STRING','complex']


		sns.set(context='notebook', style='white', 
			palette='deep', font='Liberation Sans', font_scale=1, 
			color_codes=False, rc=None)
		plt.rcParams["axes.grid"] = False

		plt.clf()
		x_mean_list = list()
		for col in data.columns:
			x_mean_list.append(np.mean(utilsFacade.finite(list(data[col])))-0.5)
		y_mean_list = list()
		for col in data.T.columns:
			y_mean_list.append(max(utilsFacade.finite(list(data.T[col])))-0.5)

		plt.clf()
		fig = plt.figure(figsize = (5,8))
		gs = gridspec.GridSpec(16,11)
		ax1_density = plt.subplot(gs[0:2,0:8])
		ax1_density.set_ylim(0.5,0.8)
		ax1_density.axhline(0.55, alpha = 0.6, color="grey", linestyle='--', linewidth = 0.2, zorder=1)
		ax1_density.axhline(0.6, alpha = 0.6, color="grey", linestyle='--', linewidth = 0.2, zorder=1)
		ax1_density.axhline(0.65, alpha = 0.6, color="grey", linestyle='--', linewidth = 0.2, zorder=1)
		ax1_density.axhline(0.7, alpha = 0.6, color="grey", linestyle='--', linewidth = 0.2, zorder=1)
		ax1_density.axhline(0.75, alpha = 0.6, color="grey", linestyle='--', linewidth = 0.2, zorder=1)
		scalarmap_x, colorList_x = figure2a.get_specific_color_gradient(plt.cm.Greys,
															   			np.array(xrange(len(x_mean_list))))
		ax1_density.bar(np.arange(len(x_mean_list)), 
			x_mean_list, 0.95, color = colorList_x, bottom = 0.5,
			edgecolor = "white", linewidth = 2, zorder = 3)
		plt.xticks(list(xrange(len(data.columns))))
		ax1_density.set_xticklabels([])
		ax1_density.set_xlim(0,len(data.columns))
		ax = plt.subplot(gs[2:10,0:8])
		scalarmap, colorList = figure2a.get_specific_color_gradient(plt.cm.RdBu,
														   			np.array(data), vmin = 0.4, vmax = 0.7)
		sns.heatmap(data, cmap = plt.cm.RdBu, vmin = 0.4,vmax = 0.7,
			linecolor = "white", linewidth = 2, cbar = False)
		y_mean_list = y_mean_list[::-1]
		ax2_density = plt.subplot(gs[2:10,8:10])
		plt.yticks(list(xrange(len(data.index))))
		ax2_density.set_ylim(0,len(data.index))
		ax2_density.set_xlim(0.5,0.85)
		scalarmap_y, colorList_y = figure2a.get_specific_color_gradient(plt.cm.Greys,
															   			np.array(xrange(len(y_mean_list))))
		plt.setp(ax2_density.get_xticklabels(), rotation = 90)
		ax2_density.set_yticklabels([])
		ax2_density.axvline(0.7, color = "red", linestyle = "--", linewidth = 0.5)
		ax2_density.axvline(0.55, alpha = 0.6, color="grey", linestyle='--', linewidth = 0.2, zorder=1)
		ax2_density.axvline(0.6, alpha = 0.6, color="grey", linestyle='--', linewidth = 0.2, zorder=1)
		ax2_density.axvline(0.65, alpha = 0.6, color="grey", linestyle='--', linewidth = 0.2, zorder=1)
		ax2_density.axvline(0.75, alpha = 0.6, color="grey", linestyle='--', linewidth = 0.2, zorder=1)
		ax2_density.axvline(0.8, alpha = 0.6, color="grey", linestyle='--', linewidth = 0.2, zorder=1)
		ax2_density.barh(np.arange(len(y_mean_list)), y_mean_list, 0.95,
						 color = colorList_y, left = 0.5, edgecolor = "white", linewidth = 2, zorder = 3)
		ax_category = plt.subplot(gs[2:10,10:11])
		df = pd.DataFrame({"color":["green"]*7+["lightgreen"]+["magenta"]*2+["green","magenta","green"]})
		df = pd.DataFrame({'color':14*[1]})
		sns.heatmap(df, cbar = False, linewidth = 2, linecolor = "white")
		ax_category.axis("off")
		ax_cbar = plt.subplot(gs[13:14,0:8])
		cbar = fig.colorbar(scalarmap, cax = ax_cbar, orientation = "horizontal")
		cbar.set_label("Area under curve (AUC)")
		plt.savefig(output_folder + "fig2a_auc_matrix.pdf", bbox_inches = "tight", dpi=400)

	@staticmethod
	def export_underlying_data(auc_matrix, **kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('folder','PATH')

		category_list = ['chromosome', 'compartment', 'complex',
						 'housekeeping', 'jiyoye_essentiality', 
						 'k562_essentiality','kbm7_essentiality', 
						 'pathway', 'raji_essentiality',
						 'string_700']
		name_list = ["battle_protein","tiannan","battle_ribo", "battle_rna",
					 "primateRNA", "primatePRO", "wu","mann_all_log2","yibo",
					 "gygi1","gygi2","gygi3","tcga_ovarian",'tcga_breast',
					 'bxd_protein','colo_cancer']		

		concat_list = list()
		names_columns = list()
		category_columns = list()
		for name in name_list:
			for category in category_list:
				print(name,category)
				file_name = name + '_figure1_data_' + category.upper() + '.tsv.gz'
				data = DataFrameAnalyzer.open_in_chunks(folder, file_name)
				concat_list.append(data)
				for i in xrange(len(data)):
					names_columns.append(name)
					category_columns.append(category)
		data = pd.concat(concat_list)
		data['name'] = pd.Series(names_columns, index = data.index)
		data['category'] = pd.Series(category_columns, index = data.index)
		data.to_csv(output_folder + 'suppTable2_fig2a_underlying_data_ROCanalysis_' + time.strftime('%Y%m%d') +'.tsv',
					sep = '\t')

class figure2b:
	@staticmethod
	def execute(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		print('FIGURE2B: main_figure2b_vignette_complexEffect')
		figure2b.main_figure2b_vignette_complexEffect(folder = folder, output_folder = output_folder)

	@staticmethod
	def main_figure2b_vignette_complexEffect(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		print('plot_complex_effect:tcga_ovarian')
		figure2b.plot_complex_effect('tcga_ovarian')
		print('plot_complex_effect:battle_protein')
		figure2b.plot_complex_effect('battle_protein')
		print('plot_complex_effect:gygi1')
		figure2b.plot_complex_effect('gygi1')
		print('plot_complex_effect:gygi3')
		figure2b.plot_complex_effect('gygi3')
		print('plot_complex_effect:tcga_breast')
		figure2b.plot_complex_effect('tcga_breast')
		print('plot_complex_effect:colo_cancer')
		figure2b.plot_complex_effect('colo_cancer')

	@staticmethod
	def get_data(name, **kwargs):
		folder = kwargs.get('folder','PATH')
		complex_data = DataFrameAnalyzer.open_in_chunks(folder, name + '_figure1_data_COMPLEX.tsv.gz')
		return complex_data

	@staticmethod
	def plot_complex_effect(name, **kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		complex_data = figure2b.get_data(name, folder = folder)
		other_correlations = utilsFacade.finite(list(complex_data[complex_data.label==False]["correlations"]))
		complex_correlations = utilsFacade.finite(list(complex_data[complex_data.label==True]["correlations"]))

		pval_list1 = list()
		for i in xrange(1,1000):
			corr1 = random.sample(complex_correlations,100)
			corr2 = random.sample(other_correlations,100)
			odds1, pval1 = scipy.stats.mannwhitneyu(corr1, corr2)
			pval_list1.append(pval1)
		print(np.mean(pval_list1))

		sns.set(context='notebook', style='white', 
			palette='deep', font='Liberation Sans', font_scale=1, 
			color_codes=False, rc=None)
		plt.rcParams["axes.grid"] = True

		plt.clf()
		fig = plt.figure(figsize = (7,5))
		ax = fig.add_subplot(111)
		ax.set_ylabel('Density', fontsize=12)
		ax.set_xlabel('correlation coefficient (pearson)', fontsize=12)
		ax.hist(other_correlations, bins = 50, color='grey', alpha =0.3, normed = 1)
		plottingFacade.func_plotDensities_border(ax, other_correlations, 
												 linewidth = 2, alpha = 1, facecolor = 'grey')
		ax.hist(complex_correlations, bins = 50, color='#AF2D2D', alpha =0.3, normed =1)
		plottingFacade.func_plotDensities_border(ax, complex_correlations, 
												 linewidth = 2, alpha = 1, facecolor = '#AF2D2D')
		plottingFacade.make_full_legend(ax,['n(other)='+str(len(other_correlations)),
			'n(complex)='+str(len(complex_correlations)),
			'pvalComplex(Mann-Whitney U)='+str(np.mean(pval_list1))],['grey']*3, loc = 'upper left')
		ax.set_xlim(-1,1)
		plt.savefig(output_folder + 'fig2b_' + name + '_complex_density_effect.pdf',
					bbox_inches = 'tight', dpi = 400)


if __name__ == "__main__":
	## EXECUTE FIGURE2
	figure2a.execute(folder = sys.argv[1], output_folder = sys.argv[2])
	figure2b.execute(folder = sys.argv[1], output_folder = sys.argv[2])



