# ------------------------------------------------------------
# Author: Natalie Romanov
# ------------------------------------------------------------

class figure4a:
	@staticmethod
	def execute(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		print('FIGURE4A: main_figure4a_complex_intern_landscapes')
		figure4a.main_figure4a_complex_intern_landscapes(folder = folder, output_folder = output_folder)

	@staticmethod
	def main_figure4a_complex_intern_landscapes(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		print('get_data')
		data_dict = figure4a.get_data(folder = folder)

		name_list = ['gygi3','gygi1','battle_protein','mann_all_log2',
					 'tcgaBreast','tcgaOvarian', 'tcgaColoCancer']

		data_list = [data_dict['gygi3'], data_dict['gygi1'],
					 data_dict['battle_protein'], data_dict['mann'],
					 data_dict['tcgaBreast'], data_dict['tcgaOvarian'],
					 data_dict['tcgaColoCancer']]

		print('iteration_complexes')
		figure4a.iteration_complexes(data_list, name_list, output_folder = output_folder)

	@staticmethod
	def get_data(**kwargs):
		folder = kwargs.get('folder','PATH')

		stoch_gygi3 = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_stoch_gygi3.tsv.gz')
		stoch_gygi1 = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_stoch_gygi1.tsv.gz')
		stoch_mann = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_stoch_mann.tsv.gz')
		stoch_battle_protein = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_stoch_battle_protein.tsv.gz')
		stoch_tcgaBreast = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_stoch_tcgaBreast.tsv.gz')
		stoch_tcgaOvarian = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_stoch_tcgaOvarian.tsv.gz')
		stoch_tcgaColorCancer = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_stoch_tcgaColoCancer.tsv.gz')

		return {'gygi3':stoch_gygi3,
				'gygi1':stoch_gygi1,
				'mann':stoch_mann,
				'battle_protein':stoch_battle_protein,
				'tcgaBreast':stoch_tcgaBreast,
				'tcgaOvarian':stoch_tcgaOvarian,
				'tcgaColoCancer':stoch_tcgaColorCancer}

	@staticmethod
	def load_relevant_complexes(**kwargs):
		folder = kwargs.get('folder','PATH')

		com_data = DataFrameAnalyzer.getFile(folder, 'figure4a_relevant_complexes.tsv')
		relevant_complexes = list(com_data.index)
		return relevant_complexes

	@staticmethod
	def prepare_variance_dataframe(name_list, data_list, complex_df, proteins):
		df_list = list()
		for protein in proteins:
			temp = list()
			for d,n in zip(data_list, name_list):
				sub = complex_df[complex_df.fileName==n]
				if protein in list(sub.index):
					if type(sub.loc[protein]) == pd.DataFrame:
						quant_cols = utilsFacade.filtering(d, 'quant_', condition = 'startswith')
						s = sub.loc[protein][quant_cols].T
						coverage_list = list()
						for c,col in enumerate(list(s.columns)):
							temp_finite = utilsFacade.finite(list(np.array(s)[:,c]))
							tempSmall = list(np.array(s)[:,c])
							coverage_list.append(float(len(temp_finite))/float(len(tempSmall))*100)
						s = sub.loc[protein]
						s['coverage'] = pd.Series(coverage_list, index = s.index)
						s = s.sort_values('coverage', ascending = False)
						var_value = s.iloc[0]['relative_variance']
						temp.append(var_value)
					else:
						var_value = sub.loc[protein]['relative_variance']
						temp.append(var_value)
				else:
					temp.append(np.nan)
			df_list.append(temp)
		df = pd.DataFrame(df_list)
		df.index = proteins
		df.columns = name_list
		df = df.dropna(thresh = int(len(df.columns)/2.0))
		return df.T

	@staticmethod
	def rank_dataset(df, proteins, name_list):
		df_lst = list()
		dfList = map(list,df.values)
		for f in dfList:
			rank_list = rankdata(f)
			all_ranks = rankdata(utilsFacade.finite(f))
			if len(all_ranks)>0:
				minimum_rank = all_ranks.min() 
			temp_list = list()
			for fitem, rank in zip(f, rank_list):
				if str(fitem) != 'nan':
					temp_list.append(rank)
				else:
					temp_list.append(np.nan)
			df_lst.append(temp_list)
		df = pd.DataFrame(df_lst)
		df.columns = proteins
		df.index = name_list
		return df

	@staticmethod
	def get_zscores(df):
		myArray = np.array(df)
		normalizedArray = []
		for row in range(0, len(myArray)):
			list_values = []
			Min =  min(utilsFacade.finite(list(myArray[row])))
			Max = max(utilsFacade.finite(list(myArray[row])))
			mean = np.mean(utilsFacade.finite(list(myArray[row])))
			std = np.std(utilsFacade.finite(list(myArray[row])))
			for element in myArray[row]:
				list_values.append((element - mean)/std)
			normalizedArray.append(list_values)

		newArray = []
		for row in range(0, len(normalizedArray)):
			list_values = normalizedArray[row]
			newArray.append(list_values)

		new_df = pd.DataFrame(newArray)
		new_df.columns = list(df.columns)
		new_df.index = df.index
		df = new_df.copy()
		df = df.iloc[::-1]
		dfList = map(list,df.values)
		return df, dfList

	@staticmethod
	def manage_proteasome_data(df):
		try:
			df = df.drop(['PSD3','C9','C2','C6','RPN1'], axis = 1)
		except:
			remove_list = ['PSD3','C9','C2','C6','RPN1']
			remove_list1 = [item[0] + item[1:].lower() for item in remove_list]
			remove_list = remove_list + remove_list1
			df = df[~df.index.isin(remove_list)]
		return df		

	@staticmethod
	def plot_heatmap(df, complex_id, altName, **kwargs):
		output_folder = kwargs.get('output_folder','PATH')

		sns.set(context='notebook', style='white', 
			palette='deep', font='Liberation Sans', font_scale=1, 
			color_codes=False, rc=None)
		plt.rcParams["axes.grid"] = False

		sc, color_list = colorFacade.get_specific_color_gradient(plt.cm.RdBu_r,
						 np.array(utilsFacade.finite(utilsFacade.flatten(map(list, df.values)))),
						 vmin = -2, vmax = 2)
		
		plt.clf()
		if complex_id.find('26S')!=-1 or complex_id.find('NPC')!=-1:
			fig = plt.figure(figsize = (10,3))
		else:
			fig = plt.figure(figsize = (5,3))
		ax = fig.add_subplot(111)
		mask = df.isnull()
		mask1 = mask.copy()
		mask1 = mask1.replace(True,'bla')
		mask1 = mask1.replace(False,'20')
		mask1 = mask1.replace('bla',1)
		mask1 = mask1.replace('20',0)
		sns.heatmap(mask1, cmap = ['grey'], linewidth = 0.2,
					cbar = False, xticklabels = [], yticklabels = [])
		sns.heatmap(df, cmap = plt.cm.RdBu_r, linewidth = 0.2,
					mask = mask, vmin = -2, vmax = 2)
		plt.savefig(output_folder + 'fig4a_' + altName + '_heatmap_subunits.pdf',
					bbox_inches = 'tight', dpi = 400)

	@staticmethod
	def get_alternative_name(complex_id):
		altName = complex_id
		altName = '_'.join(altName.split(' '))
		altName = altName.replace('/','')
		altName = altName.replace(':','')
		return altName

	@staticmethod
	def iteration_complexes(data_list, name_list, **kwargs):
		do_ranking = kwargs.get('do_ranking', False)
		output_folder = kwargs.get('output_folder','PATH')

		relevant_complexes = figure4a.load_relevant_complexes()
		concatanated_complex_dfs = list()
		for complex_id in relevant_complexes:
			print('*************************')
			print(complex_id)
			altName = figure4a.get_alternative_name(complex_id)

			concat_list = list()
			for d,n in zip(data_list, name_list):
				sub = d[d.ComplexName == complex_id]
				if len(sub)>0:
					sub['fileName'] = pd.Series([n]*len(sub), index = sub.index)
					concat_list.append(sub)
			complex_df = pd.concat(concat_list)
			complex_df.index = [item.upper() for item in list(complex_df.index)]
			proteins = list(set(complex_df.index))			

			df = figure4a.prepare_variance_dataframe(name_list, data_list, complex_df, proteins)
			if do_ranking == True:
				df = figure4a.rank_dataset(df, proteins, name_list)
			df, dfList = figure4a.get_zscores(df)
			if complex_id.find('26S')!=-1:
				df = figure4a.manage_proteasome_data(df)

			protein_list = list(df.columns)
			label_list = list()
			median_list = list()
			for c,col in enumerate(df.columns):
				median_list.append(np.mean(utilsFacade.finite(list(np.array(df)[:,c]))))
				label_list.append(col)
			median_list, label_list = zip(*sorted(zip(median_list, label_list), reverse = False))
			df = df[list(label_list)]
			if complex_id.find('Kornberg')!=-1:
				complex_id = 'Kornbergs mediator (SRB) complex'

			df1 = df.T.copy()
			df1['complex_id'] = pd.Series([complex_id]*len(df1), index = df1.index)
			concatanated_complex_dfs.append(df1)
			print('plot_heatmap')
			print('*************************')

			figure4a.plot_heatmap(df, complex_id, altName, output_folder = output_folder)

		complex_data = pd.concat(concatanated_complex_dfs)
		complex_data.to_csv(output_folder + 'fig4a_underlyingData_complexes_RNA_newData_' + time.strftime('%Y%m%d') + '.tsv',
							sep = '\t')
		return complex_data


class figure4b:
	@staticmethod
	def execute(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		print('FIGURE4B: main_figure4b_26SProteasome_analysis')
		figure4b.main_figure4b_26SProteasome_analysis()

	@staticmethod
	def main_figure4b_26SProteasome_analysis(**kwargs):
		output_folder = kwargs.get('output_folder','PATH')

		print("load_boxplot_data")
		sig_data_gygi3, sig_data_bp = figure4b.load_boxplot_data()

		print('plot_boxplot_complexVariance')
		figure4b.plot_boxplot_complexVariance(sig_data_gygi3,
											  output_folder = output_folder)

	@staticmethod
	def load_boxplot_data(**kwargs):
		folder = kwargs.get('folder','PATH')

		data = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_stoch_gygi3.tsv.gz')
		sig_data_gygi3 = data[data.ComplexID=="26S Proteasome"]

		data = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_stoch_battle_protein.tsv.gz')
		sig_data_bp = data[data.ComplexID=="26S Proteasome"]
		return sig_data_gygi3,sig_data_bp

	@staticmethod
	def plot_boxplot_complexVariance(sig_data, **kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('folder','PATH')
		
		dataset = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_stoch_gygi3.tsv.gz')
		dataset = dataset[dataset.ComplexID=="26S Proteasome"]
		quantCols = utilsFacade.get_quantCols(sig_data)
		dataset = dataset[quantCols]
		proteinList = list(dataset.index)
		dataset.index = proteinList
		dataset = dataset.drop_duplicates()

		complexID = "26S Proteasome"
		sub = sig_data[sig_data.ComplexID==complexID]
		sub = sub.sort_values("relative_variance")
		pvalueDict = sub.to_dict()["levene_pval.adj"]
		varianceDict = sub.to_dict()["relative_variance"]
		mean_variance = np.mean(utilsFacade.finite(sub["relative_variance"]))
		stateDict = sub.to_dict()["state"]
		quant_sub = sub[quantCols].T
		dsub = dataset.T[quant_sub.columns]
		dsub = dsub.T.drop_duplicates()
		dsub = dsub.T

		original_dataList = list()
		dataList1 = list()
		proteins1 = list()
		dataList2 = list()
		proteins = list()
		proteins2 = ["Psmb7","Psmd10","Psmb6","Psmb5","Psmb8","Psmd9","Psmb9","Psmb10"]
		for c,col in enumerate(dsub.columns):
			if col!="Psma8" and col!="Psmd4":#not enough datapoints/coverage
				if col in proteins2:
					dataList2.append(utilsFacade.finite(dsub.iloc[:,c]))
					proteins.append(col)
				else:
					dataList1.append(utilsFacade.finite(dsub.iloc[:,c]))
					proteins1.append(col)
		proteins2 = proteins

		immuno_proteasome = ["Psmb5","Psmb6","Psmb7","Psmb8","Psmb9","Psmb10"]


		sns.set(context='notebook', style='white', 
			palette='deep', font='Liberation Sans', font_scale=1, 
			color_codes=False, rc=None)
		plt.rcParams["axes.grid"] = True

		plt.clf()
		fig = plt.figure(figsize=(15,6))
		gs = gridspec.GridSpec(3,3)
		ax = plt.subplot(gs[0:,0:2])
		ax.axhline(1, color = 'k', linestyle = '--')
		ax.axhline(0.5, color = 'k', linestyle = '--')
		ax.axhline(-1, color = 'k', linestyle = '--')
		ax.axhline(0, color = 'k', linestyle = '--')
		ax.axhline(-0.5, color = 'k', linestyle = '--')
		bp=ax.boxplot(dataList1,notch=0,sym="",vert=1,patch_artist=True,widths=[0.8]*len(dataList1))
		plt.setp(bp['medians'], color="black")
		plt.setp(bp['whiskers'], color="black")
		plt.setp(bp['whiskers'], color="black")
		for i,patch in enumerate(bp['boxes']):
			protein=proteins1[i]
			x = numpy.random.normal(i+1, 0.04, size=len(dataList1[i]))
			if protein in immuno_proteasome:
				patch.set_facecolor("orange")	
				patch.set_edgecolor("orange")
			else:
				patch.set_facecolor("grey")	
				patch.set_edgecolor("grey")
			patch.set_alpha(0.8)
		plt.xticks(list(xrange(len(proteins1)+1)))
		ax.set_xlim(0,len(proteins1)+1)
		ax.set_xticklabels([""]+proteins1,fontsize=13,rotation=90)
		ax.set_ylabel("Relative Abundance",fontsize=13)
		plt.tick_params(axis="y",which="both",bottom="off",top="off",labelsize=12)
		complex_name = complexID.replace("/","_")
		complex_name = complex_name.replace(":","_")
		ax.set_ylim(-1.5,1.5)

		ax = plt.subplot(gs[0:,2:])
		ax.axhline(1, color = 'k', linestyle = '--')
		ax.axhline(0.5, color = 'k', linestyle = '--')
		ax.axhline(-1, color = 'k', linestyle = '--')
		ax.axhline(-0.5, color = 'k', linestyle = '--')
		ax.axhline(0, color = 'k', linestyle = '--')
		bp = ax.boxplot(dataList2,notch=0,sym="",vert=1,patch_artist=True,widths=[0.8]*len(dataList2))
		plt.setp(bp['medians'], color="black")
		plt.setp(bp['whiskers'], color="black")
		plt.setp(bp['whiskers'], color="black")
		for i,patch in enumerate(bp['boxes']):
			protein = proteins2[i]
			x = numpy.random.normal(i+1, 0.04, size=len(dataList2[i]))
			if protein in immuno_proteasome:
				patch.set_facecolor("orange")	
				ax.scatter(x,dataList2[i],color='white', alpha=0.9,edgecolor="brown",s=20)		
			elif protein.find("a")!=-1 or protein.find("b")!=-1:
				patch.set_facecolor("#95DCEC")	
				ax.scatter(x,dataList2[i],color='white', alpha=0.9,edgecolor="darkblue",s=20)		
			elif protein.find("c")!=-1 or protein.find("d")!=-1:
				patch.set_facecolor("grey")	
				ax.scatter(x,dataList2[i],color='white', alpha=0.9,edgecolor="black",s=20)		
			patch.set_edgecolor("black")
			patch.set_alpha(0.8)
		plt.xticks(list(xrange(len(proteins2)+1)))
		ax.set_xlim(0,len(proteins2)+1)
		ax.set_xticklabels([""]+[item.upper() for item in proteins2],fontsize=13,rotation=90)
		ax.set_yticklabels([])
		plt.tick_params(axis="y",which="both",bottom="off",top="off",labelsize=12)
		complex_name = complexID.replace("/","_")
		complex_name = complex_name.replace(":","_")
		ax.set_ylim(-1.5,1.5)
		plt.savefig(output_folder + "fig4b_stochiometry_data_complexes_26SProteasome_gygi3.pdf",
					bbox_inches="tight",dpi=600)


if __name__ == "__main__":
	## EXECUTE FIGURE4
	figure4a.execute(folder = sys.argv[1], output_folder = sys.argv[2])
	figure4b.execute(folder = sys.argv[1], output_folder = sys.argv[2])








