# ------------------------------------------------------------
# Author: Natalie Romanov
# ------------------------------------------------------------

class supp_figure1:
	@staticmethod
	def execute(**kwargs):
		folder = kwargs.get('folder',"PATH")
		output_folder = kwargs.get('output_folder',"PATH")

		nameList = ["gygi1","gygi3","battle_protein",
					"wu","mann_all_log2","tiannan",
					"primateRNA","primatePRO","gygi2",
					"battle_rna","battle_ribo",
					'coloCa','tcga_breast','tcga_ovarian',
					'bxd_protein']

		print("load_data")
		corr_dict = supp_figure1.load_data(nameList, output_folder)

		print('make_supplementary_plot')
		supp_figure1.make_supplementary_plot(nameList, corr_dict, output_folder)

		'''
		print("get_significancies")
		pval_dict, new_pval_dict = supp_figure1.get_significancies(nameList, output_folder)
		'''

	@staticmethod
	def load_json_file(folder, file_name):
		with open(folder + file_name) as data_file:    
			correlation_values = json.load(data_file)
		return correlation_values

	@staticmethod
	def load_data(nameList, output_folder):
		folder = output_folder

		corr_dict = dict()
		for name in nameList:
			for ty in ['all','700','other']:
				if ty=='all' or ty=='700':
					file_name = '_'.join(['string_correlations',ty,name]) + '.json'
					correlation_values = supp_figure1.load_json_file(folder, file_name)
					corr_dict.setdefault(name + ":" + ty,[])
					corr_dict[name + ":" + ty] = correlation_values
				else:
					file_name = '_'.join(['other_string_correlations',ty,name]) + '.json'
					correlation_values = supp_figure1.load_json_file(folder, file_name)
					corr_dict.setdefault(name + ":" + ty,[])
					corr_dict[name + ":" + ty] = correlation_values
		return corr_dict

	@staticmethod
	def make_supplementary_plot(nameList, corr_dict, output_folder):
		folder = output_folder

		for name in nameList:
			print(name)
			dataList = list()
			for ty in ["700","all","other"]:
				key = name + ":" + ty
				dataList.append(corr_dict[key])
				if ty == '700':
					best_string_correlation_values = list(corr_dict[key])
				elif ty == 'all':
					string_correlation_values = list(corr_dict[key])
				else:
					other_correlation_values = list(corr_dict[key])

			sns.set_style("white")
			plt.rcParams["axes.grid"] = True

			plt.clf()
			fig = plt.figure(figsize=(5,5))
			gs = gridspec.GridSpec(10,10)
			ax = plt.subplot(gs[0:7,0:])
			plottingFacade.func_plotDensities_border(ax, other_correlation_values, facecolor="grey")
			plottingFacade.func_plotDensities_border(ax, string_correlation_values, facecolor="orange")
			plottingFacade.func_plotDensities_border(ax, best_string_correlation_values, facecolor="#EE7600")
			ax.set_xlim(-1,1)
			ax.set_xticklabels([])
			plt.tick_params(axis="y",which="both",left="off",right="off",labelsize=10)
		
			ax = plt.subplot(gs[7:,0:])
			bp = ax.boxplot(dataList,notch=0,sym="",vert=0,
							patch_artist=True,widths=(0.5,0.5,0.5))
			plt.setp(bp['medians'], color="black")
			plt.setp(bp['whiskers'], color="black",linestyle="-")
			for i,patch in enumerate(bp['boxes']):
				if i==0:
					patch.set_facecolor("#EE7600")	
				elif i==1:
					patch.set_facecolor("orange")	
				else:
					patch.set_facecolor("#D8D8D8")	
				patch.set_edgecolor("black")
				patch.set_alpha(1)
			ax.set_xlim(-1,1)
			ax.set_yticklabels([])
			plt.tick_params(axis="y",which="both",left="off",
							right="off",labelsize=15)
			plt.savefig(output_folder + "suppFig1a_string_correlation_recovery_" + name + ".pdf",
						bbox_inches="tight", dpi = 400)

	@staticmethod
	def get_significancies(nameList, output_folder):
		folder = output_folder

		pval_dict = dict()
		for name in nameList:
			print(name)
			dataList = list()
			for ty in ["700","all","other"]:
				key = name + ":" + ty
				dataList.append(corr_dict[key])
				if ty == '700':
					best_string_correlation_values = list(corr_dict[key])
				elif ty == 'all':
					string_correlation_values = list(corr_dict[key])
				else:
					other_correlation_values = list(corr_dict[key])

			pval_distribution_mann = list()
			for i in xrange(1,1000):
				pval_mann_all = scipy.stats.mannwhitneyu(random.sample(other_correlation_values,1000),
														 random.sample(string_correlation_values,1000))[1]
				pval_distribution_mann.append(pval_mann_all)

			pval_distribution_mann_best = list()
			for i in xrange(1,1000):
				pval_mann_best = scipy.stats.mannwhitneyu(random.sample(other_correlation_values,1000),
														  random.sample(best_string_correlation_values,1000))[1]
				pval_distribution_mann_best.append(pval_mann_best)

			pval_distribution_ttest = list()
			pval_distribution_wc = list()
			for i in xrange(1,10000):
				pval_ttest_all = scipy.stats.ttest_ind(random.sample(other_correlation_values,1000),
													   random.sample(string_correlation_values,1000))[1]
				pval_wc_all = scipy.stats.ranksums(random.sample(other_correlation_values,1000),
												   random.sample(string_correlation_values,1000))[1]
				pval_distribution_wc.append(pval_wc_all)
				pval_distribution_ttest.append(pval_ttest_all)
			pvalCorrs_wc = utilsFacade.correct_pvalues(pval_distribution_wc)
			pvalCorrs_ttest = utilsFacade.correct_pvalues(pval_distribution_ttest)
			pval_dict.setdefault(name + ":all",[]).append({"wc":pvalCorrs_wc,
														   "ttest":pvalCorrs_ttest})

			pval_distribution_ttest = list()
			pval_distribution_wc = list()
			for i in xrange(1,10000):
				pval_ttest_all = scipy.stats.ttest_ind(random.sample(other_correlation_values,1000),
													   random.sample(best_string_correlation_values,1000))[1]
				pval_wc_all = scipy.stats.ranksums(random.sample(other_correlation_values,1000),
												   random.sample(best_string_correlation_values,1000))[1]
				pval_distribution_wc.append(pval_wc_all)
				pval_distribution_ttest.append(pval_ttest_all)
			pvalCorrs_wc = utilsFacade.correct_pvalues(pval_distribution_wc)
			pvalCorrs_ttest = utilsFacade.correct_pvalues(pval_distribution_ttest)
			pval_dict.setdefault(name + ":700",[]).append({"wc":pvalCorrs_wc,
														   "ttest":pvalCorrs_ttest})

		new_pval_dict = dict()
		for key in pval_dict:
			ttest = list(pval_dict[key][0]["ttest"])
			wc = list(pval_dict[key][0]["wc"])
			new_pval_dict.setdefault(key,[])
			new_pval_dict[key].append({"ttest":ttest,"wc":wc})
		with open(output_folder + 'suppFig1A_stringRecovery_correlations_pvalues.json', 'w') as outfile:
		    json.dump(new_pval_dict, outfile)
		return pval_dict,new_pval_dict

	@staticmethod
	def read_significancies(output_folder):
		with open(output_folder + 'suppFig1A_stringRecovery_correlations_pvalues.json', 'rb') as infile:
		    pval_dict = json.load(infile)

if __name__ == "__main__":
	## EXECUTE SUPPFIGURE1
	supp_figure1.execute(folder = sys.argv[1], output_folder = sys.argv[2])
	supp_figure1.execute(folder = sys.argv[1], output_folder = sys.argv[2])



