# ------------------------------------------------------------
# Author: Natalie Romanov
# ------------------------------------------------------------

class auc_matrix_yeast:
	@staticmethod
	def execute(**kwargs):
		folder = kwargs.get('folder', 'PATH')

		print('read_auc_df')
		auc_df = auc_matrix_yeast.read_auc_df(folder)

		print('plot_auc_matrix')
		auc_matrix_yeast.plot_auc_matrix(folder, auc_df)

	@staticmethod
	def read_auc_df(folder):
		fname = 'aucData_summary_yeast_auc_df.tsv.gz'
		auc_df = DataFrameAnalyzer.getFile(folder, fname)
		return auc_df

	@staticmethod
	def get_specific_color_gradient(colormap,inputList, **kwargs):
		vmin = kwargs.get("vmin", False)
		vmax = kwargs.get("vmax", False)
		cm = plt.get_cmap(colormap)
		if type(inputList)==list:
			if vmin == False and vmax == False:
				cNorm = mpl.colors.Normalize(vmin=min(inputList), vmax=max(inputList))
			else:
				cNorm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
		else:
			if vmin == False and vmax == False:
				cNorm = mpl.colors.Normalize(vmin=inputList.min(), vmax=inputList.max())
			else:
				cNorm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
		scalarMap = mpl.cm.ScalarMappable(norm=cNorm, cmap=cm)
		scalarMap.set_array(inputList)
		colorList = scalarMap.to_rgba(inputList)
		return scalarMap,colorList

	@staticmethod
	def plot_auc_matrix(folder, auc_df):
		max_list = auc_df.max()
		auc_df = auc_df.T
		auc_df['max'] = pd.Series(max_list, index = auc_df.index)
		auc_df = auc_df.sort_values('max', ascending = False)
		auc_df = auc_df.drop('max', axis = 1)
		mean_list = auc_df.mean()
		auc_df = auc_df.T
		auc_df['mean'] = pd.Series(mean_list, index = auc_df.index)
		auc_df = auc_df.sort_values('mean', ascending = True)
		auc_df = auc_df.drop('mean', axis = 1)
		auc_df = auc_df.T
		data = auc_df.copy()

		sns.set(context='notebook', style='white', 
			palette='deep', font='Liberation Sans', font_scale=1, 
			color_codes=False, rc=None)
		plt.rcParams["axes.grid"] = False

		plt.clf()
		x_mean_list = list()
		for col in data.columns:
			x_mean_list.append(np.mean(utilsFacade.finite(list(data[col])))-0.5)
		y_mean_list = list()
		for col in data.T.columns:
			y_mean_list.append(max(utilsFacade.finite(list(data.T[col])))-0.5)

		plt.clf()
		fig = plt.figure(figsize = (5,8))
		gs = gridspec.GridSpec(16,11)
		ax1_density = plt.subplot(gs[0:2,0:8])
		ax1_density.set_ylim(0.5,0.8)
		ax1_density.axhline(0.55, alpha = 0.6, color="grey", linestyle='--', linewidth = 0.2, zorder=1)
		ax1_density.axhline(0.6, alpha = 0.6, color="grey", linestyle='--', linewidth = 0.2, zorder=1)
		ax1_density.axhline(0.65, alpha = 0.6, color="grey", linestyle='--', linewidth = 0.2, zorder=1)
		ax1_density.axhline(0.7, alpha = 0.6, color="grey", linestyle='--', linewidth = 0.2, zorder=1)
		ax1_density.axhline(0.75, alpha = 0.6, color="grey", linestyle='--', linewidth = 0.2, zorder=1)
		scalarmap_x, colorList_x = auc_matrix_yeast.get_specific_color_gradient(plt.cm.Greys,
															   					np.array(xrange(len(x_mean_list))))
		ax1_density.bar(np.arange(len(x_mean_list)), x_mean_list, 0.95, color = colorList_x, bottom = 0.5,
						edgecolor = "white", linewidth = 2, zorder = 3)
		plt.xticks(list(xrange(len(data.columns))))
		ax1_density.set_xticklabels([])
		ax1_density.set_xlim(0,len(data.columns))
		ax = plt.subplot(gs[2:10,0:8])
		scalarmap, colorList = auc_matrix_yeast.get_specific_color_gradient(plt.cm.RdBu, np.array(data),
														   					vmin = 0.4, vmax = 0.7)
		sns.heatmap(data, cmap = plt.cm.RdBu, vmin = 0.4,vmax = 0.7,
					linecolor = "white", linewidth = 2, cbar = False)
		y_mean_list = y_mean_list[::-1]
		ax2_density = plt.subplot(gs[2:10,8:10])
		plt.yticks(list(xrange(len(data.index))))
		ax2_density.set_ylim(0,len(data.index))
		ax2_density.set_xlim(0.5,0.85)
		scalarmap_y, colorList_y = auc_matrix_yeast.get_specific_color_gradient(plt.cm.Greys,
															   					np.array(xrange(len(y_mean_list))))
		plt.setp(ax2_density.get_xticklabels(), rotation = 90)
		ax2_density.set_yticklabels([])
		ax2_density.axvline(0.7, color = "red", linestyle = "--", linewidth = 0.5)
		ax2_density.axvline(0.55, alpha = 0.6, color="grey", linestyle='--', linewidth = 0.2, zorder=1)
		ax2_density.axvline(0.6, alpha = 0.6, color="grey", linestyle='--', linewidth = 0.2, zorder=1)
		ax2_density.axvline(0.65, alpha = 0.6, color="grey", linestyle='--', linewidth = 0.2, zorder=1)
		ax2_density.axvline(0.75, alpha = 0.6, color="grey", linestyle='--', linewidth = 0.2, zorder=1)
		ax2_density.axvline(0.8, alpha = 0.6, color="grey", linestyle='--', linewidth = 0.2, zorder=1)
		ax2_density.barh(np.arange(len(y_mean_list)), y_mean_list,
			0.95, color = colorList_y, left = 0.5,
			edgecolor = "white", linewidth = 2, zorder = 3)
		ax_category = plt.subplot(gs[2:10,10:11])
		df = pd.DataFrame({"color":["green"]*7+["lightgreen"]+["magenta"]*2+["green","magenta","green"]})
		df = pd.DataFrame({'color':14*[1]})
		sns.heatmap(df, cbar = False, linewidth = 2, linecolor = "white")
		ax_category.axis("off")
		ax_cbar = plt.subplot(gs[13:14,0:8])
		cbar = fig.colorbar(scalarmap, cax = ax_cbar, 
			   orientation = "horizontal")
		cbar.set_label("Area under curve (AUC)")
		plt.savefig(folder + "aucData_yeast_auc_matrix.pdf", bbox_inches = "tight", dpi=400)

if __name__ == "__main__":
	auc = auc_matrix_yeast.execute(folder = sys.argv[1])
