# ------------------------------------------------------------
# Author: Natalie Romanov
# ------------------------------------------------------------

class figure3:
	@staticmethod
	def execute(**kwargs):
		folder = kwargs.get('folder','PATH')
		cmap = kwargs.get("cmap",plt.cm.RdBu)
		output_folder = kwargs.get('output_folder','PATH')
		
		print('FIGURE3: main_figure3_landscape')
		figure3.main_figure3_landscape(folder = folder, cmap = cmap,
									   output_folder = output_folder)

	@staticmethod
	def main_figure3_landscape(**kwargs):
		cmap = kwargs.get("cmap",plt.cm.RdBu)
		output_folder = kwargs.get('output_folder','PATH')
		recluster_unbiasedly = kwargs.get('recluster_unbiasedly','')
		
		print('get_figure_data')
		complex_info_dict = figure3.get_figure_data()

		print("prepare_df")
		df, dfList, real_df = figure3.prepare_df_protein(complex_info_dict)

		print("plot_landscape")
		figure3.plot_landscape(df, dfList, real_df, cmap, output_folder)

	@staticmethod
	def get_figure_data(**kwargs):
		folder = kwargs.get('folder','PATH')
		complex_info_dict = DataFrameAnalyzer.read_pickle(folder + 'fig3_complex_info_dictionary.pkl')
		return complex_info_dict

	@staticmethod
	def prepare_df_protein(complex_info_dict, recluster_unbiasedly):
		gygi3_list = list()
		gygi1_list = list()
		gygi2_list = list()

		bp_list = list()
		brna_list = list()
		bribo_list = list()

		mann_list = list()
		tcga_breast_list = list()
		tcga_ovarian_list = list()
		tcga_colo_list = list()

		labelList = list()
		for complexID in complex_info_dict:
			gygi3_list.append(complex_info_dict[complexID][0]["gygi3"][0])
			gygi1_list.append(complex_info_dict[complexID][0]["gygi1"][0])
			bp_list.append(complex_info_dict[complexID][0]["bp"][0])
			mann_list.append(complex_info_dict[complexID][0]["mann"][0])
			tcga_breast_list.append(complex_info_dict[complexID][0]["tcga_breast"][0])
			tcga_ovarian_list.append(complex_info_dict[complexID][0]["tcga_ovarian"][0])
			tcga_colo_list.append(complex_info_dict[complexID][0]["tcga_colo"][0])
			labelList.append(complexID.split("(")[0])

		df = pd.DataFrame([gygi3_list,gygi1_list,bp_list])
		df.columns = labelList
		df.index = ["gygi3","gygi1","battle-protein"]

		df = pd.DataFrame([gygi3_list,gygi1_list,bp_list, mann_list,
						   tcga_breast_list,tcga_ovarian_list,tcga_colo_list])
		df.columns = labelList
		df.index = ["gygi3","gygi1","battle-protein","mann",
					"tcga_breast","tcga_ovarian",'tcga_colo']

		
		medians = [np.median(df[colname]) for colname in df.columns]
		colnames = list(df.columns)
		medians,colnames = zip(*sorted(zip(medians,colnames),reverse=True))
		df = df[list(colnames)]
		labelList = df.index
		columnList = df.columns
		dfList = map(list,df.values)

		keys = list(complex_info_dict.keys())
		subunit_length_list_gygi3 = dict()
		subunit_length_list_gygi1 = dict()
		subunit_length_list_bp = dict()
		subunit_length_list_mann = dict()
		subunit_length_list_tcgaBreast = dict()
		subunit_length_list_tcgaOvarian = dict()
		subunit_length_list_tcgaColo = dict()
		for complexID in keys:
			key = complexID.split("(")[0]
			subunit_length_list_gygi3.setdefault(key,[]).append(complex_info_dict[complexID][0]["gygi3"][1])
			subunit_length_list_gygi1.setdefault(key,[]).append(complex_info_dict[complexID][0]["gygi1"][1])
			subunit_length_list_bp.setdefault(key,[]).append(complex_info_dict[complexID][0]["bp"][1])
			subunit_length_list_mann.setdefault(key,[]).append(complex_info_dict[complexID][0]["mann"][1])
			subunit_length_list_tcgaBreast.setdefault(key,[]).append(complex_info_dict[complexID][0]["tcga_breast"][1])
			subunit_length_list_tcgaOvarian.setdefault(key,[]).append(complex_info_dict[complexID][0]["tcga_ovarian"][1])
			subunit_length_list_tcgaColo.setdefault(key,[]).append(complex_info_dict[complexID][0]["tcga_colo"][1])
		
		su_length_list_gygi3 = list()
		su_length_list_gygi1 = list()
		su_length_list_bp = list()
		su_length_list_mann = list()
		su_length_list_tcgaBreast = list()
		su_length_list_tcgaOvarian = list()
		su_length_list_tcgaColo = list()
		for complexID in columnList:
			su_length_list_bp.append(subunit_length_list_bp[complexID][0])
			su_length_list_gygi1.append(subunit_length_list_gygi1[complexID][0])
			su_length_list_gygi3.append(subunit_length_list_gygi3[complexID][0])
			su_length_list_mann.append(subunit_length_list_mann[complexID][0])
			su_length_list_tcgaBreast.append(subunit_length_list_tcgaBreast[complexID][0])
			su_length_list_tcgaOvarian.append(subunit_length_list_tcgaOvarian[complexID][0])
			su_length_list_tcgaColo.append(subunit_length_list_tcgaColo[complexID][0])

		df_lst = list()
		for df in dfList:
			rank_list = rankdata(df)
			minimum_rank = rank_list.min() 
			rank_list = [np.nan if item==minimum_rank else item for item in rank_list]
			df_lst.append(rank_list)
		df = pd.DataFrame(df_lst)
		df.index = labelList
		df.columns = columnList

		real_df = pd.DataFrame(dfList)
		real_df.index = labelList
		real_df.columns = columnList

		myArray = np.array(df)
		normalizedArray = []

		for row in range(0, len(myArray)):
			list_values = []
			Min =  min(utilsFacade.finite(list(myArray[row])))
			Max = max(utilsFacade.finite(list(myArray[row])))
			mean = np.mean(utilsFacade.finite(list(myArray[row])))
			std = np.std(utilsFacade.finite(list(myArray[row])))

			for element in myArray[row]:
				list_values.append((element - mean)/std)
			normalizedArray.append(list_values)

		newArray = []
		for row in range(0, len(normalizedArray)):
			list_values = normalizedArray[row]
			newArray.append(list_values)
		new_df = pd.DataFrame(newArray)
		new_df.columns = list(df.columns)
		new_df.index = df.index
		df = new_df.copy()
		dfList = map(list,df.values)
		return df,dfList, real_df

	@staticmethod
	def get_scalarmap(cmap, dfList):
		scalarmap1,colorList1 = colorFacade.get_specific_color_gradient(cmap,np.array(utilsFacade.finite(dfList[0])))
		scalarmap2,colorList2 = colorFacade.get_specific_color_gradient(cmap,np.array(utilsFacade.finite(dfList[1])))
		scalarmap3,colorList3 = colorFacade.get_specific_color_gradient(cmap,np.array(utilsFacade.finite(dfList[2])))
		scalarmap4,colorList4 = colorFacade.get_specific_color_gradient(cmap,np.array(utilsFacade.finite(dfList[3])))
		scalarmap5,colorList5 = colorFacade.get_specific_color_gradient(cmap,np.array(utilsFacade.finite(dfList[4])))
		scalarmap6,colorList6 = colorFacade.get_specific_color_gradient(cmap,np.array(utilsFacade.finite(dfList[5])))
		scalarmap7,colorList7 = colorFacade.get_specific_color_gradient(cmap,np.array(utilsFacade.finite(dfList[6])))
		all_values = utilsFacade.flatten(dfList)
		scalarmap,colorList = colorFacade.get_specific_color_gradient(cmap,
							  np.array(utilsFacade.finite(all_values)), vmin = -2, vmax = 2)
		scalarmap_dict = {'1':scalarmap1,
						  '2':scalarmap2,
						  '3':scalarmap3,
						  '4': scalarmap4,
						  '5': scalarmap5,
						  '6':scalarmap6,
						  '7': scalarmap7,
						  'all':scalarmap}
		return scalarmap_dict,all_values

	@staticmethod
	def plot_landscape(df, dfList, real_df, cmap, output_folder, **kwargs):
		complex_list = list(real_df.columns)
		real_df = real_df.replace(-2, np.nan)
		corr_median_list = list(real_df.median())
		corr_median_list, complex_list = zip(*sorted(zip(corr_median_list, complex_list), reverse = True))
		df = df[list(complex_list)]
		scalarmap_dict, all_values = figure3.get_scalarmap(cmap, dfList)


		sns.set(context='notebook', style='white', 
			palette='deep', font='Liberation Sans', font_scale=1, 
			color_codes=False, rc=None)
		plt.rcParams["axes.grid"] = False

		plt.clf()
		fig=plt.figure(figsize=(10,4))
		gs = gridspec.GridSpec(10,10)
		ax = plt.subplot(gs[0:3,0:])
		ind = np.arange(len(corr_median_list))
		width = 0.85
		sc, color_list = colorFacade.get_specific_color_gradient(plt.cm.Greys_r,
						 np.array(list(xrange(len(corr_median_list)))))
		rects_all = ax.bar(ind, corr_median_list, width, color=color_list, edgecolor = 'white')		

		ax.set_xlim(0,len(corr_median_list)+0.5)
		ax.set_ylim(min(corr_median_list), max(corr_median_list))
		perc25, perc75 = np.percentile(corr_median_list,[25,75])
		ax.set_xticklabels([])
		ax.axhline(perc25, color = 'red', linewidth = 0.5, linestyle = '--')
		ax.axhline(perc75, color = 'red', linewidth = 0.5, linestyle = '--')

		ax = plt.subplot(gs[3:8,0:])
		plt.rcParams["axes.grid"] = False
		dfList = map(list,df.values)
		count = 0
		for item in dfList:
			colors = list()
			for i in item:
				if str(i) == 'nan':
					color = np.array(colorFacade.hex_to_rgb('#808080'))/256.0
					color = list(color)
					color.append(1)
					colors.append(color)
				else:
					colors.append(scalarmap_dict['all'].to_rgba(i))

			ax.scatter(xrange(len(item)),[count]*len(item), color=colors, s=200, marker="s", edgecolor="white")
			ax.scatter(len(item),[count], color="white", s=200, marker="s", edgecolor="white")
			count+=1
		plt.yticks(list(xrange(len(dfList))))
		ylabelList = ["DO mouse strains(P)", "Founder mouse strains(P)",
					  "Human Individuals(P)", "Human cell types(P)",
					  "TCGA Breast Cancer (P)", "TCGA Ovarian Cancer(P)",
					  'TCGA Colorectal Cancer(P)']
		ax.set_yticklabels(list(ylabelList))
		plt.xticks(list(xrange(len(df.columns))))
		ax.set_xticklabels(list(df.columns), rotation=45, fontsize=8, ha="right")
		ax.set_xlim(-1,len(df.columns)-0.5)
		plt.savefig(output_folder + "fig3_landscape_variability_complexes_" + name + ".pdf",
					bbox_inches="tight", dpi=400)


		all_values = utilsFacade.flatten(dfList)
		scalarmap,colorList = colorFacade.get_specific_color_gradient(cmap,
							  np.array(utilsFacade.finite(all_values)), vmin = -2, vmax = 2)

		plt.rcParams["axes.grid"] = True
		plt.clf()
		fig = plt.figure(figsize=(10,1))
		ax = fig.add_subplot(111)
		ax.set_xticklabels([])
		ax.set_yticklabels([])
		cbar = fig.colorbar(scalarmap,orientation="horizontal")
		plt.savefig(output_folder + "fig3_landscape_variability_complexes_" + name + "_LEGEND.pdf",
					bbox_inches="tight", dpi=400)

if __name__ == "__main__":
	## EXECUTE FIGURE3
	figure3.execute(folder = sys.argv[1], output_folder = sys.argv[2], cmap = sys.argv[3])


