# ------------------------------------------------------------
# Author: Natalie Romanov
# ------------------------------------------------------------

class yeast_mapping:
	@staticmethod
	def execute(**kwargs):
		'''
		crossmaps between IDs derived from SGD, Uniprot and Gene Symbols
		'''
		folder = kwargs.get('folder','PATH')

		print('load_mapping_file')
		df = yeast_mapping.load_mapping_file(folder)

		print('get_mapping_dictionary')
		symbol_sgd_to_uniprot, symbol_sgd_to_sgd = yeast_mapping.get_mapping_dictionary(df)

		print('reverse_dictionaries')
		uniprot_to_symbol, sgd_to_symbol = yeast_mapping.reverse_dictionaries(symbol_sgd_to_uniprot, symbol_sgd_to_sgd)
		return df, symbol_sgd_to_uniprot, symbol_sgd_to_sgd, uniprot_to_symbol, sgd_to_symbol

	@staticmethod
	def load_mapping_file(folder):
		'''
		retrieved file from https://www.uniprot.org/docs/yeast (2018)
		'''
		folder = folder + 'data/'
		yeast = "yeast_mapping.txt"
		df = DataFrameAnalyzer.getFile(folder, yeast, sep = '\t')
		return df

	@staticmethod
	def get_mapping_dictionary(df):
		fileList = numpy.array(df)
		symbol_sgd_to_uniprot = dict()
		new_fileList = list()
		for i,f in enumerate(fileList):
			lst = filter(lambda a:str(a)!="",f[0].split(" "))
			new_fileList.append(lst)
		for i,f in enumerate(new_fileList):
			if i == 6726:
				break
			#if f[0].startswith("YSF3")==True:raise Exception()
			if f[0].endswith(";") == True:
				if f[-5].find("_YEAST")!=-1:
					symbol_sgd_to_uniprot.setdefault(f[0].split(";")[0],[]).append(f[-6])
					symbol_sgd_to_uniprot.setdefault(f[1].split(";")[0],[]).append(f[-6])
					symbol_sgd_to_uniprot.setdefault(f[2].split(";")[0],[]).append(f[-6])
				else:
					symbol_sgd_to_uniprot.setdefault(f[0].split(";")[0],[]).append(f[-5])
					symbol_sgd_to_uniprot.setdefault(f[1].split(";")[0],[]).append(f[-5])
					symbol_sgd_to_uniprot.setdefault(f[2].split(";")[0],[]).append(f[-5])
			else:
				symbol_sgd_to_uniprot.setdefault(f[0],[]).append(f[2])
				symbol_sgd_to_uniprot.setdefault(f[1],[]).append(f[2])
		symbol_sgd_to_uniprot["SUR4"] = ["P40319"]
		symbol_sgd_to_uniprot["YBL005W-B"] = ["Q12490"]
		symbol_sgd_to_uniprot["YCL019W"] = ["P25384"]
		symbol_sgd_to_uniprot["YCL020W"] = ["P25383"]
		symbol_sgd_to_uniprot["YCL057C-A"] = ["Q96VH5"]

		symbol_sgd_to_sgd = dict()
		for i,f in enumerate(new_fileList):
			if i == 6726:
				break
			#if f[0].startswith("YSF3")==True:raise Exception()
			if f[0].endswith(";") == True:
				if f[-5].find("_YEAST")!=-1:
					symbol_sgd_to_sgd.setdefault(f[0].split(";")[0],[]).append(f[-7])
					symbol_sgd_to_sgd.setdefault(f[1].split(";")[0],[]).append(f[-7])
					symbol_sgd_to_sgd.setdefault(f[2].split(";")[0],[]).append(f[-7])
				else:
					symbol_sgd_to_sgd.setdefault(f[0].split(";")[0],[]).append(f[-6])
					symbol_sgd_to_sgd.setdefault(f[1].split(";")[0],[]).append(f[-6])
					symbol_sgd_to_sgd.setdefault(f[2].split(";")[0],[]).append(f[-6])
			else:
				symbol_sgd_to_sgd.setdefault(f[0],[]).append(f[1])
				symbol_sgd_to_sgd.setdefault(f[1],[]).append(f[1])
		symbol_sgd_to_sgd["SUR4"] = ["YLR372W"]
		symbol_sgd_to_sgd["YBL005W-B"] = ["YBL005W-B"]
		symbol_sgd_to_sgd["YCL019W"] = ["YCL019W"]
		symbol_sgd_to_sgd["YCL020W"] = ["YCL020W"]
		symbol_sgd_to_sgd["YCL057C-A"] = ["YCL057C-A"]
		return symbol_sgd_to_uniprot, symbol_sgd_to_sgd

	@staticmethod
	def reverse_dictionaries(symbol_sgd_to_uniprot, symbol_sgd_to_sgd):
		uniprot_to_symbol = dict()
		sgd_to_symbol = dict()
		for key in symbol_sgd_to_uniprot:
			for item in symbol_sgd_to_uniprot[key]:
				get_key = uniprot_to_symbol.get(item,0)
				if get_key==0:
					uniprot_to_symbol.setdefault(item, [])
				uniprot_to_symbol[item].append(key)
		for key in symbol_sgd_to_sgd:
			for item in symbol_sgd_to_sgd[key]:
				get_key = sgd_to_symbol.get(item,0)
				if get_key==0:
					sgd_to_symbol.setdefault(item, [])
				sgd_to_symbol[item].append(key)

		for key in uniprot_to_symbol.keys():
			values = set(uniprot_to_symbol[key])
			key_uniprot = set([key])
			diff = values.difference(key_uniprot)
			if diff>0:
				uniprot_to_symbol[key] = list(diff)

		for key in sgd_to_symbol.keys():
			values = set(sgd_to_symbol[key])
			key_sgd = set([key])
			diff = values.difference(key_sgd)
			if diff>0:
				sgd_to_symbol[key] = list(diff)

		return uniprot_to_symbol, sgd_to_symbol

