#!/usr/bin/env python
# -*- coding: utf-8 -*-
import argparse
import gc
from tqdm import tqdm
import os
os.environ["TOKENIZERS_PARALLELISM"] = 'false'

import h5py
import numpy as np
import pandas as pd
import pysam

from transformers import AutoTokenizer
import logging
logging.basicConfig(format="%(asctime)s %(message)s", level=logging.INFO)

tqdm.pandas()



pd.options.display.max_colwidth = 1000
################################################################
def split_info_fields(info_record):
    records = info_record.split(";")
    result = {}
    for record in records:
        record_data = record.split("=")
        result[record_data[0]] = record_data[1]
    return result


###

def extract_mane_names(att, name, separator):
  att_list = str(att).split(';')
  for i in att_list:
    if name in i:      
      return str(i.split(separator)[-1])

###

def find_mane_genes_list(gff_mane):
    df = gff_mane
    gene_df = df[df['type'] == 'gene']

    gene_names = gene_df["attributes"].apply(extract_mane_names, name='gene_name', separator='=').dropna().tolist()    
    return gene_names
  
###
'''
def find_mane_transcripts(gff_mane):
    df = gff_mane
    transcript_df = df[(df['type'] == 'transcript') & (df['attributes'].str.contains("tag=MANE_Select"))]
    transcripts_names = transcript_df["attributes"].apply(extract_mane_names, name='Dbxref=', separator=':').dropna().tolist()
    
    return transcripts_names
'''
###

def find_mane_transcripts_gff(mane_gff):
    df = mane_gff

    filtered_df = transcript_df = df[(df['type'] == 'transcript')]
    new_df = filtered_df[['seqid', 'type', 'start', 'end', 'strand', 'attributes']]
    return filtered_df
#    return new_df
    
###

def find_gff_transcripts_gff(gff):
    df = gff

    filtered_df = df[(df['type'] == 'mRNA') | (df['type'] == 'lnc_RNA')]
#    filtered_df['lens'] = filtered_df['end'] - filtered_df['start']

    new_df = filtered_df[['seqid', 'type', 'start', 'end', 'strand', 'attributes', 'lens']]
    return new_df

###

def find_gff_genes_gff(gff):    
    df = gff

    filtered_df = df[(df['type'] == 'gene') & 
                     ((df['attributes'].str.contains('gene_biotype=protein_coding')) | (df['attributes'].str.contains('gene_biotype=lncRNA')))]
    
    new_df = filtered_df[['seqid', 'start', 'end', 'strand', 'attributes', 'lens']]
    return new_df
###

def find_gff_exon_CDS_gff(gff):    
    df = gff

    filtered_df = df[(df['type'] == 'exon') | (df['type'] == 'CDS')]
    
    new_df = filtered_df[['seqid', 'type', 'start', 'end', 'strand', 'attributes', 'lens']]
    return new_df
        
###

def find_longest_cds(transcript_gff_row, df_content, part_list, part='CDS'):

    transcript_name = extract_mane_names(transcript_gff_row['attributes'], name='ID=', separator='=')

    part_df = df_content[(df_content['type'] == part) & (df_content['attributes'].str.contains(f'Parent={transcript_name};'))]

    part_len = part_df['lens'].sum()

    part_list.append(part_len)

###
def choose_transcript(att, mane_transcripts_gff, gff_transcripts_gff, mane_genes, list_transcripts, gff_exon_CDS_gff):
    att_list = att.split(';')
    for i in att_list:
      if 'Name=' in i:
        gene_name = i.split('=')[-1] + ';'#Name=SOX12 -> SOX12
      elif 'Dbxref=' in i:
        gene_id = i.split(',')[0] + ','#Dbxref=GeneID:6666



    df_mane = mane_transcripts_gff
    df_gff = gff_transcripts_gff

#    print(gene_name)
  
    if gene_name.split(';')[0] in mane_genes:

        transcript = df_mane[(df_mane['type'] == 'transcript') & (df_mane['attributes'].str.contains(f'gene_name={gene_name}'))]
        
        list_transcripts.append(extract_mane_names(transcript["attributes"].to_string(), name='Dbxref=RefSeq', separator=':'))

    else:
 
        all_transcripts = df_gff[df_gff['attributes'].str.contains(gene_id)]
