
class step4_preparation:
	@staticmethod
	def execute(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')
		
		print('load_data')
		data_dict = step4_preparation.load_data(folder)
		for n in data_dict.keys():
			data = data_dict[n]
			all_random_corrDict, all_random_corrs = step4_preparation.get_randomized_complex_set(data, name, folder)
			all_random_reshuffled_corrDict, all_random_reshuffled_corrs = step4_preparation.get_randomized_complex_set_reshuffled_data(data, name, folder)
			all_real_corrDict, all_real_corrs = step4_preparation.get_real_complex_set(data, name, folder)

	@staticmethod
	def load_data(folder):
		dat_gygi3 = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_gygi3.tsv.gz')
		dat_gygi2 = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_gygi2.tsv.gz')
		dat_gygi1 = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_gygi1.tsv.gz')
		dat_mann = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_mann.tsv.gz')
		dat_battle_rna = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_battle_rna.tsv.gz')
		dat_battle_ribo = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_battle_ribo.tsv.gz')
		dat_battle_protein = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_battle_protein.tsv.gz')
		dat_tcga_colo = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_tcga_colo.tsv.gz')
		dat_tcga_breast = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_tcga_breast.tsv.gz')
		dat_tcga_ovarian = DataFrameAnalyzer.open_in_chunks(folder, 'complex_filtered_tcga_ovarian.tsv.gz')
		data_dict = {'mann':dat_mann,'gygi2':dat_gygi2,
					 'gygi3':dat_gygi3,'gygi1':dat_gygi1,
					 'tcga_colo':dat_tcga_colo,'tcga_breast':dat_tcga_breast,
					 'tcga_ovarian':dat_tcga_ovarian,'battle_rna':dat_battle_rna,
					 'battle_ribo':dat_battle_ribo,'battle_protein':dat_battle_protein}
		return data_dict

	@staticmethod
	def get_randomized_complex_set(data, name, output_folder):
		all_random_corrDict = dict()
		all_random_dict = dict()

		d = data.copy()
		n = name
		complexSet = list(set(d['ComplexName']))
		quant_list = utilsFacade.filtering(d, 'quant', condition = 'startswith')
		quant_data = d[quant_list]
		all_random_corrs = list()
		all_random_corrDict = dict((e1,list()) for e1 in complexSet)
		count = 0
		for complexID in complexSet:
			print(n, count, complexID)
			sub = d[d.ComplexName==complexID]
			complex_proteins = list(set(sub.index))
			other_sub = d[~d.index.isin(complex_proteins)]
			other_proteins = list(other_sub.index)
			corr_list = list()
			for i in xrange(0,10):
				random_proteins = random.sample(other_proteins, len(complex_proteins))
				quant_sub = quant_data.T[random_proteins].T
				quant_sub = quant_sub.drop_duplicates()
				corr_data = quant_sub.T.corr()
				corr_values = utilsFacade.get_correlation_values(corr_data)
				corr_list.append(np.median(corr_values))
			all_random_corrs.append(corr_list)
			all_random_corrDict[complexID] = corr_list
			count+=1
		all_random_corrs = utilsFacade.flatten(all_random_corrs)

		all_random_corrDict['ALL'] = utilsFacade.finite(all_random_corrs)
		json.dump(all_random_corrDict, open(output_folder + n + '_randomized_complexSet_dict.json','wb'))
		return all_random_corrDict, all_random_corrs

	@staticmethod
	def get_randomized_complex_set_reshuffled_data(data, name, output_folder):
		all_random_corrDict = dict()
		all_random_dict = dict()

		d = data.copy()
		n = name
		complexSet = list(set(d['ComplexName']))
		quant_list = utilsFacade.filtering(d, 'quant', condition = 'startswith')
		quant_data = d[quant_list]
		all_random_corrs = list()
		all_random_corrDict = dict((e1,list()) for e1 in complexSet)
		count = 0
		for complexID in complexSet:
			print(n, count, complexID)
			sub = d[d.ComplexName==complexID]
			complex_proteins = list(set(sub.index))
			other_sub = d[~d.index.isin(complex_proteins)]
			other_proteins = list(other_sub.index)
			corr_list = list()
			for i in xrange(0,10):
				random_proteins = random.sample(other_proteins, len(complex_proteins))
				quant_sub = quant_data.T[random_proteins].T
				quant_sub = quant_sub.drop_duplicates()
				df_list = list()
				for q,row in quant_sub.iterrows():
					temp = list(row)
					random.shuffle(temp)
					df_list.append(temp)
				new_quant_sub = pd.DataFrame(df_list)
				corr_data = new_quant_sub.T.corr()
				corr_values = utilsFacade.get_correlation_values(corr_data)
				corr_list.append(np.median(corr_values))
			all_random_corrs.append(corr_list)
			all_random_corrDict[complexID] = corr_list
			count+=1
		all_random_corrs = utilsFacade.flatten(all_random_corrs)

		all_random_corrDict['ALL'] = utilsFacade.finite(all_random_corrs)
		json.dump(all_random_corrDict,
			open(output_folder + n + '_randomized_complexSet_reshuffledData_dict.json','wb'))
		return all_random_corrDict, all_random_corrs

	@staticmethod
	def get_real_complex_set(data, name, output_folder):
		all_real_corrDict = dict()
		all_real_dict = dict()

		d = data.copy()
		n = name
		complexSet = list(set(d['ComplexName']))
		quant_list = utilsFacade.filtering(d, 'quant', condition = 'startswith')
		quant_data = d[quant_list]
		all_real_corrs = list()
		all_real_corrDict = dict((e1,list()) for e1 in complexSet)
		count = 0
		for complexID in complexSet:
			print(n, count, complexID)
			sub = d[d.ComplexName==complexID]
			complex_proteins = list(set(sub.index))
			quant_sub = quant_data.T[complex_proteins].T
			quant_sub = quant_sub.drop_duplicates()
			corr_data = quant_sub.T.corr()
			corr_values = utilsFacade.get_correlation_values(corr_data)
			all_real_corrs.append(np.median(corr_values))
			all_real_corrDict[complexID] = np.median(corr_values)
			count+=1

		all_real_corrDict['ALL'] = utilsFacade.finite(all_real_corrs)
		json.dump(all_real_corrDict,open(output_folder + n + '_real_complexSet_dict.json','wb'))
		return all_real_corrDict, all_real_corrs


class step4_figure:
	@staticmethod
	def execute(**kwargs):
		folder = kwargs.get('folder','PATH')
		num_subunits = kwargs.get('num_subunits',0)
		output_folder = kwargs.get('output_folder','PATH')
		
		print('SUPP-FIGURE2B: main_supp_figure2b_boxplot_randomControl')
		step4_figure.main_supp_figure2b_boxplot_randomControl(output_folder = output_folder,
															  num_subunits = num_subunits,
															  folder = folder)

	@staticmethod
	def main_supp_figure2b_boxplot_randomControl(**kwargs):
		folder = kwargs.get('folder','PATH')
		num_subunits = kwargs.get('num_subunits',0)
		output_folder = kwargs.get('output_folder','PATH')

		print('prepare_data_for_boxplot')
		bp_data_dict = step4_figure.prepare_data_for_boxplot(num_subunits = num_subunits,
					   folder = folder, output_folder = output_folder)
		
		print('plot_boxplot')
		step4_figure.plot_boxplot(bp_data_dict, output_folder)

	@staticmethod
	def prepare_data_for_boxplot(**kwargs):
		folder = kwargs.get('folder','PATH')
		num_subunits = kwargs.get('num_subunits',0)
		output_folder = kwargs.get('output_folder','PATH')

		name_list = ['gygi1','gygi2','gygi3','battle_rna','battle_ribo','battle_protein',
					 'mann','tcgaColo','tcgaBreast','tcgaOvarian']

		bp_data_dict = dict()
		for n in name_list:
			print(n)
			real_complexDict = json.load(open(folder + n + '_real_complexSet_dict.json','rb'))
			randomized_complexDict = json.load(open(folder + n + '_randomized_complexSet_dict.json','rb'))
			randomized_shuffled_complexDict = json.load(open(folder + n + '_randomized_complexSet_reshuffledData_dict.json','rb'))
			real_complexes = set(real_complexDict.keys())
			randomized_complexes = set(randomized_complexDict.keys())
			randomized_shuffled_complexes = set(randomized_shuffled_complexDict.keys())
			complexList = set.intersection(*[real_complexes, randomized_complexes, 
						  randomized_shuffled_complexes])

			randomized_shuffled_list = list()
			randomized_list = list()
			real_list = list()
			for complexID in complexList:
				if complexID!='ALL':
					randomized_shuffled = randomized_shuffled_complexDict[complexID]
					randomized = randomized_complexDict[complexID]
					real = real_complexDict[complexID]
					randomized_shuffled_list.append(randomized_shuffled)
					randomized_list.append(randomized)
					real_list.append(real)
			randomized_shuffled_list = utilsFacade.finite(utilsFacade.flatten(randomized_shuffled_list))
			randomized_list = utilsFacade.finite(utilsFacade.flatten(randomized_list))
			real_list = filter(lambda a:str(a)!='nan',real_list)

			bp_data_dict[n] = [randomized_shuffled_list, randomized_list, real_list]
		return bp_data_dict

	@staticmethod
	def plot_boxplot(bp_data_dict, output_folder):

		sns.set(context='notebook', style='white', 
			palette='deep', font='Liberation Sans', font_scale=1, 
			color_codes=False, rc=None)
		plt.rcParams["axes.grid"] = True

		name_list = ['gygi1','gygi2','gygi3','battle_rna','battle_ribo','battle_protein',
					 'mann','tcgaColo','tcgaBreast','tcgaOvarian']

		positions = [1,1.5,2]
		plt.clf()
		fig = plt.figure(figsize = (15,4))
		gs = gridspec.GridSpec(6,30)
		for n,name in enumerate(name_list):
			dataList = bp_data_dict[name]
			ax = plt.subplot(gs[0:,(3*n):(3*n)+3])
			bp = ax.boxplot(dataList,notch=0,sym="",vert=1,patch_artist=True,
					widths=[0.45,0.45,0.45], positions = positions)
			plt.setp(bp['medians'], color="black")
			plt.setp(bp['whiskers'], color="black")
			pval1 = scipy.stats.ranksums(dataList[0], dataList[-1])[1]
			pval2 = scipy.stats.ranksums(dataList[1], dataList[-1])[1]
			print(name, pval1, pval2)

			for i,patch in enumerate(bp['boxes']):
				patch.set_facecolor("lightgrey")	
				patch.set_edgecolor("black")
				patch.set_alpha(0.6)
				x = numpy.random.normal(positions[i], 0.04, size=len(dataList[2]))
				ax.scatter(x, random.sample(dataList[i],len(dataList[2])), color='white',
						   alpha=0.9, edgecolor="black",s=10)
			ax.set_ylim(-0.7,1)
			if n>0:
				ax.set_yticklabels([])
		plt.savefig(output_folder + 'suppFigure2B_boxplot_random_complex_variability_all.pdf',
					bbox_inches = 'tight', dpi = 400)

if __name__ == "__main__":
	step4_preparation.execute(folder = sys.argv[1], output_folder = sys.argv[2])
	step4_figure.execute(folder = sys.argv[1], output_folder = sys.argv[2])