class yeast_complexes:
	@staticmethod
	def execute(symbol_sgd_to_uniprot, **kwargs):
		'''
		prepares dictionaries with proteins mapped to complexes and vice versa
		'''
		folder = kwargs.get('folder', 'PATH')

		print('get_complexes')
		yeast_complex_sgd_dict, yeast_complex_symbol_dict, yeast_complex_uniprot_dict = yeast_complexes.get_complexes(folder, symbol_sgd_to_uniprot)
		print('reverse_complex_dictionaries')
		rev_sgd_dict, rev_symbol_dict, rev_uniprot_dict = yeast_complexes.reverse_complex_dictionaries(yeast_complex_sgd_dict, yeast_complex_symbol_dict, yeast_complex_uniprot_dict)
		return yeast_complex_sgd_dict, yeast_complex_symbol_dict, yeast_complex_uniprot_dict, rev_sgd_dict, rev_symbol_dict, rev_uniprot_dict

	@staticmethod
	def get_complexes(folder, symbol_sgd_to_uniprot):
		'''
		List of yeast complexes derived from https://www.ncbi.nlm.nih.gov/pubmed/20620961
		'''
		fname = 'data/yeast_benschop_complexes.txt'
		data = DataFrameAnalyzer.getFile(folder, fname)

		yeast_complex_uniprot_dict = dict((e1,list()) for e1 in list(data.index))
		yeast_complex_symbol_dict = dict((e1,list()) for e1 in list(data.index))
		yeast_complex_sgd_dict = dict((e1,list()) for e1 in list(data.index))

		for complexId,row in data.iterrows():
			symbols = row['Complex members (standard gene name)'].split('; ')
			sgds = row['Complex members (systematic name)'].split('; ')
			uniprots = list()
			for s in symbols:
				if s=='DUF1':
					uniprots.append('Q99247')
				else:
					if s=='LRC5':
						s = 'GEP3'
					uniprots.append(symbol_sgd_to_uniprot[s])

			yeast_complex_uniprot_dict[complexId] = uniprots
			yeast_complex_symbol_dict[complexId] = symbols
			yeast_complex_sgd_dict[complexId] = sgds
		return yeast_complex_sgd_dict, yeast_complex_symbol_dict, yeast_complex_uniprot_dict

	@staticmethod
	def reverse_complex_dictionaries(sgd_dict, symbol_dict, uniprot_dict):
		rev_sgd_dict = dict()
		rev_symbol_dict = dict()
		rev_uniprot_dict = dict()

		for key in sgd_dict.keys():
			for value in sgd_dict[key]:
				get_key = rev_sgd_dict.get(value, 0)
				if get_key==0:
					rev_sgd_dict.setdefault(value,[])
				rev_sgd_dict[value].append(key)
		for key in symbol_dict.keys():
			for value in symbol_dict[key]:
				get_key = rev_symbol_dict.get(value, 0)
				if get_key==0:
					rev_symbol_dict.setdefault(value,[])
				rev_symbol_dict[value].append(key)
		for key in uniprot_dict.keys():
			for value in uniprot_dict[key]:
				get_key = rev_uniprot_dict.get(value[0], 0)
				if get_key==0:
					rev_uniprot_dict.setdefault(value[0],[])
				rev_uniprot_dict[value[0]].append(key)
		return rev_sgd_dict, rev_symbol_dict, rev_uniprot_dict

