# ------------------------------------------------------------
# Author: Natalie Romanov
# ------------------------------------------------------------

class supp_figure2a:
	@staticmethod
	def execute(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')
		
		print('SUPP-FIGURE2A: main_supp_figure2a_bins_boxplots_effect')
		supp_figure2a.main_supp_figure2a_bins_boxplots_effect(output_folder = output_folder)

	@staticmethod
	def main_supp_figure2a_bins_boxplots_effect(**kwargs):
		output_folder = kwargs.get('output_folder','PATH')

		print('load_data')
		data_dict = supp_figure2a.load_data(output_folder)

		print('make_boxplots')
		supp_figure2a.make_boxplots(data_dict, output_folder)

	@staticmethod
	def load_data(folder):
		complex_data_mann = DataFrameAnalyzer.open_in_chunks(folder,'suppFig2a_wpc_mann_all_log2_complex_data.tsv.gz')
		complex_data_gygi2 = DataFrameAnalyzer.open_in_chunks(folder,'suppFig2a_wpc_gygi2_complex_data.tsv.gz')
		complex_data_gygi3 = DataFrameAnalyzer.open_in_chunks(folder,'suppFig2a_wpc_gygi3_complex_data.tsv.gz')
		complex_data_gygi1 = DataFrameAnalyzer.open_in_chunks(folder,'suppFig2a_wpc_gygi1_complex_data.tsv.gz')
		complex_data_bp = DataFrameAnalyzer.open_in_chunks(folder,'suppFig2a_wpc_battle_protein_complex_data.tsv.gz')
		complex_data_br = DataFrameAnalyzer.open_in_chunks(folder,'suppFig2a_wpc_battleRNA_complex_data.tsv.gz')
		complex_data_bribo = DataFrameAnalyzer.open_in_chunks(folder,'suppFig2a_wpc_battleRibo_complex_data.tsv.gz')
		complex_data_coloCancer = DataFrameAnalyzer.open_in_chunks(folder,'suppFig2a_wpc_coloCancer_complex_data.tsv.gz')
		complex_data_breastCancer = DataFrameAnalyzer.open_in_chunks(folder,'suppFig2a_wpc_breastCancer_complex_data.tsv.gz')
		complex_data_ovarianCancer = DataFrameAnalyzer.open_in_chunks(folder,'suppFig2a_wpc_ovarianCancer_complex_data.tsv.gz')

		return {'mann':complex_data_mann,
				'gygi2':complex_data_gygi2,
				'gygi3':complex_data_gygi3,
				'gygi1':complex_data_gygi1,
				'bp': complex_data_bp,
				'br':complex_data_br,
				'bribo': complex_data_bribo,
				'colo': complex_data_coloCancer,
				'breast':complex_data_breastCancer,
				'ovarian':complex_data_ovarianCancer}
	
	@staticmethod
	def bin_abundances_variances(complex_data, keyword):
		complex_data = complex_data.copy()
		complex_data = complex_data.T
		complex_data = complex_data.sort_values(keyword)

		zscore_list = list(complex_data[keyword])
		corr_list = list(complex_data['corr'])
		bin_list = list()
		for i in xrange(0,5):
			bin_list.append([(i)*20, (20+(i)*20)])
		bin_list = [(0,25),(25,50),(50,75),(75,100)]
		zscore_dict = dict()
		for binL in bin_list:
			c1,c2 = binL
			temp = corr_list[c1:c2]
			zscore_dict.setdefault(str(c1)+':'+str(c2),[]).append(utilsFacade.finite(temp))
		return zscore_dict

	@staticmethod
	def make_boxplots(data_dict, folder):
		complex_data_gygi3 = data_dict['gygi3']
		complex_data_gygi1 = data_dict['gygi1']
		complex_data_gygi2 = data_dict['gygi2']
		complex_data_bp = data_dict['bp']
		complex_data_br = data_dict['br']
		complex_data_bribo = data_dict['bribo']
		complex_data_mann = data_dict['mann']
		complex_data_coloCancer = data_dict['colo']
		complex_data_breastCancer = data_dict['breast']
		complex_data_ovarianCancer = data_dict['ovarian']

		bin_list = list()
		for i in xrange(0,5):
			bin_list.append(str((i)*20)+':'+str((20+(i)*20)))
		bin_list = ['0:25','25:50','50:75','75:100']

		sns.set(context='notebook', style='white', 
			palette='deep', font='Liberation Sans', font_scale=1, 
			color_codes=False, rc=None)
		plt.rcParams["axes.grid"] = True

		name_list = ['gygi1','gygi2','gygi3','battleRNA',
					 'battleRibo','battle_protein','mann_all_log2',
					 'coloCancer','breastCancer','ovarianCancer']
		data_list = [complex_data_gygi1, complex_data_gygi2,
					 complex_data_gygi3, complex_data_br,
					 complex_data_bribo, complex_data_bp,
					 complex_data_mann, complex_data_coloCancer,
					 complex_data_breastCancer, complex_data_ovarianCancer]

		for n,d in zip(name_list, data_list):
			d = d.T
			d = d[d.num>=0]
			d = d.T

			zscore_dict_abundances = supp_figure2a.bin_abundances_variances(d,'zscore')
			zscore_dict_variances = supp_figure2a.bin_abundances_variances(d,'var_zscore')

			dataList = list()
			var_dataList = list()
			median_abundances = list()
			median_variances = list()
			for dbin in bin_list:
				dataList.append(zscore_dict_abundances[dbin][0])
				var_dataList.append(zscore_dict_variances[dbin][0])
				median_abundances.append(np.mean(zscore_dict_abundances[dbin][0]))
				median_variances.append(np.mean(zscore_dict_variances[dbin][0]))

			positions = list(np.arange(1,len(bin_list)/5.0+5,0.5)[0:(len(bin_list))])
			widths = [0.4]*len(bin_list)

			plt.clf()
			fig = plt.figure(figsize = (7,4))
			ax = fig.add_subplot(121)
			bp = ax.boxplot(dataList,notch=0,sym="",vert=1,
							patch_artist=True,widths=widths,
							positions = positions)
			plt.setp(bp['medians'], color="black")
			plt.setp(bp['whiskers'], color="black")
			nList = list()
			for i,patch in enumerate(bp['boxes']):
				patch.set_facecolor("lightgrey")	
				patch.set_edgecolor("black")
				patch.set_alpha(0.6)
				sub_group = dataList[i]
				x = numpy.random.normal(positions[i], 0.04, size=len(dataList[i]))
				ax.scatter(x,sub_group, color='white', alpha=0.9,edgecolor="black",s=40)
				nList.append(len(dataList[i]))
			pvals = list()
			combiList = utilsFacade.getCombinations(np.arange(0,len(dataList),1))
			for combi in combiList:
				c1,c2 = combi
				pval = scipy.stats.ttest_ind(dataList[c1],dataList[c2])[1]
				pvals.append(str(c1)+':'+str(c2)+'='+str(pval))
			ax.set_title('\n'.join(pvals), fontsize =5)
			ax.set_ylim(-0.5,1)
			ax.set_xticklabels(nList)

			ax = fig.add_subplot(122)
			bp = ax.boxplot(var_dataList,notch=0,sym="",vert=1,
							patch_artist=True,widths=widths, positions = positions)
			plt.setp(bp['medians'], color="black")
			plt.setp(bp['whiskers'], color="black")
			nList = list()
			for i,patch in enumerate(bp['boxes']):
				patch.set_facecolor("lightgrey")	
				patch.set_edgecolor("black")
				patch.set_alpha(0.6)
				sub_group = var_dataList[i]
				x = numpy.random.normal(positions[i], 0.04, size=len(var_dataList[i]))
				ax.scatter(x,sub_group, color='white', alpha=0.9,edgecolor="black",s=40)
				nList.append(len(dataList[i]))
			df_var = d.T[['corr','var_zscore']]
			df_var = df_var.dropna()
			pvals = list()
			combiList = utilsFacade.getCombinations(np.arange(0,len(var_dataList),1))
			for combi in combiList:
				c1,c2 = combi
				pval = scipy.stats.ttest_ind(var_dataList[c1],var_dataList[c2])[1]
				pvals.append(str(c1) + ':' + str(c2) + '=' + str(pval))
			ax.set_title('\n'.join(pvals), fontsize =5)
			ax.set_ylim(-0.5,1)
			ax.set_xticklabels(nList)

			plt.savefig(folder + 'suppFig2a_boxplots_overview_' + n + '.pdf',
						bbox_inches = 'tight', dpi = 400)

class supp_figure2b:
	@staticmethod
	def execute(**kwargs):
		folder = kwargs.get('folder',"PATH")
		output_folder = kwargs.get('output_folder',"PATH")
		
		print('SUPP-FIGURE2B: main_supp_figure2b_boxplot_randomControl')
		supp_figure2b.main_supp_figure2b_boxplot_randomControl(output_folder = output_folder)

	@staticmethod
	def main_supp_figure2b_boxplot_randomControl(**kwargs):
		folder = kwargs.get('folder',"PATH")
		output_folder = kwargs.get('output_folder',"PATH")
		num_subunits = kwargs.get('num_subunits',0)

		print('prepare_data_for_boxplot')
		bp_data_dict = supp_figure2b.prepare_data_for_boxplot(output_folder, num_subunits = num_subunits)

		print('plot_boxplot')
		supp_figure2b.plot_boxplot(bp_data_dict, output_folder)

	@staticmethod
	def prepare_data_for_boxplot(folder, **kwargs):
		num_subunits = kwargs.get('num_subunits',0)

		name_list = ['gygi1','gygi2','gygi3','battle_rna',
					 'battle_ribo','battle_protein',
					 'mann','tcgaColo','tcgaBreast',
					 'tcgaOvarian']

		bp_data_dict = dict()
		for n in name_list:
			print(n)
			real_complexDict = json.load(open(folder + n + '_real_complexSet_dict.json','rb'))
			randomized_complexDict = json.load(open(folder + n + '_randomized_complexSet_dict.json','rb'))
			randomized_shuffled_complexDict = json.load(open(folder + n + '_randomized_complexSet_reshuffledData_dict.json','rb'))

			real_complexes = set(real_complexDict.keys())
			randomized_complexes = set(randomized_complexDict.keys())
			randomized_shuffled_complexes = set(randomized_shuffled_complexDict.keys())
			complexList = set.intersection(*[real_complexes,randomized_complexes,randomized_shuffled_complexes])

			randomized_shuffled_list = list()
			randomized_list = list()
			real_list = list()
			for complexID in complexList:
				if complexID!='ALL':
					randomized_shuffled = randomized_shuffled_complexDict[complexID]
					randomized = randomized_complexDict[complexID]
					real = real_complexDict[complexID]
					randomized_shuffled_list.append(randomized_shuffled)
					randomized_list.append(randomized)
					real_list.append(real)
			randomized_shuffled_list = utilsFacade.finite(utilsFacade.flatten(randomized_shuffled_list))
			randomized_list = utilsFacade.finite(utilsFacade.flatten(randomized_list))
			real_list = filter(lambda a:str(a)!='nan',real_list)
			bp_data_dict[n] = [randomized_shuffled_list, randomized_list, real_list]
		return bp_data_dict

	@staticmethod
	def plot_boxplot(bp_data_dict, output_folder):

		sns.set(context='notebook', style='white', 
			palette='deep', font='Liberation Sans', font_scale=1, 
			color_codes=False, rc=None)
		plt.rcParams["axes.grid"] = True

		name_list = ['gygi1','gygi2','gygi3','battle_rna',
					 'battle_ribo','battle_protein',
					 'mann','tcgaColo','tcgaBreast','tcgaOvarian']

		positions = [1,1.5,2]
		plt.clf()
		fig = plt.figure(figsize = (15,4))
		gs = gridspec.GridSpec(6,30)
		for n,name in enumerate(name_list):
			dataList = bp_data_dict[name]# [bp_data_dict[name][0], bp_data_dict[name][2]]
			ax = plt.subplot(gs[0:,(3*n):(3*n)+3])
			bp = ax.boxplot(dataList,notch=0,sym="",vert=1,patch_artist=True,
					widths=[0.45,0.45,0.45], positions = positions)
			plt.setp(bp['medians'], color="black")
			plt.setp(bp['whiskers'], color="black")
			pval1 = scipy.stats.ranksums(dataList[0], dataList[-1])[1]
			pval2 = scipy.stats.ranksums(dataList[1], dataList[-1])[1]
			print(name, pval1, pval2)

			for i,patch in enumerate(bp['boxes']):
				patch.set_facecolor("lightgrey")	
				patch.set_edgecolor("black")
				patch.set_alpha(0.6)
				x = numpy.random.normal(positions[i], 0.04, size=len(dataList[2]))
				ax.scatter(x, random.sample(dataList[i],len(dataList[2])), color='white',
						   alpha=0.9, edgecolor="black",s=10)
			ax.set_ylim(-0.7,1)
			if n > 0:
				ax.set_yticklabels([])
		plt.savefig(output_folder + 'suppFigure2B_boxplot_random_complex_variability_all.pdf',
					bbox_inches = 'tight', dpi = 400)


if __name__ == "__main__":
	## EXECUTE SUPPFIGURE2
	supp_figure2a.execute(folder = sys.argv[1], output_folder = sys.argv[2])
	supp_figure2b.execute(folder = sys.argv[1], output_folder = sys.argv[2])
	


