# ------------------------------------------------------------
# Author: Natalie Romanov
# ------------------------------------------------------------

class file_Loader:

	@staticmethod
	def load_data(gold_complexes, **kwargs):

		folder = kwargs.get('folder','PATH')

		dat_gygi3 = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_gygi3.tsv.gz')
		dat_gygi2 = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_gygi2.tsv.gz')
		dat_gygi1 = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_gygi1.tsv.gz')
		dat_mann = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_mann.tsv.gz')
		dat_battle_rna = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_battle_rna.tsv.gz')
		dat_battle_ribo = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_battle_ribo.tsv.gz')
		dat_battle_protein = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_battle_protein.tsv.gz')
		dat_tcga_colo = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_tcga_colo.tsv.gz')
		dat_tcga_breast = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_tcga_breast.tsv.gz')
		dat_tcga_ovarian = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_tcga_ovarian.tsv.gz')

		dat_mann = dat_mann[dat_mann.ComplexName.isin(gold_complexes)]
		dat_gygi2 = dat_gygi2[dat_gygi2.ComplexName.isin(gold_complexes)]
		dat_gygi3 = dat_gygi3[dat_gygi3.ComplexName.isin(gold_complexes)]
		dat_gygi1 = dat_gygi1[dat_gygi1.ComplexName.isin(gold_complexes)]
		dat_tcga_colo = dat_tcga_colo[dat_tcga_colo.ComplexName.isin(gold_complexes)]
		dat_tcga_breast = dat_tcga_breast[dat_tcga_breast.ComplexName.isin(gold_complexes)]
		dat_tcga_ovarian = dat_tcga_ovarian[dat_tcga_ovarian.ComplexName.isin(gold_complexes)]
		dat_battle_rna = dat_battle_rna[dat_battle_rna.ComplexName.isin(gold_complexes)]
		dat_battle_ribo = dat_battle_ribo[dat_battle_ribo.ComplexName.isin(gold_complexes)]
		dat_battle_protein = dat_battle_protein[dat_battle_protein.ComplexName.isin(gold_complexes)]

		data_dict = {'mann':dat_mann,
					 'gygi2':dat_gygi2,
					 'gygi3':dat_gygi3,
					 'gygi1':dat_gygi1,
					 'tcga_colo':dat_tcga_colo,
					 'tcga_breast':dat_tcga_breast,
					 'tcga_ovarian':dat_tcga_ovarian,
					 'battle_rna':dat_battle_rna,
					 'battle_ribo':dat_battle_ribo,
					 'battle_protein':dat_battle_protein}

		return data_dict

	@staticmethod
	def load_complex_dictionary(**kwargs):
		folder = kwargs.get('folder','PATH')

		complexDict = DataFrameAnalyzer.read_pickle(folder + 'complex_dictionary.pkl')
		return complexDict