#        transcript = all_transcripts.loc[all_transcripts['lens'].idxmax()]
        
        if 'gene_biotype=protein_coding' in att_list:
          cds_list = []
          all_transcripts.apply(find_longest_cds, df_content=gff_exon_CDS_gff, part_list=cds_list, part='CDS', axis=1)
          transcript = all_transcripts.iloc[cds_list.index(max(cds_list))]
        else:           
          exon_list = []
          all_transcripts.apply(find_longest_cds, df_content=gff_exon_CDS_gff, part_list=exon_list, part='exon', axis=1)
          transcript = all_transcripts.iloc[exon_list.index(max(exon_list))] 
                  
        list_transcripts.append(extract_mane_names(transcript["attributes"], name='ID=', separator='=').split('-')[-1])
    
###

    
# Get part of DNA sequence
def tokenize_sequence(
    start_for_tokenize_d,
    end_for_tokenize_d,
    chr_name,
    reference,
    tokenizer,
):
    start_seq_d = start_for_tokenize_d
    end_seq_d = end_for_tokenize_d   
 
    DNA_seq = reference.fetch(reference=chr_name, start=start_seq_d, end=end_seq_d)
    DNA_seq = DNA_seq.upper()

    if len(set(DNA_seq) - set('ACTG')) > 0:
        return 'N'#, 'N'
    

#    DNA_seq_r =  DNA_seq.replace('T', 'O').replace('C', 'Z').replace('A', 'T').replace('G', 'C').replace('O', 'A').replace('Z', 'G')[::-1]

    
    # Tokenize part of DNA sequence
    tokens = tokenizer.encode_plus(DNA_seq, add_special_tokens=False, return_offsets_mapping=True)
#    tokens_r = tokenizer.encode_plus(DNA_seq_r, add_special_tokens=False, return_offsets_mapping=True)
    
    return tokens#, tokens_r

def search_exons_end_cds(trans_content, strand):

    # first and last exon - for search intron
    exon_cotent = trans_content[trans_content["type"] == "exon"]
    
    if exon_cotent.empty:
        exon_start_d = -1000
        exon_end_d = -1000
    else:
        exon_start_d = exon_cotent["start"].values.min()
        exon_end_d = exon_cotent["end"].values.max()

    # first and last CDS - for search UTR
    cds_cotent = trans_content[trans_content["type"] == "CDS"]
    if cds_cotent.empty:
        cds_start_d = -1000
        cds_end_d = -1000
    else:
        cds_start_d = cds_cotent["start"].values.min()
        cds_end_d = cds_cotent["end"].values.max()

    return exon_start_d, exon_end_d, cds_start_d, cds_end_d


#classification
def classification(
    tokens_for_classing,
    trans_content,
    exon_start_d,
    exon_end_d,
    cds_start_d,
    cds_end_d,
    start_for_tokenize_d,
    strand,
    mode,
    end_for_tokenize_d
):
    def process_transcipt_element(tr, arr):
        for t in ["exon", "CDS"]:
            if tr["type"] == t:
                arr[
                    class_lables.index(t),
                    tr["start"] + DNA2tokens_shift : tr["end"] + DNA2tokens_shift,
                ] = 1

    for i in [exon_start_d, exon_end_d, start_for_tokenize_d]:
        assert i >= 0

    DNA2tokens_shift = - start_for_tokenize_d
    assert DNA2tokens_shift <= 0

    tokens2DNA_shift = - DNA2tokens_shift
    assert tokens2DNA_shift >= 0

    tokens_start_d = tokens_for_classing["offset_mapping"][0][0] + tokens2DNA_shift
    tokens_end_d = tokens_for_classing["offset_mapping"][-1][-1] + tokens2DNA_shift
    assert tokens_end_d > tokens_start_d

    tokens_for_classing = np.array(tokens_for_classing["offset_mapping"])

    class_lables = ["5UTR", "exon", "intron", "3UTR", "CDS"]

    classes_t = np.zeros(
        shape=(len(class_lables), tokens_end_d - tokens_start_d), dtype=np.int8
    )

    exon_start_t = exon_start_d + DNA2tokens_shift
    exon_end_t = exon_end_d + DNA2tokens_shift


    trans_content.apply(process_transcipt_element, arr=classes_t, axis="columns")

    classes_t[class_lables.index("intron"), exon_start_t:exon_end_t][classes_t[class_lables.index("exon"), exon_start_t:exon_end_t] != 1] = 1

    if cds_start_d > 0:
        cds_start_t = cds_start_d + DNA2tokens_shift
        assert exon_start_t <= cds_start_t, "CDS start is not within exon"
        if exon_start_t < cds_start_t:
            if strand == "+":
                classes_t[class_lables.index("5UTR"), exon_start_t:cds_start_t] = classes_t[class_lables.index("exon"), exon_start_t:cds_start_t]
            if strand == "-":
                classes_t[class_lables.index("3UTR"), exon_start_t:cds_start_t] = classes_t[class_lables.index("exon"), exon_start_t:cds_start_t]


    if cds_end_d > 0:
        cds_end_t = cds_end_d + DNA2tokens_shift
        assert cds_end_t <= exon_end_t, "CDS end is not within exon"
        if cds_end_t < exon_end_t:
            if strand == "+":
                classes_t[class_lables.index("3UTR"), cds_end_t : exon_end_t] = classes_t[class_lables.index("exon"), cds_end_t : exon_end_t]
            if strand == "-":
                classes_t[class_lables.index("5UTR"), cds_end_t : exon_end_t] = classes_t[class_lables.index("exon"), cds_end_t : exon_end_t]


    if mode == 'reverse':
        classes_t = classes_t.T[::-1].T

    classes = np.zeros(shape=(len(class_lables), len(tokens_for_classing)), dtype=np.int8)

    for ind, (st, end) in enumerate(tokens_for_classing):
        classes[:, ind] = classes_t[
            :, st:end
        ].max(axis=1)
    

    return classes, classes_t


