# ------------------------------------------------------------
# Author: Natalie Romanov
# ------------------------------------------------------------

class file_Loader:

	@staticmethod
	def load_data(**kwargs):
		folder = kwargs.get('folder', 'PATH')

		gygi3 = DataFrameAnalyzer.open_in_chunks(folder, "dataset_gygi3_remapped.tsv.gz")
		gygi2 = DataFrameAnalyzer.open_in_chunks(folder, "dataset_gygi2_remapped.tsv.gz")
		gygi1 = DataFrameAnalyzer.open_in_chunks(folder, "dataset_gygi1_remapped.tsv.gz")
		battle_protein = DataFrameAnalyzer.open_in_chunks(folder, "dataset_battle_protein_remapped.tsv.gz")
		battle_ribo = DataFrameAnalyzer.open_in_chunks(folder, "dataset_battle_ribo_remapped.tsv.gz")
		battle_rna = DataFrameAnalyzer.open_in_chunks(folder, "dataset_battle_rna_remapped.tsv.gz")
		mann_all_log2 = DataFrameAnalyzer.open_in_chunks(folder, "dataset_mann_all_log2_remapped.tsv.gz")
		wu = DataFrameAnalyzer.open_in_chunks(folder, "dataset_wu_remapped.tsv.gz")
		tiannan = DataFrameAnalyzer.open_in_chunks(folder, "dataset_tiannan_remapped.tsv.gz")
		primatePRO = DataFrameAnalyzer.open_in_chunks(folder, "dataset_primatePRO_remapped.tsv.gz")
		primateRNA = DataFrameAnalyzer.open_in_chunks(folder, "dataset_primateRNA_remapped.tsv.gz")
		tcga_ovarian = DataFrameAnalyzer.open_in_chunks(folder, "dataset_tcga_ovarian_remapped.tsv.gz")
		tcga_breast = DataFrameAnalyzer.open_in_chunks(folder, "dataset_tcga_breast_remapped.tsv.gz")
		bxd_protein = DataFrameAnalyzer.open_in_chunks(folder, "dataset_bxdMouse_remapped.tsv.gz")
		colo_cancer = DataFrameAnalyzer.open_in_chunks(folder, "dataset_coloCa_remapped.tsv.gz")

		data_dict = {'gygi1':gygi1,
					 'gygi2':gygi2,
					 'gygi3': gygi3,
					 'battle_protein':battle_protein,
					 'battle_ribo': battle_ribo,
					 'battle_rna':battle_rna,
					 'wu':wu,
					 'tiannan':tiannan,
					 'colo_cancer':coloCa, 
					 'tcga_breast':tcga_breast, 
					 'tcga_ovarian':tcga_ovarian,
					 'bxd_protein':bxd_protein, 
					 'primateRNA':primateRNA, 
					 'primatePRO':primatePRO,
					 'mann':mann}

		return data_dict

	@staticmethod
	def load_housekeeping_data(folder):
		file_name = "housekeeping_genes.txt"
		housekeeping_data = DataFrameAnalyzer.getFile(folder,file_name)
		housekeeping_genes = housekeeping_data.index
		return housekeeping_genes

	@staticmethod
	def load_essential_genes(folder):
		file_name = "essentiality_genes.txt"
		essentiality_data = DataFrameAnalyzer.getFile(folder,file_name)
		kbm7_essentiality_data = essentiality_data[essentiality_data["KBM7 adjusted p-value"] < 0.05]
		k562_essentiality_data = essentiality_data[essentiality_data["K562 adjusted p-value"] < 0.05]
		jiyoye_essentiality_data = essentiality_data[essentiality_data["Jiyoye adjusted p-value"] < 0.05]
		raji_essentiality_data = essentiality_data[essentiality_data["Raji adjusted p-value"] < 0.05]
		kbm7_essential_genes = kbm7_essentiality_data.index
		k562_essential_genes = k562_essentiality_data.index
		jiyoye_essential_genes = jiyoye_essentiality_data.index
		raji_essential_genes = raji_essentiality_data.index
		return kbm7_essential_genes, k562_essential_genes, jiyoye_essential_genes, raji_essential_genes

	@staticmethod
	def load_string_data(folder, species):
		file_name = species + "_STRING_geneName_per_protein_allInteractingProteins_dict.json"
		with open(string_folder + file_name) as json_data:
			string_dict_all = json.load(json_data)

		file_name = species + "_STRING_only500_geneName_per_protein_allInteractingProteins_dict.json"
		with open(string_folder + file_name) as json_data:
			string_dict_500 = json.load(json_data)

		file_name = species + "_STRING_only700_geneName_per_protein_allInteractingProteins_dict.json"
		with open(string_folder + file_name) as json_data:
			string_dict_700 = json.load(json_data)

		return string_dict_all, string_dict_500, string_dict_700

	@staticmethod
	def get_complex_dictionary(**kwargs):
		folder = kwargs.get('folder','PATH')

		complexDict = DataFrameAnalyzer.read_pickle(folder + 'complex_dictionary.pkl')
		return complexDict