class step7_preparation:
	@staticmethod
	def execute(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		print('load_complex_dictionary')
		complexDict = file_Loader.load_complex_dictionary(folder = folder)

		print("get_gold_complexes")
		gold_complexes = step7_preparation.get_gold_complexes(complexDict)

		print("load_data")
		data_dict = file_Loader.load_data(gold_complexes)

		print("get_complex_info_dict")
		complex_info_dict = step7_preparation.get_complex_info_dict(data_dict, gold_complexes)	

	@staticmethod
	def get_gold_complexes(complexDict):
		gold_complexes = list()
		for complexID in complexDict:
			if complexDict[complexID]["goldComplex"][0]=="yes":
				gold_complexes.append(complexDict[complexID]["altName"][0])
		return gold_complexes

	@staticmethod
	def get_median_correlations(dat, complexID):
		sub = dat[dat.ComplexName==complexID]
		quantCols = figure3.get_quantCols(sub)
		med_corr = -2
		if len(sub) >= 5:
			corrData = sub[quantCols].T.corr()
			corrValues = utilsFacade.get_correlation_values(corrData)
			med_corr = np.median(utilsFacade.finite(corrValues))
		return med_corr,len(sub)

	@staticmethod
	def get_quantCols(dataset):
		quantCols = list()
		for col in list(dataset.columns):
			if col.startswith("quant_")==True:
				quantCols.append(col)
		if len(quantCols) == 0:
			for col in list(dataset.columns):
				if col.find("Ratio")!=-1:
					quantCols.append(col)
		return quantCols

	@staticmethod
	def get_complex_info_dict(data_dict, gold_complexes, **kwargs):
		folder = kwargs.get('folder','PATH')

		complex_info_dict = dict()
		for complexID in gold_complexes:
			med_corr_tcgaBreast,length_tcgaBreast = step7_preparation.get_median_correlations(data_dict['tcga_breast'], complexID)
			med_corr_tcgaOvarian,length_tcgaOvarian = step7_preparation.get_median_correlations(data_dict['tcga_ovarian'], complexID)
			med_corr_mann,length_mann = step7_preparation.get_median_correlations(data_dict['mann'], complexID)
			med_corr_gygi3,length_gygi3 = step7_preparation.get_median_correlations(data_dict['gygi3'], complexID)
			med_corr_gygi1,length_gygi1 = step7_preparation.get_median_correlations(data_dict['gygi1'], complexID)
			med_corr_bp,length_bp = step7_preparation.get_median_correlations(data_dict['battle_protein'], complexID)
			med_corr_gygi2,length_gygi2 = step7_preparation.get_median_correlations(data_dict['gygi2'], complexID)
			med_corr_battle_rna,length_battle_rna = step7_preparation.get_median_correlations(data_dict['battle_rna'], complexID)
			med_corr_battle_ribo,length_battle_ribo = step7_preparation.get_median_correlations(data_dict['battle_ribo'], complexID)
			med_corr_tcgaColo,length_tcgaColo = step7_preparation.get_median_correlations(data_dict['tcga_colo'], complexID)
			med_corrs = [med_corr_gygi3 ,med_corr_gygi1, med_corr_bp, 
						 med_corr_tcgaBreast, med_corr_tcgaOvarian,
						 med_corr_mann, med_corr_tcgaColo]
			filtered_meds = filter(lambda a:a>-2, med_corrs)

			if len(filtered_meds)>=4:
				complex_name = complexID
				complex_info_dict.setdefault(complex_name,[]).append({
															  "bp": (med_corr_bp,length_bp),
															  "gygi1": (med_corr_gygi1,length_gygi1),
															  "gygi3": (med_corr_gygi3,length_gygi3),
															  "gygi2": (med_corr_gygi2,length_gygi2),
															  "brna": (med_corr_battle_rna,length_battle_rna),
															  "bribo": (med_corr_battle_ribo,length_battle_ribo),
															  "tcga_breast": (med_corr_tcgaBreast,length_tcgaBreast),
															  "tcga_ovarian": (med_corr_tcgaOvarian,length_tcgaOvarian),
															  "mann": (med_corr_mann,length_mann),
															  'tcga_colo':(med_corr_tcgaColo, length_tcgaColo)})
		DataFrameAnalyzer.to_pickle(complex_info_dict, folder + 'fig3_complex_info_dictionary.pkl')
		return complex_info_dict

class step7_stats_check:
	@staticmethod
	def execute(**kwargs):
		folder = kwargs.get('folder','PATH')

		print('load_data_for_first_check')
		df = step7_stats_check.load_data_for_first_check(folder = folder)

		print('check_on_overall_agreement_between_datasets')
		pvalue, med_corr = step7_stats_check.check_on_overall_agreement_between_datasets(df)

	@staticmethod
	def load_data_for_first_check(**kwargs):
		folder = kwargs.get('folder','PATH')

		df = DataFrameAnalyzer.getFile(folder, 'figure3_underlying_data_ranked_zscores.tsv')
		return df

	@staticmethod
	def check_on_overall_agreement_between_datasets(df):
		corr_values = utilsFacade.get_correlation_values(df.corr())

		#define random distribution
		table_list = list()
		for i,row in df.iterrows():
			temp = list(row)
			np.random.shuffle(temp)
			table_list.append(temp)
		rand_data = pd.DataFrame(table_list)
		rand_data = rand_data.T
		table_list = list()
		for i,row in rand_data.iterrows():
			temp = list(row)
			np.random.shuffle(temp)
			table_list.append(temp)
		rand_data = pd.DataFrame(table_list)
		rand_data = rand_data.T

		#collect correlation values from random distribution
		rand_corr_values = utilsFacade.get_correlation_values(rand_data.corr())
		
		#check normality (W>0.9) of random distribution
		w, p = scipy.stats.shapiro(rand_corr_values)
		print(w)

		tstat, pvalue = scipy.stats.ttest_ind(corr_values, rand_corr_values)
		return pvalue, np.median(corr_values)

class step7_figure:
	@staticmethod
	def execute(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		print('FIGURE3: main_figure3_landscape')
		step7_figure.main_figure3_landscape(folder = folder, output_folder = output_folder)

	@staticmethod
	def main_figure3_landscape(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		print('get_figure_data')
		complex_info_dict = step7_figure.get_figure_data(folder = folder)

		print("prepare_df")
		df, dfList, real_df = step7_figure.prepare_df_protein(complex_info_dict)

		print("plot_landscape")
		step7_figure.plot_landscape(df, dfList, real_df, output_folder = output_folder)

		print('export_underlying_zscore_data')
		step7_figure.export_underlying_zscore_data(df, output_folder = output_folder)

	@staticmethod
	def get_figure_data(**kwargs):
		folder = kwargs.get('folder','PATH')

		complex_info_dict = DataFrameAnalyzer.read_pickle(folder + 'fig3_complex_info_dictionary.pkl')
		return complex_info_dict

	@staticmethod
	def prepare_df_protein(complex_info_dict):
		gygi3_list = list()
		gygi1_list = list()
		gygi2_list = list()

		bp_list = list()
		brna_list = list()
		bribo_list = list()

		mann_list = list()
		tcga_breast_list = list()
		tcga_ovarian_list = list()
		tcga_colo_list = list()

		labelList = list()
		for complexID in complex_info_dict:
			gygi3_list.append(complex_info_dict[complexID][0]["gygi3"][0])
			gygi1_list.append(complex_info_dict[complexID][0]["gygi1"][0])
			bp_list.append(complex_info_dict[complexID][0]["bp"][0])
			mann_list.append(complex_info_dict[complexID][0]["mann"][0])
			tcga_breast_list.append(complex_info_dict[complexID][0]["tcga_breast"][0])
			tcga_ovarian_list.append(complex_info_dict[complexID][0]["tcga_ovarian"][0])
			tcga_colo_list.append(complex_info_dict[complexID][0]["tcga_colo"][0])
			labelList.append(complexID.split("(")[0])

		df = pd.DataFrame([gygi3_list,gygi1_list,bp_list])
		df.columns = labelList
		df.index = ["gygi3","gygi1","battle-protein"]

		df = pd.DataFrame([gygi3_list,gygi1_list,bp_list, mann_list,
						   tcga_breast_list,tcga_ovarian_list,
						   tcga_colo_list])
		df.columns = labelList
		df.index = ["gygi3","gygi1","battle-protein","mann",
					"tcga_breast","tcga_ovarian","tcga_colo"]


		medians = [np.median(df[colname]) for colname in df.columns]
		colnames = list(df.columns)
		medians,colnames = zip(*sorted(zip(medians,colnames),reverse=True))
		df = df[list(colnames)]
		labelList = df.index
		columnList = df.columns
		dfList = map(list,df.values)

		subunit_length_list_gygi3 = dict()
		subunit_length_list_gygi1 = dict()
		subunit_length_list_bp = dict()
		subunit_length_list_mann = dict()
		subunit_length_list_tcgaBreast = dict()
		subunit_length_list_tcgaOvarian = dict()
		subunit_length_list_tcgaColo = dict()
		for complexID in complex_info_dict.keys():
			key = complexID.split("(")[0]
			subunit_length_list_gygi3.setdefault(key,[]).append(complex_info_dict[complexID][0]["gygi3"][1])
			subunit_length_list_gygi1.setdefault(key,[]).append(complex_info_dict[complexID][0]["gygi1"][1])
			subunit_length_list_bp.setdefault(key,[]).append(complex_info_dict[complexID][0]["bp"][1])
			subunit_length_list_mann.setdefault(key,[]).append(complex_info_dict[complexID][0]["mann"][1])
			subunit_length_list_tcgaBreast.setdefault(key,[]).append(complex_info_dict[complexID][0]["tcga_breast"][1])
			subunit_length_list_tcgaOvarian.setdefault(key,[]).append(complex_info_dict[complexID][0]["tcga_ovarian"][1])
			subunit_length_list_tcgaColo.setdefault(key,[]).append(complex_info_dict[complexID][0]["tcga_colo"][1])
		
		su_length_list_gygi3 = list()
		su_length_list_gygi1 = list()
		su_length_list_bp = list()
		su_length_list_mann = list()
		su_length_list_tcgaBreast = list()
		su_length_list_tcgaOvarian = list()
		su_length_list_tcgaColo = list()
		for complexID in columnList:
			su_length_list_bp.append(subunit_length_list_bp[complexID][0])
			su_length_list_gygi1.append(subunit_length_list_gygi1[complexID][0])
			su_length_list_gygi3.append(subunit_length_list_gygi3[complexID][0])
			su_length_list_mann.append(subunit_length_list_mann[complexID][0])
			su_length_list_tcgaBreast.append(subunit_length_list_tcgaBreast[complexID][0])
			su_length_list_tcgaOvarian.append(subunit_length_list_tcgaOvarian[complexID][0])
			su_length_list_tcgaColo.append(subunit_length_list_tcgaColo[complexID][0])

		df_lst = list()
		for df in dfList:
			rank_list = rankdata(df)
			minimum_rank = rank_list.min() 
			rank_list = [np.nan if item==minimum_rank else item for item in rank_list]
			df_lst.append(rank_list)
		df = pd.DataFrame(df_lst)
		df.index = labelList
		df.columns = columnList

		real_df = pd.DataFrame(dfList)
		real_df.index = labelList
		real_df.columns = columnList

		myArray = np.array(df)
		normalizedArray = []

		for row in range(0, len(myArray)):
			list_values = []
			Min =  min(utilsFacade.finite(list(myArray[row])))
			Max = max(utilsFacade.finite(list(myArray[row])))
			mean = np.mean(utilsFacade.finite(list(myArray[row])))
			std = np.std(utilsFacade.finite(list(myArray[row])))

			for element in myArray[row]:
				list_values.append((element - mean)/std)
			normalizedArray.append(list_values)

		newArray = []
		for row in range(0, len(normalizedArray)):
			list_values = normalizedArray[row]
			newArray.append(list_values)
		new_df = pd.DataFrame(newArray)
		new_df.columns = list(df.columns)
		new_df.index = df.index
		df = new_df.copy()
		dfList=map(list,df.values)
		return df,dfList, real_df

	@staticmethod
	def get_scalarmap(cmap, dfList):
		scalarmap1,colorList1 = colorFacade.get_specific_color_gradient(cmap,np.array(utilsFacade.finite(dfList[0])))
		scalarmap2,colorList2 = colorFacade.get_specific_color_gradient(cmap,np.array(utilsFacade.finite(dfList[1])))
		scalarmap3,colorList3 = colorFacade.get_specific_color_gradient(cmap,np.array(utilsFacade.finite(dfList[2])))
		scalarmap4,colorList4 = colorFacade.get_specific_color_gradient(cmap,np.array(utilsFacade.finite(dfList[3])))
		scalarmap5,colorList5 = colorFacade.get_specific_color_gradient(cmap,np.array(utilsFacade.finite(dfList[4])))
		scalarmap6,colorList6 = colorFacade.get_specific_color_gradient(cmap,np.array(utilsFacade.finite(dfList[5])))
		scalarmap7,colorList7 = colorFacade.get_specific_color_gradient(cmap,np.array(utilsFacade.finite(dfList[6])))
		all_values = utilsFacade.flatten(dfList)
		scalarmap,colorList = colorFacade.get_specific_color_gradient(cmap,
							  np.array(utilsFacade.finite(all_values)), vmin = -2, vmax = 2)
		scalarmap_dict = {'1':scalarmap1,
						  '2':scalarmap2,
						  '3':scalarmap3,
						  '4': scalarmap4,
						  '5': scalarmap5,
						  '6':scalarmap6,
						  '7': scalarmap7,
						  'all':scalarmap}
		return scalarmap_dict,all_values

	@staticmethod
	def plot_landscape(df, dfList, real_df, **kwargs):
		output_folder = kwargs.get('output_folder','PATH')

		complex_list = list(real_df.columns)
		real_df = real_df.replace(-2, np.nan)
		corr_median_list = list(real_df.median())
		corr_median_list, complex_list = zip(*sorted(zip(corr_median_list, complex_list), reverse = True))
		df = df[list(complex_list)]
		scalarmap_dict, all_values = step7_figure.get_scalarmap(cmap, dfList)


		sns.set(context='notebook', style='white', 
			palette='deep', font='Liberation Sans', font_scale=1, 
			color_codes=False, rc=None)
		plt.rcParams["axes.grid"] = False

		plt.clf()
		fig = plt.figure(figsize=(10,4))
		gs = gridspec.GridSpec(10,10)
		ax = plt.subplot(gs[0:3,0:])
		ind = np.arange(len(corr_median_list))
		width = 0.85
		sc, color_list = colorFacade.get_specific_color_gradient(plt.cm.Greys_r,
						 np.array(list(xrange(len(corr_median_list)))))
		rects_all = ax.bar(ind, corr_median_list, width, color=color_list, edgecolor = 'white')		

		ax.set_xlim(0,len(corr_median_list)+0.5)
		ax.set_ylim(min(corr_median_list), max(corr_median_list))
		perc25, perc75 = np.percentile(corr_median_list,[25,75])
		ax.set_xticklabels([])
		ax.axhline(perc25, color = 'red', linewidth = 0.5, linestyle = '--')
		ax.axhline(perc75, color = 'red', linewidth = 0.5, linestyle = '--')

		ax = plt.subplot(gs[3:8,0:])
		plt.rcParams["axes.grid"] = False
		dfList = map(list,df.values)
		count = 0
		for item in dfList:
			colors = list()
			for i in item:
				if str(i) == 'nan':
					color = np.array(colorFacade.hex_to_rgb('#808080'))/256.0
					color = list(color)
					color.append(1)
					colors.append(color)
				else:
					colors.append(scalarmap_dict['all'].to_rgba(i))
			ax.scatter(xrange(len(item)),[count]*len(item),
					   color=colors,s=200,marker="s",edgecolor="white")
			ax.scatter(len(item),[count],color="white",
					   s=200,marker="s",edgecolor="white")
			count+=1
		plt.yticks(list(xrange(len(dfList))))
		ylabelList = ["DO mouse strains(P)", "Founder mouse strains(P)",
					  "Human Individuals(P)", "Human cell types(P)",
					  "TCGA Breast Cancer (P)", "TCGA Ovarian Cancer(P)",
					  'TCGA Colorectal Cancer(P)']
		ax.set_yticklabels(list(ylabelList))
		plt.xticks(list(xrange(len(df.columns))))
		ax.set_xticklabels(list(df.columns), rotation=45, fontsize=8, ha="right")
		ax.set_xlim(-1,len(df.columns)-0.5)
		plt.savefig(output_folder + "fig3_landscape_variability_complexes_protein_main.pdf",
					bbox_inches="tight", dpi=400)


		all_values = utilsFacade.flatten(dfList)
		scalarmap,colorList = colorFacade.get_specific_color_gradient(plt.cm.RdBu,
							  np.array(utilsFacade.finite(all_values)), vmin = -2, vmax = 2)

		plt.rcParams["axes.grid"] = True
		plt.clf()
		fig = plt.figure(figsize=(10,1))
		ax = fig.add_subplot(111)
		ax.set_xticklabels([])
		ax.set_yticklabels([])
		cbar = fig.colorbar(scalarmap,orientation="horizontal")
		plt.savefig(output_folder + "fig3_landscape_variability_complexes_protein_main_LEGEND.pdf",
					bbox_inches="tight", dpi=400)

	@staticmethod
	def export_underlying_zscore_data(df, **kwargs):
		output_folder = kwargs.get('output_folder','PATH')

		df = df.T
		df = df.rename(index=str, columns={'gygi3': 'DO mouse strains(P)',
										   'gygi1': 'Founder mouse strains(P)',
										   'battle-protein':'Human Individuals(Battle,P)',
										   'mann':'Human cell lines(P)',
										   'tcga_breast':'TCGA Breast Cancer(P)',
										   'tcga_ovarian':'TCGA Ovarian Cancer(P)',
										   'tcga_colo':'TCGA Colorectal Cancer(P)'})
		df.to_csv(output_folder + 'figure3_underlying_data_ranked_zscores.tsv', sep = '\t')

if __name__ == "__main__":
	## EXECUTE STEP7
	step7_preparation.execute(folder = sys.argv[1], output_folder = sys.argv[2])
	step7_figure.execute(folder = sys.argv[1], output_folder = sys.argv[2])
	step7_stats_check.execute(folder = sys.argv[1])