class yeast_pathways:
	@staticmethod
	def execute(**kwargs):
		'''
		prepares dictionaries with proteins mapped to pathways and vice versa
		'''
		folder = kwargs.get('folder', 'PATH')

		print('define_pathway_dictionaries')
		pathway_dict, rev_pathway_dict = yeast_pathways.define_pathway_dictionaries(folder)
		return pathway_dict, rev_pathway_dict

	@staticmethod
	def define_pathway_dictionaries(folder):
		'''
		List of yeast pathways derived from https://pathway.yeastgenome.org/
		'''
		fname = 'data/biochemical_pathways.tab'
		data = pd.read_csv(folder + fname, sep = '\t')

		proteins = list(data['3'])
		pathways = list(data['0'])
		pathway_dict = dict((e1,list()) for e1 in list(set(proteins)))
		for i,protein in enumerate(proteins):
			if pathways[i] not in pathway_dict[protein]:
				pathway_dict[protein].append(pathways[i])

		rev_pathway_dict = dict()
		for pro in pathway_dict.keys():
			for pat in pathway_dict[pro]:
				get_key = rev_pathway_dict.get(pat,0)
				if get_key==0:
					rev_pathway_dict.setdefault(pat,[])
				rev_pathway_dict[pat].append(pro)
		return pathway_dict, rev_pathway_dict

class YeastMappings(object):
	def __init__(self, **kwargs):
		folder = kwargs.get('folder', 'PATH')

		print('YEAST MAPPING')
		self.df, self.symbol_sgd_to_uniprot, self.symbol_sgd_to_sgd, self.uniprot_to_symbol, self.sgd_to_symbol = yeast_mapping.execute(folder = folder)
		print('YEAST COMPLEX RETRIEVAL')
		self.sgd_dict, self.symbol_dict, self.uniprot_dict, self.rev_sgd_dict, self.rev_symbol_dict, self.rev_uniprot_dict = yeast_complexes.execute(self.symbol_sgd_to_uniprot, folder = folder)
		print('YEAST PATHWAY RETRIEVAL')
		self.pathway_dict, self.rev_pathway_dict = yeast_pathways.execute(folder = folder)

class export_object:
	@staticmethod
	def execute(yeast_mapping, **kwargs):
		folder = kwargs.get('folder', 'PATH')
		DataFrameAnalyzer.to_pickle(yeast_mapping, folder + "yeast_mapping.obj")

if __name__ == "__main__":
	## EXECUTE STEP16
	yeast_mapping = YeastMappings(folder = sys.argv[1])
	export_object.execute(yeast_mapping, folder = sys.argv[1])
