# ------------------------------------------------------------
# Author: Natalie Romanov
# ------------------------------------------------------------

class file_Loader(object):

	@staticmethod
	def load_step3_prep_data(folder):
		dat_gygi3 = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_gygi3.tsv.gz')
		dat_gygi2 = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_gygi2.tsv.gz')
		dat_gygi1 = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_gygi1.tsv.gz')
		dat_mann = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_mann.tsv.gz')
		dat_battle_rna = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_battle_rna.tsv.gz')
		dat_battle_ribo = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_battle_ribo.tsv.gz')
		dat_battle_protein = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_battle_protein.tsv.gz')
		dat_tcga_colo = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_tcga_colo.tsv.gz')
		dat_tcga_breast = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_tcga_breast.tsv.gz')
		dat_tcga_ovarian = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_tcga_ovarian.tsv.gz')

		data_dict = {'mann':dat_mann,
					 'gygi2':dat_gygi2,
					 'gygi3':dat_gygi3,
					 'gygi1':dat_gygi1,
					 'tcga_colo':dat_tcga_colo,
					 'battle_rna':dat_battle_rna,
					 'tcga_breast':dat_tcga_breast,
					 'tcga_ovarian':dat_tcga_ovarian,
					 'battle_ribo':dat_battle_ribo,
					 'battle_protein':dat_battle_protein}
		return data_dict

	@staticmethod
	def load_step3_figure_data(folder):
		complex_data_mann = DataFrameAnalyzer.open_in_chunks(folder,'suppFig2a_wpc_mann_all_log2_complex_data.tsv.gz')
		complex_data_gygi2 = DataFrameAnalyzer.open_in_chunks(folder,'suppFig2a_wpc_gygi2_complex_data.tsv.gz')
		complex_data_gygi3 = DataFrameAnalyzer.open_in_chunks(folder,'suppFig2a_wpc_gygi3_complex_data.tsv.gz')
		complex_data_gygi1 = DataFrameAnalyzer.open_in_chunks(folder,'suppFig2a_wpc_gygi1_complex_data.tsv.gz')
		complex_data_bp = DataFrameAnalyzer.open_in_chunks(folder,'suppFig2a_wpc_battle_protein_complex_data.tsv.gz')
		complex_data_br = DataFrameAnalyzer.open_in_chunks(folder,'suppFig2a_wpc_battleRNA_complex_data.tsv.gz')
		complex_data_bribo = DataFrameAnalyzer.open_in_chunks(folder,'suppFig2a_wpc_battleRibo_complex_data.tsv.gz')
		complex_data_coloCancer = DataFrameAnalyzer.open_in_chunks(folder,'suppFig2a_wpc_coloCancer_complex_data.tsv.gz')
		complex_data_breastCancer = DataFrameAnalyzer.open_in_chunks(folder,'suppFig2a_wpc_breastCancer_complex_data.tsv.gz')
		complex_data_ovarianCancer = DataFrameAnalyzer.open_in_chunks(folder,'suppFig2a_wpc_ovarianCancer_complex_data.tsv.gz')

		data_dict = {'mann':complex_data_mann,
					 'gygi2':complex_data_gygi2,
					 'gygi3':complex_data_gygi3,
					 'gygi1':complex_data_gygi1,
					 'bp': complex_data_bp,
					 'br':complex_data_br,
					 'bribo': complex_data_bribo,
					 'colo': complex_data_coloCancer,
					 'breast':complex_data_breastCancer
					 'ovarian':complex_data_ovarianCancer}
		return data_dict