class step6_preparation(object):
	def __init__(self, data, name, species, **kwargs):
		self.data = data
		self.name = name
		self.species = species

		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('folder','PATH')

		complexDict = file_Loader.get_complex_dictionary(folder)

		print('load_modules')
		self.load_modules(folder, data, species)

		print('load_correlations')
		self.load_correlations()

		print('load_features')
		self.load_features()

		print("get_summary_auc")
		self.get_summary_auc()

		print("export_dataframes")
		self.export_dataframes(output_folder)

		print("make_auc_dict")
		self.auc_dict = self.make_auc_dict(output_folder)

		print('prepare_auc_matrix')
		self.prepare_auc_matrix(folder)

		print('summarize_essentiality')
		self.summarize_essentiality(folder)

	@staticmethod
	def get_corr_data(data):
		quant_cols = utilsFacade.filtering(list(data.columns), 'quant_')
		quant_data = data[quant_cols]
		corrData = quant_data.corr()
		return corrData

	def load_modules(self, folder, data, species):
		self.housekeeping_genes = file_Loader.load_housekeeping_data(folder)
		essential_genes = file_Loader.load_essential_genes(folder)
		essential_genes = self.kbm7_essential_genes, self.k562_essential_genes, self.jiyoye_essential_genes, self.raji_essential_genes

		self.corrData = step6_preparation.get_corr_data(data)

		pathway_data = data[data['reactome']>0]#"PATHWAY_REACTOME"
		if data.shape == pathway_data.shape:
			pathway_data = data[data['reactome']!=""]#"PATHWAY_REACTOME"

		self.pathway_proteins = list(pathway_data.index)

		other_data = data[~data.index.isin(pathway_proteins)]
		self.other_genes = filter(lambda a:str(a)!="nan",list(set(other_data.index)))
		
		self.string_dict_all, self.string_dict_500, self.string_dict_700 = file_Loader.load_string_data(folder, species)

	def load_correlations(self):
		print("get_complex_correlation_values")
		self.get_complex_correlation_values()
		print("get_pathway_correlation_values")
		self.get_pathway_correlation_values()
		print("get_housekeeping_correlation_values")
		self.get_housekeeping_correlation_values()
		print("get_compartment_correlation_values")
		self.get_compartment_correlation_values()
		print("get_essentiality_correlation_values")
		self.get_essentiality_correlation_values()
		print("get_chromosome_correlation_values")
		self.get_chromosome_correlation_values()
		print("get_other_correlation_values")
		self.get_other_correlation_values()
		print("get_STRING_correlation_values")
		self.get_STRING_correlation_values()

	def load_features(self):
		print("prepare_complex_df")
		self.prepare_complex_df()
		print("prepare_pathway_df")
		self.prepare_pathway_df()

		print("prepare_feature_df: COMPARTMENT")
		self.compartment_df = self.prepare_feature_df(self.compartment_correlation_values, self.pathway_other_correlations)
		
		print("prepare_feature_df: ESSENTIALITY")
		self.kbm7_essential_df = self.prepare_feature_df(self.kbm7_essential_correlation_values,self.pathway_other_correlations)
		self.k562_essential_df = self.prepare_feature_df(self.k562_essential_correlation_values,self.pathway_other_correlations)
		self.jiyoye_essential_df = self.prepare_feature_df(self.jiyoye_essential_correlation_values,self.pathway_other_correlations)
		self.raji_essential_df = self.prepare_feature_df(self.raji_essential_correlation_values,self.pathway_other_correlations)
		
		print("prepare_feature_df: CHROMOSOME")
		self.chromosome_df = self.prepare_feature_df(self.chromosome_correlation_values,self.pathway_other_correlations)
		
		print("prepare_feature_df: HOUSEKEEPING")
		self.housekeeping_df = self.prepare_feature_df(self.housekeeping_correlation_values,self.pathway_other_correlations)		
		
		print("prepare_feature_df: STRING")
		self.string_all_df = self.prepare_feature_df(self.string_correlation_values_all,self.pathway_other_correlations)
		self.string_500_df = self.prepare_feature_df(self.string_correlation_values_500,self.pathway_other_correlations)
		self.string_700_df = self.prepare_feature_df(self.string_correlation_values_700,self.pathway_other_correlations)

	def get_housekeeping_correlation_values(self):
		corrData = self.corrData
		all_genes = self.all_genes
		housekeeping_genes = self.housekeeping_genes

		housekeeping_genes_reformatted = list()
		for h in housekeeping_genes:
			housekeeping_genes_reformatted.append(h.strip())
		housekeeping_genes = housekeeping_genes_reformatted
		housekeeping_genes = [item[0] + item[1:].lower() for item in housekeeping_genes]

		housekeeping_genes = list(set(housekeeping_genes).difference(set(all_genes)))
		overlapping_genes = list(set(corrData.index).intersection(set(housekeeping_genes)))
		sub = corrData[overlapping_genes].T[overlapping_genes].T
		housekeeping_correlation_values = utilsFacade.get_correlation_values(sub)
		self.housekeeping_correlation_values = housekeeping_correlation_values

	def get_complex_correlation_values(self, complexDict):
		corrData = self.corrData
		species = self.species

		complex_correlation_values = list()
		all_genes = list()
		for complexID in complexDict:
			human_genes = complexDict[complexID][species + "GeneNames"]
			overlapping_genes = list(set(human_genes).intersection(set(corrData.index)))
			sub_corr = corrData[overlapping_genes].T[overlapping_genes].T
			corr_values = utilsFacade.get_correlation_values(sub_corr)
			complex_correlation_values.append(corr_values)
			all_genes.append(overlapping_genes)
		self.complex_correlation_values = utilsFacade.flatten(complex_correlation_values)
		self.all_genes = utilsFacade.flatten(all_genes)
		
	def get_pathway_correlation_values(self):
		data = self.data
		all_genes = self.all_genes
		pathway_data = self.pathway_data

		pathway_data = pathway_data[~pathway_data.index.isin(all_genes)]
		pathway_list = list()
		for pat in list(set(pathway_data['reactome'])):#"PATHWAY_REACTOME"
			for item in pat.split(",name:"):
				try:
					pathway_list.append(item.split("id:")[1])
				except:
					continue
		pathway_list = list(set(pathway_list))
		quant_cols = utilsFacade.get_quantCols(data)
		pathway_correlation_values = list()
		for pat in pathway_list:
			sub = pathway_data[pathway_data['reactome'].str.contains(pat)]#"PATHWAY_REACTOME"
			quant_sub = sub[quant_cols]
			corr_data = quant_sub.T.corr()
			corr_values = utilsFacade.get_correlation_values(corr_data)
			pathway_correlation_values.append(corr_values)
		pathway_correlation_values = utilsFacade.flatten(pathway_correlation_values)
		return pathway_correlation_values

	def get_other_correlation_values(self):
		corrData = self.corrData
		other_genes = self.other_genes
		complex_correlation_values = self.complex_correlation_values
		pathway_correlation_values = self.pathway_correlation_values 

		overlapping_genes = list(set(other_genes).intersection(set(corrData.index)))
		otherData = corrData[overlapping_genes]
		other_correlations = utilsFacade.get_correlation_values(otherData)
		try:
			self.complex_other_correlations = random.sample(other_correlations,len(complex_correlation_values))
			self.pathway_other_correlations = random.sample(other_correlations,len(pathway_correlation_values))
		except:
			self.complex_other_correlations = other_correlations
			self.pathway_other_correlations = other_correlations
		
	def get_essentiality_correlation_values(self):
		corrData = self.corrData
		all_genes = self.all_genes
		kbm7_essential_genes = self.kbm7_essential_genes
		k562_essential_genes = self.k562_essential_genes
		jiyoye_essential_genes = self.jiyoye_essential_genes
		raji_essential_genes = self.raji_essential_genes

		kbm7_essential_genes = list(set(kbm7_essential_genes).difference(set(all_genes)))
		k562_essential_genes = list(set(k562_essential_genes).difference(set(all_genes)))
		jiyoye_essential_genes = list(set(jiyoye_essential_genes).difference(set(all_genes)))
		raji_essential_genes = list(set(raji_essential_genes).difference(set(all_genes)))

		kbm7_essential_genes = [item[0] + item[1:].lower() for item in kbm7_essential_genes]
		k562_essential_genes = [item[0] + item[1:].lower() for item in k562_essential_genes]
		jiyoye_essential_genes = [item[0] + item[1:].lower() for item in jiyoye_essential_genes]
		raji_essential_genes = [item[0] + item[1:].lower() for item in raji_essential_genes]

		kbm7_genes = list(set(corrData.index).intersection(set(kbm7_essential_genes)))
		k562_genes = list(set(corrData.index).intersection(set(k562_essential_genes)))
		jiyoye_genes = list(set(corrData.index).intersection(set(jiyoye_essential_genes)))
		raji_genes = list(set(corrData.index).intersection(set(raji_essential_genes)))

		kbm7_corr = corrData[kbm7_genes].T[kbm7_genes].T
		k562_corr = corrData[k562_genes].T[k562_genes].T
		jiyoye_corr = corrData[jiyoye_genes].T[jiyoye_genes].T
		raji_corr = corrData[raji_genes].T[raji_genes].T

		self.kbm7_essential_correlation_values = utilsFacade.get_correlation_values(kbm7_corr)
		self.k562_essential_correlation_values = utilsFacade.get_correlation_values(k562_corr)
		self.jiyoye_essential_correlation_values = utilsFacade.get_correlation_values(jiyoye_corr)
		self.raji_essential_correlation_values = utilsFacade.get_correlation_values(raji_corr)

	def get_compartment_correlation_values(self):
		data = self.data
		compartment_data = data[data["location_humanProteinAtlas"]>0]
		compartment_list = list()
		for com in list(compartment_data["location_humanProteinAtlas"]):
			for item in com.split(";"):
				compartment_list.append(item)
		compartment_list = list(set(compartment_list))
		quant_cols = utilsFacade.get_quantCols(compartment_data)
		compartment_correlation_values = list()
		for compartment in compartment_list:
			sub = data[data["location_humanProteinAtlas"] == compartment][quant_cols]
			sub_corr = sub.T.corr()
			corr_values = utilsFacade.get_correlation_values(sub_corr)
			compartment_correlation_values.append(corr_values)
		self.compartment_correlation_values = utilsFacade.flatten(compartment_correlation_values)

	def get_chromosome_correlation_values(self):
		data = self.data
		data_chrom = data.copy()
		chrom_list = list()
		for chrom_pos in list(data_chrom["GENOMIC_POS"]):
			try:
				chrom_list.append(chrom_pos.split("chr:")[1].split(",start")[0])
			except:
				chrom_list.append(np.nan)
		data_chrom["chrom"] = pd.Series(chrom_list, index = data_chrom.index)
		quant_cols = utilsFacade.get_quantCols(data_chrom)
		chromosome_correlation_values = list()
		for chrom in map(str,list(xrange(1,23)))+["X"]:
			sub = data_chrom[data_chrom["chrom"] == chrom][quant_cols]
			sub_corr = sub.T.corr()
			corr_values = utilsFacade.get_correlation_values(sub_corr)
			chromosome_correlation_values.append(corr_values)
		self.chromosome_correlation_values = utilsFacade.flatten(chromosome_correlation_values)

	def get_STRING_correlation_values(self):
		corrData = self.corrData
		string_dict_all = self.string_dict_all
		string_dict_500 = self.string_dict_500
		string_dict_700 = self.string_dict_700

		overlapping_all = list(set(corrData.index).intersection(set(string_dict_all.keys())))
		overlapping_500 = list(set(corrData.index).intersection(set(string_dict_500.keys())))
		overlapping_700 = list(set(corrData.index).intersection(set(string_dict_700.keys())))

		string_correlation_values_all = list()
		string_correlation_values_500 = list()
		string_correlation_values_700 = list()
		for key in overlapping_700:
			interactors = string_dict_700[key]
			overlapping_interactors = list(set(corrData.index).intersection(set(interactors)))
			if len(overlapping_interactors) > 0:
				sub_corr = corrData[overlapping_interactors].T[key].T
			string_correlation_values_700.append(sub_corr)
		for key in overlapping_500:
			interactors = string_dict_500[key]
			overlapping_interactors = list(set(corrData.index).intersection(set(interactors)))
			if len(overlapping_interactors) > 0:
				sub_corr = corrData[overlapping_interactors].T[key].T
			string_correlation_values_500.append(sub_corr)
		for key in overlapping_all:
			interactors = string_dict_all[key]
			overlapping_interactors = list(set(corrData.index).intersection(set(interactors)))
			if len(overlapping_interactors) > 0:
				sub_corr = corrData[overlapping_interactors].T[key].T
			string_correlation_values_all.append(sub_corr)
		self.string_correlation_values_700 = utilsFacade.flatten(string_correlation_values_700)
		self.string_correlation_values_500 = utilsFacade.flatten(string_correlation_values_500)
		self.string_correlation_values_all = utilsFacade.flatten(string_correlation_values_all)

	def prepare_pathway_df(self):
		pathway_correlation_values = self.pathway_correlation_values
		pathway_other_correlations = self.pathway_other_correlations

		df = pd.DataFrame({"correlations": pathway_correlation_values + pathway_other_correlations,
						   "label": [True]*len(pathway_correlation_values) + [False]*len(pathway_other_correlations)})
		df = df.sort_values("correlations", ascending = True)

		count_list = list()
		dec_list = list(utilsFacade.frange(-1,1,0.01))
		max_corr = df["correlations"].max()
		for count,dec in enumerate(list(reversed(np.array(dec_list)))):
			if dec < max_corr:
				break
		idx = count + 1

		df_small = pd.DataFrame({"correlations":dec_list[:-idx]})
		fpr_list = list()
		tpr_list = list()
		recall_list = list()
		accuracy_list = list()
		precision_list = list()
		all_true_positives = len(df[df.label == True])
		all_false_positives = len(df[df.label == False])
		for t in dec_list[:-idx]:
			rejected = df[df["correlations"]<t]
			sub = df[df["correlations"]>=t]

			total_info = self.calculate_tp_fp(sub, rejected)
			true_positives, false_positives, false_negatives, true_negatives, tpr, fpr, precision, recall, accuracy = total_info

			fpr_list.append(fpr)
			tpr_list.append(tpr)
			recall_list.append(recall)
			accuracy_list.append(accuracy)
			precision_list.append(precision)
		df_small["fpr"] = pd.Series(fpr_list, index = df_small.index)
		df_small["tpr"] = pd.Series(tpr_list, index = df_small.index)
		df_small["recall"] = pd.Series(recall_list, index = df_small.index)
		df_small["accuracy"] = pd.Series(accuracy_list, index = df_small.index)
		df_small["precision"] = pd.Series(precision_list, index = df_small.index)
		auc_value = sklearn.metrics.auc(fpr_list,tpr_list)
		df_small["auc"] = [auc_value]*len(df_small)
		self.pathway_df = df_small

	def prepare_complex_df(self):
		complex_correlation_values = self.complex_correlation_values
		complex_other_correlations = self.complex_other_correlations

		df = pd.DataFrame({"correlations": complex_correlation_values + complex_other_correlations,
						   "label": [True]*len(complex_correlation_values) + [False]*len(complex_other_correlations)})
		df = df.sort_values("correlations", ascending = True)

		fpr_list = list()
		tpr_list = list()
		recall_list = list()
		accuracy_list = list()
		precision_list = list()
		all_true_positives = len(df[df.label == True])
		all_false_positives = len(df[df.label == False])
		count = 0
		for i,row in df.iterrows():
			rejected = df[0:count]
			sub = df[count:]

			total_info = self.calculate_tp_fp(sub, rejected)
			true_positives, false_positives, false_negatives, true_negatives, tpr, fpr, precision, recall, accuracy = total_info

			fpr_list.append(fpr)
			tpr_list.append(tpr)
			recall_list.append(recall)
			accuracy_list.append(accuracy)
			precision_list.append(precision)
			count +=1 
		df["tpr"] = pd.Series(tpr_list,index = df.index)
		df["fpr"] = pd.Series(fpr_list,index = df.index)
		df["auc"] = [sklearn.metrics.auc(fpr_list,tpr_list)] * len(df)
		df["recall"] = pd.Series(recall_list, index = df.index)
		df["accuracy"] = pd.Series(accuracy_list, index = df.index)
		df["precision"] = pd.Series(precision_list, index = df.index)
		df.sort_values("correlations",ascending=False).head()
		self.complex_df = df

	def prepare_feature_df(self,correlation_values,other_correlations):
		compartment_correlation_values = correlation_values
		pathway_other_correlations = other_correlations
		if len(pathway_other_correlations) > len(compartment_correlation_values):
			compartment_other_correlations = random.sample(other_correlations,len(compartment_correlation_values))
		else:
			compartment_other_correlations = pathway_other_correlations

		df = pd.DataFrame({"correlations": compartment_correlation_values + compartment_other_correlations,
						   "label": [True]*len(compartment_correlation_values) + [False]*len(compartment_other_correlations)})
		df = df.sort_values("correlations", ascending = True)

		count_list = list()
		dec_list = list(utilsFacade.frange(-1,1,0.01))
		max_corr = df["correlations"].max()
		for count,dec in enumerate(list(reversed(np.array(dec_list)))):
			if dec < max_corr:
				break
		idx = count + 1

		df_small = pd.DataFrame({"correlations":dec_list[:-idx]})
		fpr_list = list()
		tpr_list = list()
		recall_list = list()
		accuracy_list = list()
		precision_list = list()
		all_true_positives = len(df[df.label == True])
		all_false_positives = len(df[df.label == False])
		for t in dec_list[:-idx]:
			rejected = df[df["correlations"]<t]
			sub = df[df["correlations"]>=t]

			total_info = self.calculate_tp_fp(sub, rejected)
			true_positives, false_positives, false_negatives, true_negatives, tpr, fpr, precision, recall, accuracy = total_info

			fpr_list.append(fpr)
			tpr_list.append(tpr)
			recall_list.append(recall)
			accuracy_list.append(accuracy)
			precision_list.append(precision)

		df_small["fpr"] = pd.Series(fpr_list, index = df_small.index)
		df_small["tpr"] = pd.Series(tpr_list, index = df_small.index)
		df_small["recall"] = pd.Series(recall_list, index = df_small.index)
		df_small["accuracy"] = pd.Series(accuracy_list, index = df_small.index)
		df_small["precision"] = pd.Series(precision_list, index = df_small.index)
		auc_value = sklearn.metrics.auc(fpr_list,tpr_list)
		df_small["auc"] = [auc_value]*len(df_small)

		return df_small

	def calculate_tp_fp(self, sub, rejected):
		true_positives = float(len(sub[sub.label == True]))
		false_positives = float(len(sub[sub.label == False]))
		false_negatives = float(len(rejected.label == True))
		true_negatives = float(len(rejected.label == False))
		tpr = true_positives/all_true_positives
		fpr = false_positives/all_false_positives
		precision = true_positives/(true_positives + false_positives)
		recall = true_positives/(true_positives + false_negatives)
		accuracy = (true_positives + true_negatives)/(true_positives + true_negatives + false_positives + false_negatives)

		total_info = (true_positives, false_positives, false_negatives, true_negatives,
					  tpr, fpr, precision, recall, accuracy)
		return total_info

	def get_summary_auc(self):
		complex_auc = self.complex_df["auc"][0]
		pathway_auc = self.pathway_df["auc"][0]
		compartment_auc = self.compartment_df["auc"][0]
		string_all_auc = self.string_all_df["auc"][0]
		string_500_auc = self.string_500_df["auc"][0]
		string_700_auc = self.string_700_df["auc"][0]
		chromosome_auc = self.chromosome_df["auc"][0]
		housekeeping_auc = self.housekeeping_df["auc"][0]
		kbm7_essential_auc = self.kbm7_essential_df["auc"][0]
		k562_essential_auc = self.k562_essential_df["auc"][0]
		raji_essential_auc = self.raji_essential_df["auc"][0]
		jiyoye_essential_auc = self.jiyoye_essential_df["auc"][0]

		auc_list = [complex_auc,pathway_auc,compartment_auc, string_all_auc,
					string_500_auc, string_700_auc, chromosome_auc, housekeeping_auc,
					kbm7_essential_auc, k562_essential_auc, raji_essential_auc,
					jiyoye_essential_auc]
		name_list = ["complex","pathway", "compartment", "string_all",
					 "string_500","string_700", "chromosome", "housekeeping",
					 "kbm7_essential", "k562_essential", "raji_essential", "jiyoye_essential"]
		self.auc_df = pd.DataFrame({"name":name_list,"auc":auc_list})

	def export_dataframes(self, output_folder):
		name = self.name
		auc_df = self.auc_df
		complex_df = self.complex_df
		pathway_df = self.pathway_df
		compartment_df = self.compartment_df
		string_all_df = self.string_all_df
		string_500_df = self.string_500_df
		string_700_df = self.string_700_df
		chromosome_df = self.chromosome_df
		housekeeping_df = self.housekeeping_df
		kbm7_essential_df = self.kbm7_essential_df
		k562_essential_df = self.k562_essential_df
		raji_essential_df = self.raji_essential_df
		jiyoye_essential_df = self.jiyoye_essential_df

		#creating roc_temporary files
		auc_df.to_csv(output_folder + name + "_figure1_data_AUC.tsv.gz",sep = "\t", compression = "gzip")
		complex_df.to_csv(output_folder + name + "_figure1_data_COMPLEX.tsv.gz",sep = "\t", compression = "gzip")
		pathway_df.to_csv(output_folder + name + "_figure1_data_PATHWAY.tsv.gz",sep = "\t", compression = "gzip")
		compartment_df.to_csv(output_folder + name + "_figure1_data_COMPARTMENT.tsv.gz",sep = "\t", compression = "gzip")
		string_all_df.to_csv(output_folder + name + "_figure1_data_STRING_ALL.tsv.gz",sep = "\t", compression = "gzip")
		string_500_df.to_csv(output_folder + name + "_figure1_data_STRING_500.tsv.gz",sep = "\t", compression = "gzip")
		string_700_df.to_csv(output_folder + name + "_figure1_data_STRING_700.tsv.gz",sep = "\t", compression = "gzip")
		chromosome_df.to_csv(output_folder + name + "_figure1_data_CHROMOSOME.tsv.gz",sep = "\t", compression = "gzip")
		housekeeping_df.to_csv(output_folder + name + "_figure1_data_HOUSEKEEPING.tsv.gz",sep = "\t", compression = "gzip")
		kbm7_essential_df.to_csv(output_folder + name + "_figure1_data_KBM7_ESSENTIALITY.tsv.gz",sep = "\t", compression = "gzip")
		k562_essential_df.to_csv(output_folder + name + "_figure1_data_K562_ESSENTIALITY.tsv.gz",sep = "\t", compression = "gzip")
		jiyoye_essential_df.to_csv(output_folder + name + "_figure1_data_JIYOYE_ESSENTIALITY.tsv.gz",sep = "\t", compression = "gzip")
		raji_essential_df.to_csv(output_folder + name + "_figure1_data_RAJI_ESSENTIALITY.tsv.gz",sep = "\t", compression = "gzip")

	def make_auc_dict(self, folder):
		name_list = ["battle_protein","tiannan","battle_ribo", "battle_rna",
					 "primateRNA", "primatePRO", "wu","mann_all_log2",
					 "gygi1","gygi2","gygi3","tcga_ovarian",'tcga_breast',
					 'bxd_protein','colo_cancer']		

		auc_dict = dict((e1,list()) for e1 in name_list)
		for name in name_list:
			file_name = name + "_figure1_data_AUC.tsv.gz"
			data = DataFrameAnalyzer.open_in_chunks(ffolder, file_name)
			data.index = data["name"]
			data = data.drop(["string_all","string_500"],0)
			data_dict = data["auc"].to_dict()
			auc_dict[name] = data_dict

		DataFrameAnalyzer.to_pickle(auc_dict, folder + 'fig2a_auc_dictionary.pkl')
		return auc_dict

	def prepare_auc_matrix(self,folder):
		auc_dict = self.auc_dict

		linkage_method = 'average'
		metric = 'euclidean'

		auc_matrix = pd.DataFrame(auc_dict)
		auc_matrix = auc_matrix.replace(np.nan, 0.5)
		auc_matrix = utilsFacade.recluster_matrix(auc_matrix,
					 linkage_method = linkage_method, metric = metric)
		auc_matrix.to_csv(folder + 'fig2a_auc_matrix.tsv.gz', sep = '\t', compression = 'gzip')
		return auc_matrix

	def summarize_essentiality(self,folder):
		auc_df = self.auc_df

		auc_df = auc_df.T
		essential_columns = filter(lambda a:str(a).find("ess")!=-1, list(auc_df.columns))
		mean_values = np.mean(auc_df[essential_columns].T)
		auc_df["essential"] = mean_values
		auc_df = auc_df.drop(essential_columns,1)
		auc_matrix = auc_df.T
		auc_matrix.to_csv(folder + 'fig2a_auc_matrix.tsv.gz', sep = '\t', compression = 'gzip')
		return auc_matrix