# Main process
def process_transcript(transcript, index_transcript, list_transcripts):
        logging.debug(f"------------new transcript start------------")
        transcript_name = transcript
    
        transcript_name_short = transcript_name.split('-')[-1]
        if transcript_name_short not in list_transcripts:
            return index_transcript
    
        transcript_content = data_grouped.get_group(transcript)
        transcript_info = transcripts_index.loc[transcript]
        gene = transcript_info["Parent"]
        transcript_strand = transcript_info["strand"]
        transcript_chr = transcript_info["seqid"]
        transcript_class = transcript_info["type"]


        start_for_tokenize_d = transcript_info["start"]
        end_for_tokenize_d = transcript_info["end"]
        
        if end_for_tokenize_d - start_for_tokenize_d > 250000:
            end_for_tokenize_d = start_for_tokenize_d + 250000

        assert (
            transcript_strand == "+" or transcript_strand == "-"
        ), "Not identifited strand"

        (
            exon_start_d,
            exon_end_d,
            cds_start_d,
            cds_end_d,
        ) = search_exons_end_cds(transcript_content, transcript_strand)



        logging.debug(f"tokenize start")
        transcript_tokens = tokenize_sequence(
            start_for_tokenize_d,
            end_for_tokenize_d,
            transcript_chr,
            ref,
            tokenizer,
        )
        if transcript_tokens == 'N':
            return index_transcript
            
        mode = 'forwrad'    
        logging.debug(f"classification start")
        classes_token, classes_letter = classification(
            transcript_tokens,
            transcript_content,
            exon_start_d,
            exon_end_d,
            cds_start_d,
            cds_end_d,
            start_for_tokenize_d,
            transcript_strand,
            mode,
            end_for_tokenize_d
        )

        coordinates = [int(start_for_tokenize_d), int(end_for_tokenize_d)]
        
        chars = np.array(tokenizer.convert_ids_to_tokens(transcript_tokens["input_ids"]))
        seq = ''.join(chars)

        d_seq = np.array(list(seq))
        d_seq = np.where(d_seq == 'A', int(6), d_seq)
        d_seq = np.where(d_seq == 'T', int(15), d_seq)
        d_seq = np.where(d_seq == 'C', int(8), d_seq)
        d_seq = np.where(d_seq == 'G', int(9), d_seq)

        d_seq = np.array(d_seq, dtype=int)
        


        with h5py.File(args.output, "a") as file:
            group = file.create_group(f"transcript_{index_transcript}") 
            group.attrs['strand'] = transcript_strand
            group.attrs['type'] = transcript_class
            group.attrs['gene_name'] = gene
            group.attrs['transcript_name'] = transcript_name
            group.attrs['genome'] = 'GCF_009914755.1_T2T-CHM13v2.0_genomic'
                        
            token_atcg = d_seq
            labels_atcg = classes_letter.T
            labels = np.concatenate([np.array([[-100]*5]),
                                    classes_token.T,
                                    np.array([[-100]*5])], axis=0)

            input_ids = np.concatenate([[tokenizer.convert_tokens_to_ids("[CLS]")], 
                                        transcript_tokens["input_ids"],
                                        [tokenizer.convert_tokens_to_ids("[SEP]")],])
    