class step3_preparation:

	@staticmethod
	def execute(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		print('load_step3_prep_data')
		data_dict = file_Loader.load_step3_prep_data(folder)

		for n in data_dict.keys():
			complex_specific_dict = step3_preparation.get_complex_specific_data(data_dict[n])
			complex_data = step3_preparation.export_data(complex_specific_dict,
														 n, output_folder = output_folder)

	@staticmethod
	def get_complex_specific_data(data):
		complexList = list(set(data["ComplexID"]))
		quantCols = utilsFacade.filtering(list(data.columns), 'quant_')
		complex_specific_dict = dict((e1,dict()) for e1 in complexList)
		median_complexList = list()
		varMedian_complexList = list()
		for complexID in complexList:
			sub = data[data.ComplexID == complexID]
			quant_sub = sub[quantCols].T
			medianList = list()
			for p,protein in enumerate(list(quant_sub.columns)):
				lst = list(quant_sub.iloc[:,p])
				medianList.append(np.median(utilsFacade.finite(lst)))
			corrData = quant_sub.corr()
			corr_values = utilsFacade.get_correlation_values(corrData)

			median_abundance = np.mean(medianList)
			variance_abundance = np.std(medianList)
			median_correlation = np.median(corr_values)
			complex_specific_dict[complexID] = dict((e1,list()) for e1 in ['abundance','corr',
																		   'var','num','var_zscore',
																		   'zscore'])
			complex_specific_dict[complexID]['abundance'] = median_abundance
			complex_specific_dict[complexID]['corr'] = median_correlation
			complex_specific_dict[complexID]['var'] = variance_abundance
			complex_specific_dict[complexID]['num'] = len(sub)
			median_complexList.append(median_abundance)
			varMedian_complexList.append(variance_abundance)

		zscores = list()
		median_mean = np.mean(utilsFacade.finite(median_complexList))
		median_std = np.std(utilsFacade.finite(median_complexList))
		for item in median_complexList:
			zscores.append(float(item - median_mean)/float(median_std))

		var_zscores = list()
		median_mean = np.mean(utilsFacade.finite(varMedian_complexList))
		median_std = np.std(utilsFacade.finite(varMedian_complexList))
		for item in varMedian_complexList:
			var_zscores.append(float(item - median_mean)/float(median_std))

		count = 0
		for complexID,z in zip(complexList,zscores):
			var_zscore = var_zscores[count]
			complex_specific_dict[complexID]['zscore'] = z
			complex_specific_dict[complexID]['var_zscore'] = var_zscore
			count+= 1
		return complex_specific_dict

	@staticmethod
	def export_data(complex_specific_dict, name, **kwargs):
		output_folder = kwargs.get('output_folder', 'PATH')

		complex_data = pd.DataFrame(complex_specific_dict)
		complex_data.to_csv(output_folder + 'suppFig2a_wpc_' + name + '_complex_data.tsv.gz',
							sep = '\t', compression = 'gzip')
		return complex_data

class step3_figure:
	@staticmethod
	def execute(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')
		
		print('SUPP-FIGURE2A: main_supp_figure2a_bins_boxplots_effect')
		step3_figure.main_supp_figure2a_bins_boxplots_effect(output_folder = output_folder)

	@staticmethod
	def main_supp_figure2a_bins_boxplots_effect(**kwargs):
		output_folder = kwargs.get('output_folder','PATH')

		print('load_data')
		data_dict = file_Loader.load_step3_figure_data(output_folder)

		print('make_boxplots')
		step3_figure.make_boxplots(data_dict, output_folder)

	@staticmethod
	def bin_abundances_variances(complex_data, keyword):
		complex_data = complex_data.copy()
		complex_data = complex_data.T
		complex_data = complex_data.sort_values(keyword)

		zscore_list = list(complex_data[keyword])
		corr_list = list(complex_data['corr'])
		bin_list = list()
		for i in xrange(0,5):
			bin_list.append([(i)*20, (20+(i)*20)])
		bin_list = [(0,25), (25,50), (50,75), (75,100)]
		zscore_dict = dict()
		for binL in bin_list:
			c1,c2 = binL
			temp = corr_list[c1:c2]
			zscore_dict.setdefault(str(c1) + ':' + str(c2),[]).append(utilsFacade.finite(temp))
		return zscore_dict

	@staticmethod
	def make_boxplots(data_dict, folder):
		complex_data_gygi3 = data_dict['gygi3']
		complex_data_gygi1 = data_dict['gygi1']
		complex_data_gygi2 = data_dict['gygi2']
		complex_data_bp = data_dict['bp']
		complex_data_br = data_dict['br']
		complex_data_bribo = data_dict['bribo']
		complex_data_mann = data_dict['mann']
		complex_data_coloCancer = data_dict['colo']
		complex_data_breastCancer = data_dict['breast']
		complex_data_ovarianCancer = data_dict['ovarian']

		bin_list = list()
		for i in xrange(0,5):
			bin_list.append(str((i)*20) + ':' + str((20+(i)*20)))
		bin_list = ['0:25','25:50','50:75','75:100']

		sns.set(context='notebook', style='white', 
			palette='deep', font='Liberation Sans',
			font_scale=1, color_codes=False, rc=None)
		plt.rcParams["axes.grid"] = True

		name_list = ['gygi1','gygi2','gygi3','battleRNA',
					 'battleRibo','battle_protein','mann_all_log2',
					 'coloCancer','breastCancer','ovarianCancer']
		data_list = [complex_data_gygi1, complex_data_gygi2,
					 complex_data_gygi3, complex_data_br,
					 complex_data_bribo, complex_data_bp,
					 complex_data_mann, complex_data_coloCancer,
					 complex_data_breastCancer,
					 complex_data_ovarianCancer]

		for n,d in zip(name_list, data_list):
			d = d.T
			d = d[d.num >= 0]
			d = d.T

			print(n)
			zscore_dict_abundances = step3_figure.bin_abundances_variances(d,'zscore')
			zscore_dict_variances = step3_figure.bin_abundances_variances(d,'var_zscore')

			dataList = list()
			var_dataList = list()
			median_abundances = list()
			median_variances = list()
			for dbin in bin_list:
				dataList.append(zscore_dict_abundances[dbin][0])
				var_dataList.append(zscore_dict_variances[dbin][0])
				median_abundances.append(np.mean(zscore_dict_abundances[dbin][0]))
				median_variances.append(np.mean(zscore_dict_variances[dbin][0]))

			positions = list(np.arange(1,len(bin_list)/5.0+5,0.5)[0:(len(bin_list))])
			widths = [0.4]*len(bin_list)

			plt.clf()
			fig = plt.figure(figsize = (7,4))
			ax = fig.add_subplot(121)
			bp = ax.boxplot(dataList,notch=0,sym="",vert=1, patch_artist=True,widths=widths,
							positions = positions)
			plt.setp(bp['medians'], color="black")
			plt.setp(bp['whiskers'], color="black")
			nList = list()
			for i,patch in enumerate(bp['boxes']):
				patch.set_facecolor("lightgrey")	
				patch.set_edgecolor("black")
				patch.set_alpha(0.6)
				sub_group = dataList[i]
				x = numpy.random.normal(positions[i], 0.04, size=len(dataList[i]))
				ax.scatter(x,sub_group, color='white', alpha=0.9,edgecolor="black",s=40)
				nList.append(len(dataList[i]))
			pvals = list()
			combiList = utilsFacade.getCombinations(np.arange(0,len(dataList),1))
			for combi in combiList:
				c1,c2 = combi
				pval = scipy.stats.ttest_ind(dataList[c1],dataList[c2])[1]
				pvals.append(str(c1) + ':' + str(c2) + '=' + str(pval))
			ax.set_title('\n'.join(pvals), fontsize =5)
			ax.set_ylim(-0.5,1)
			ax.set_xticklabels(nList)

			ax = fig.add_subplot(122)
			bp = ax.boxplot(var_dataList,notch=0,sym="",vert=1, patch_artist=True,widths=widths,
							positions = positions)
			plt.setp(bp['medians'], color="black")
			plt.setp(bp['whiskers'], color="black")
			nList = list()
			for i,patch in enumerate(bp['boxes']):
				patch.set_facecolor("lightgrey")	
				patch.set_edgecolor("black")
				patch.set_alpha(0.6)
				sub_group = var_dataList[i]
				x = numpy.random.normal(positions[i], 0.04, size=len(var_dataList[i]))
				ax.scatter(x,sub_group, color='white', alpha=0.9,edgecolor="black",s=40)
				nList.append(len(dataList[i]))
			df_var = d.T[['corr','var_zscore']]
			df_var = df_var.dropna()
			pvals = list()
			combiList = utilsFacade.getCombinations(np.arange(0,len(var_dataList),1))
			for combi in combiList:
				c1,c2 = combi
				pval = scipy.stats.ttest_ind(var_dataList[c1],var_dataList[c2])[1]
				pvals.append(str(c1) + ':' + str(c2) + '=' + str(pval))
			ax.set_title('\n'.join(pvals), fontsize =5)
			ax.set_ylim(-0.5,1)
			ax.set_xticklabels(nList)

			plt.savefig(folder + 'suppFig2a_boxplots_overview_' + n + '.pdf',
						bbox_inches = 'tight', dpi = 400)


if __name__ == "__main__":
	## EXECUTE STEP 3
	step3_preparation.execute(folder = sys.argv[1], output_folder = sys.argv[2])
	step3_figure.execute(folder = sys.argv[1], output_folder = sys.argv[2])