class step6:
	@staticmethod
	def execute(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		data_dict = file_Loader.load_data(folder = folder)

		for name in data_dict.keys():
			data = data_dict[name]
			if name in ['gygi1','gygi2','gygi3','bxd_protein']:
				step6_preparation.execute(data, name, species, folder = folder, output_folder = output_folder) 
			else:
				step6_preparation.execute(data, name, species, folder = folder, output_folder = output_folder) 

class step6_figure2a:
	@staticmethod
	def execute(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('folder','PATH')

		print('FIGURE2A: main_figure2a_rocMatrix')
		step6_figure2a.main_figure2a_rocMatrix(folder = folder, output_folder = output_folder)

	@staticmethod
	def main_figure2a_rocMatrix(**kwargs):	
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('folder','PATH')

		print('get_auc_matrix')
		auc_matrix = step6_figure2a.get_auc_matrix(folder = folder)

		print("plot_auc_matrix")
		step6_figure2a.plot_auc_matrix(auc_matrix, output_folder = output_folder)

		print('export_underlying_data')
		step6_figure2a.export_underlying_data(auc_matrix, folder = folder, output_folder = output_folder)

	@staticmethod
	def get_auc_dict(**kwargs):
		folder = kwargs.get('folder','PATH')

		auc_dict = DataFrameAnalyzer.read_pickle(folder + 'fig2a_auc_dictionary.pkl')
		return auc_dict

	@staticmethod
	def get_specific_color_gradient(colormap,inputList, **kwargs):
		vmin = kwargs.get("vmin", False)
		vmax = kwargs.get("vmax", False)

		cm = plt.get_cmap(colormap)
		if type(inputList)==list:
			if vmin == False and vmax == False:
				cNorm = mpl.colors.Normalize(vmin=min(inputList), vmax=max(inputList))
			else:
				cNorm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
		else:
			if vmin == False and vmax == False:
				cNorm = mpl.colors.Normalize(vmin=inputList.min(), vmax=inputList.max())
			else:
				cNorm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
		scalarMap = mpl.cm.ScalarMappable(norm=cNorm, cmap=cm)
		scalarMap.set_array(inputList)
		colorList = scalarMap.to_rgba(inputList)
		return scalarMap,colorList

	@staticmethod
	def plot_auc_matrix(auc_df, **kwargs):
		output_folder = kwargs.get('folder','PATH')

		datasets = ["tcga_ovarian","battle_protein",'colo_cancer',"gygi3",
					'tcga_breast',"gygi1","mann_all_log2","primatePRO","wu",
					"battle_ribo","battle_rna","gygi2",'bxd_protein',
					"primateRNA","tiannan"]
		categories = ["chromosome","housekeeping","essential","pathway",
					  "compartment","string_700","complex"]
		data = auc_df.T
		data = data[categories]
		data = data.T[datasets].T

		data.index = ["TCGA Ovarian Cancer(P)",'Human Individuals(Battle,P)',
					  'TCGA Colorectal Cancer(P)','DO mouse strains(P)',
					  "TCGA Breast Cancer(P)",'Founder mouse strains (P)',
					  'Human cell lines(P)','Primate cells(P)',
					  'Human Individuals(Wu,P)', 'Human Individuals(RP)',
					  'Human Individuals(RS)','DO mouse strains(RS)',
					  'BXD80 mouse strains(P)','Primate cells(RS)',
					  'Kidney cancer cells(P)']
		data.columns = ['chromosome','housekeeping','essential',
						'pathway','compartment','STRING','complex']


		sns.set(context='notebook', style='white', 
			palette='deep', font='Liberation Sans', font_scale=1, 
			color_codes=False, rc=None)
		plt.rcParams["axes.grid"] = False

		plt.clf()
		x_mean_list = list()
		for col in data.columns:
			x_mean_list.append(np.mean(utilsFacade.finite(list(data[col])))-0.5)
		y_mean_list = list()
		for col in data.T.columns:
			y_mean_list.append(max(utilsFacade.finite(list(data.T[col])))-0.5)

		plt.clf()
		fig = plt.figure(figsize = (5,8))
		gs = gridspec.GridSpec(16,11)
		ax1_density = plt.subplot(gs[0:2,0:8])
		ax1_density.set_ylim(0.5,0.8)
		ax1_density.axhline(0.55, alpha = 0.6, color="grey", linestyle='--', linewidth = 0.2, zorder=1)
		ax1_density.axhline(0.6, alpha = 0.6, color="grey", linestyle='--', linewidth = 0.2, zorder=1)
		ax1_density.axhline(0.65, alpha = 0.6, color="grey", linestyle='--', linewidth = 0.2, zorder=1)
		ax1_density.axhline(0.7, alpha = 0.6, color="grey", linestyle='--', linewidth = 0.2, zorder=1)
		ax1_density.axhline(0.75, alpha = 0.6, color="grey", linestyle='--', linewidth = 0.2, zorder=1)
		scalarmap_x, colorList_x = figure2a.get_specific_color_gradient(plt.cm.Greys,
															   			np.array(xrange(len(x_mean_list))))
		ax1_density.bar(np.arange(len(x_mean_list)), 
			x_mean_list, 0.95, color = colorList_x, bottom = 0.5,
			edgecolor = "white", linewidth = 2, zorder = 3)
		plt.xticks(list(xrange(len(data.columns))))
		ax1_density.set_xticklabels([])
		ax1_density.set_xlim(0,len(data.columns))
		ax = plt.subplot(gs[2:10,0:8])
		scalarmap, colorList = figure2a.get_specific_color_gradient(plt.cm.RdBu,
														   			np.array(data), vmin = 0.4, vmax = 0.7)
		sns.heatmap(data, cmap = plt.cm.RdBu, vmin = 0.4,vmax = 0.7,
			linecolor = "white", linewidth = 2, cbar = False)
		y_mean_list = y_mean_list[::-1]
		ax2_density = plt.subplot(gs[2:10,8:10])
		plt.yticks(list(xrange(len(data.index))))
		ax2_density.set_ylim(0,len(data.index))
		ax2_density.set_xlim(0.5,0.85)
		scalarmap_y, colorList_y = figure2a.get_specific_color_gradient(plt.cm.Greys,
															   			np.array(xrange(len(y_mean_list))))
		plt.setp(ax2_density.get_xticklabels(), rotation = 90)
		ax2_density.set_yticklabels([])
		ax2_density.axvline(0.7, color = "red", linestyle = "--", linewidth = 0.5)
		ax2_density.axvline(0.55, alpha = 0.6, color="grey", linestyle='--', linewidth = 0.2, zorder=1)
		ax2_density.axvline(0.6, alpha = 0.6, color="grey", linestyle='--', linewidth = 0.2, zorder=1)
		ax2_density.axvline(0.65, alpha = 0.6, color="grey", linestyle='--', linewidth = 0.2, zorder=1)
		ax2_density.axvline(0.75, alpha = 0.6, color="grey", linestyle='--', linewidth = 0.2, zorder=1)
		ax2_density.axvline(0.8, alpha = 0.6, color="grey", linestyle='--', linewidth = 0.2, zorder=1)
		ax2_density.barh(np.arange(len(y_mean_list)), y_mean_list,
						 0.95, color = colorList_y, left = 0.5,
						 edgecolor = "white", linewidth = 2, zorder = 3)
		ax_category = plt.subplot(gs[2:10,10:11])
		df = pd.DataFrame({"color":["green"]*7+["lightgreen"]+["magenta"]*2+["green","magenta","green"]})
		df = pd.DataFrame({'color':14*[1]})
		sns.heatmap(df, cbar = False, linewidth = 2, linecolor = "white")
		ax_category.axis("off")
		ax_cbar = plt.subplot(gs[13:14,0:8])
		cbar = fig.colorbar(scalarmap, cax = ax_cbar, 
			   orientation = "horizontal")
		cbar.set_label("Area under curve (AUC)")
		plt.savefig(output_folder + "fig2a_auc_matrix.pdf", bbox_inches = "tight", dpi=600)

class step6_figure2b:

	@staticmethod
	def execute(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('folder','PATH')

		print('FIGURE2B: main_figure2b_vignette_complexEffect')
		step6_figure2b.main_figure2b_vignette_complexEffect(folder = folder, output_folder = output_folder)

	@staticmethod
	def main_figure2b_vignette_complexEffect(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('folder','PATH')

		print('plot_complex_effect:tcga_ovarian')
		step6_figure2b.plot_complex_effect('tcga_ovarian')
		print('plot_complex_effect:battle_protein')
		step6_figure2b.plot_complex_effect('battle_protein')
		print('plot_complex_effect:gygi1')
		step6_figure2b.plot_complex_effect('gygi1')
		print('plot_complex_effect:gygi3')
		step6_figure2b.plot_complex_effect('gygi3')
		print('plot_complex_effect:tcga_breast')
		step6_figure2b.plot_complex_effect('tcga_breast')
		print('plot_complex_effect:colo_cancer')
		step6_figure2b.plot_complex_effect('colo_cancer')

	@staticmethod
	def get_data(name, **kwargs):
		folder = kwargs.get('folder','PATH')

		complex_data = DataFrameAnalyzer.open_in_chunks(folder, name + '_figure1_data_COMPLEX.tsv.gz')
		return complex_data

	@staticmethod
	def plot_complex_effect(name, **kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('folder','PATH')

		complex_data = step6_figure2b.get_data(name, folder = folder)
		other_correlations = utilsFacade.finite(list(complex_data[complex_data.label==False]["correlations"]))
		complex_correlations = utilsFacade.finite(list(complex_data[complex_data.label==True]["correlations"]))

		pval_list1 = list()
		for i in xrange(1,1000):
			corr1 = random.sample(complex_correlations,100)
			corr2 = random.sample(other_correlations,100)
			odds1, pval1 = scipy.stats.mannwhitneyu(corr1, corr2)
			pval_list1.append(pval1)
		print(np.mean(pval_list1))

		sns.set(context='notebook', style='white', 
			palette='deep', font='Liberation Sans', font_scale=1, 
			color_codes=False, rc=None)
		plt.rcParams["axes.grid"] = True

		plt.clf()
		fig = plt.figure(figsize = (7,5))
		ax = fig.add_subplot(111)
		ax.set_ylabel('Density', fontsize=12)
		ax.set_xlabel('correlation coefficient (pearson)', fontsize=12)
		ax.hist(other_correlations, bins = 50, color='grey', alpha =0.3, normed = 1)
		plottingFacade.func_plotDensities_border(ax, other_correlations, 
												 linewidth = 2, alpha = 1, facecolor = 'grey')
		ax.hist(complex_correlations, bins = 50, color='#AF2D2D', alpha =0.3, normed =1)
		plottingFacade.func_plotDensities_border(ax, complex_correlations, 
												 linewidth = 2, alpha = 1, facecolor = '#AF2D2D')
		plottingFacade.make_full_legend(ax,['n(other)='+str(len(other_correlations)),
			'n(complex)='+str(len(complex_correlations)),
			'pvalComplex(Wilcoxon)='+str(np.mean(pval_list1))],['grey']*3, loc = 'upper left')
		ax.set_xlim(-1,1)
		plt.savefig(output_folder + 'fig2b_' + name + '_complex_density_effect.pdf',
					bbox_inches = 'tight', dpi = 400)

class step6_export:
	
	@staticmethod
	def execute(**kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('folder','PATH')

		print('get_auc_dict')
		auc_matrix = step6_export.get_auc_dict(folder = folder)

		print('export_underlying_data')
		step6_export.export_underlying_data(auc_matrix, folder = folder, output_folder = output_folder)

	@staticmethod
	def get_auc_dict(**kwargs):
		folder = kwargs.get('folder','PATH')

		auc_dict = DataFrameAnalyzer.read_pickle(folder + 'fig2a_auc_dictionary.pkl')
		auc_matrix = pd.DataFrame(auc_dict)
		return auc_matrix

	@staticmethod
	def export_underlying_data(auc_matrix, **kwargs):
		folder = kwargs.get('folder','PATH')
		output_folder = kwargs.get('output_folder','PATH')

		category_list = ['chromosome', 'compartment', 'complex',
						 'housekeeping', 'jiyoye_essentiality', 
						 'k562_essentiality','kbm7_essentiality',
						 'pathway', 'raji_essentiality','string_700']
		name_list = ["battle_protein","tiannan","battle_ribo", "battle_rna",
					 "primateRNA", "primatePRO", "wu","mann_all_log2",
					 "gygi1","gygi2","gygi3","tcga_ovarian",'tcga_breast',
					 'bxd_protein','colo_cancer']		

		concat_list = list()
		names_columns = list()
		category_columns = list()
		for name in name_list:
			for category in category_list:
				print(name,category)
				file_name = name + '_figure1_data_' + category.upper() + '.tsv.gz'
				data = DataFrameAnalyzer.open_in_chunks(folder, file_name)
				concat_list.append(data)
				for i in xrange(len(data)):
					names_columns.append(name)
					category_columns.append(category)

		data = pd.concat(concat_list)
		data['name'] = pd.Series(names_columns, index = data.index)
		data['category'] = pd.Series(category_columns, index = data.index)

		#add numbers of interactions
		category_list = ['chromosome', 'compartment', 'complex',
						 'housekeeping', 'jiyoye', 'k562',
						 'kbm7', 'pathway', 'raji',
						 'STRING_700']
		name_list = ["battle_protein","tiannan","battle_ribo", "battle_rna",
					 "primateRNA", "primatePRO", "wu","mann_all_log2",
					 "gygi1","gygi2","gygi3","tcga_ovarian",'tcga_breast',
					 'bxd_protein','coloCa']	

		num_dict = dict((e1,dict()) for e1 in name_list)
		for name in name_list:
			num_data = DataFrameAnalyzer.getFile(folder, 'numInteractions_aurocAnalysis_' + name + '_forMethods.tsv')
			num_data = num_data.T
			temp_dict = num_data.to_dict()
			num_dict[name] = dict((e1,dict()) for e1 in category_list)
			for category in category_list:
				num_dict[name][category] = temp_dict[category]

		pos_num_list = list()
		neg_num_list = list()
		real_neg_num_list = list()
		for i,row in data.iterrows():
			name = row['name']
			category = row['category']

			if category.endswith('_essentiality')==True:
				category = category.split('_')[0]
			if category == 'string_700':
				category = 'STRING_700'
			if name == 'colo_cancer':
				name = 'coloCa'

			pos_num_list.append('')
			neg_num_list.append('')
			real_neg_num_list.append('')

		data['pos_num'] = pd.Series(pos_num_list, index = data.index)
		data['neg_num'] = pd.Series(neg_num_list, index = data.index)
		data['real_neg_num'] = pd.Series(real_neg_num_list, index = data.index)

		data.to_csv(output_folder + 'suppTable2_fig2a_underlying_data_ROCanalysis.tsv',
					sep = '\t')

if __name__ == "__main__":
	## EXECUTE STEP6
	step6.execute(folder = sys.argv[1], output_folder = sys.argv[2])
	step6_figure2a.execute(folder = sys.argv[1], output_folder = sys.argv[2])
	step6_figure2b.execute(folder = sys.argv[1], output_folder = sys.argv[2])
	step6_export.execute(folder = sys.argv[1], output_folder = sys.argv[2])