#            print(len(labels), len(token_atcg))
            assert len(labels) == len(input_ids)
            assert len(labels_atcg) == len(token_atcg)

            group.create_dataset("coordinates", data=coordinates)
            group.create_dataset("input_ids", data=input_ids, compression="gzip")
            group.create_dataset("labels", data=labels, compression="gzip")
            group.create_dataset("token_atcg", data=token_atcg, compression="gzip")
            group.create_dataset("labels_atcg", data=labels_atcg, compression="gzip")

            index_transcript += 1
            print(f"labels.shape = {labels.shape}")
            print(f"input_ids.shape = {input_ids.shape}")
            print(f"token_atcg.shape = {token_atcg.shape}")
            print(f"labels_atcg.shape = {labels_atcg.shape}")            
            print('--------------------------------------')

            del classes_token
            del classes_letter           
            del transcript_tokens



        return index_transcript

################################################################
# Let's do it
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--chrs", type=str, default="all", help="Path to TXT file with chromosomes' names"
    )
    parser.add_argument("--gff", type=str, help="Path to GFF file")
    parser.add_argument("--mane", type=str, help="Path to MANE GFF file")
    parser.add_argument("--fasta", type=str, help="Path to FASTA file") 
    parser.add_argument("--fai", type=str, help="Path to FAI file")
    parser.add_argument("--output", type=str, help="Path to output HDF5 file")
    parser.add_argument(
        "--tokenizer_name",
        type=str,
        default="AIRI-Institute/gena-lm-bigbird-base-t2t",
        help="tokenizer",
    )
    args = parser.parse_args()


    
    col_names = [
        "seqid",
        "source",
        "type",
        "start",
        "end",
        "score",
        "strand",
        "phase",
        "attributes",
    ]
    data_gff = pd.read_csv(args.gff, sep="\t", names=col_names, header=None, comment="#")
    data_gff["start"] = (
        data_gff["start"] - 1
    )  # coodinates in GFF file are 1-based, convert to 0-based
    # data_gff["end"] = data_gff["end"] - 1 # do not substract from the end; intervals in GFF are closed, but now we can consider
    #                               them as half-opened intervals

    data_gff['lens'] = data_gff['end'] - data_gff['start']
    
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)
    chromsizes = pd.read_csv(args.fai, sep="\t", header=None, names=['chr', 'length', 'th1', 'th2', 'th3'])
    if args.chrs != "all":
        selected_chrs = pd.read_csv(args.chrs, header=None)
        selected_chrs = selected_chrs[0].tolist()


    index_transcript = 0
    index_limit = 100
    ref = pysam.Fastafile(args.fasta)
    
    for selected_chromosome in selected_chrs:
#        if selected_chromosome != 'NC_060944.1':
#            continue
        print(f"selected_chromosome = {selected_chromosome}")
        data = data_gff[data_gff["seqid"].isin([selected_chromosome])]
        data.reset_index(drop=True, inplace=True)



    
        #prepare list_transcripts
    
        data_mane = pd.read_csv(args.mane, sep="\t", names=col_names, header=None, comment="#")
        #    mane_transcripts_list = find_mane_transcripts_list(args.mane)
        mane_genes_list = find_mane_genes_list(data_mane)
        mane_transcripts_gff = find_mane_transcripts_gff(data_mane)
    

        gff_exon_CDS_gff = find_gff_exon_CDS_gff(data)

        gff_transcripts_gff = find_gff_transcripts_gff(data)
        gff_genes_gff = find_gff_genes_gff(data)
    
        list_transcripts = []
    
        #    gff_genes_gff = gff_genes_gff.head(n=5)
        gff_genes_gff['attributes'].progress_apply(choose_transcript, mane_transcripts_gff=mane_transcripts_gff, gff_transcripts_gff=gff_transcripts_gff, mane_genes=mane_genes_list, list_transcripts=list_transcripts, gff_exon_CDS_gff=gff_exon_CDS_gff)

        #    print(list_transcripts)





        new_cols = pd.DataFrame.from_records(data["attributes"].apply(split_info_fields))
        data = pd.concat((data, new_cols), axis=1)

        data_grouped = data.groupby("Parent")

        data_lnc_RNA = data["type"] == "lnc_RNA"
        data_mRNA = data["type"] == "mRNA"
        transcript_name_mask = data_mRNA | data_lnc_RNA

        valid_transcripts = data[transcript_name_mask]["ID"].unique()

        assert (
            len(
                np.intersect1d(
                    data.query("Parent in @valid_transcripts")["ID"].unique(),
                    data["Parent"].dropna().unique(),
                )
            )
            == 0
        )
        transcripts_index = data[transcript_name_mask].set_index("ID")
#        for transcript in tqdm(valid_transcripts):
#          print(f'{transcript}')
        
        for transcript in tqdm(valid_transcripts):
          index_transcript = process_transcript(transcript, index_transcript, list_transcripts)
          print(index_transcript)
          if index_transcript >= index_limit:
            del tokenizer
            gc.collect()
            tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)
            index_limit += 100
        
    print(index_transcript)


