#!/usr/bin/env python
# -*- coding: utf-8 -*-
import h5py
import argparse
from tqdm import tqdm
import numpy as np
import pandas as pd
import pysam
import logging
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel
from collections import defaultdict

np.set_printoptions(threshold=np.inf)
logging.basicConfig(format="%(asctime)s %(message)s", level=logging.INFO)
#from safetensors.torch import load_file

def encode_sequences(sequences):
    one_hot_map = {
        'a': torch.tensor([1., 0., 0., 0.]),
        'c': torch.tensor([0., 1., 0., 0.]),
        'g': torch.tensor([0., 0., 1., 0.]),
        't': torch.tensor([0., 0., 0., 1.]),
        'n': torch.tensor([0., 0., 0., 0.]),
        'A': torch.tensor([1., 0., 0., 0.]),
        'C': torch.tensor([0., 1., 0., 0.]),
        'G': torch.tensor([0., 0., 1., 0.]),
        'T': torch.tensor([0., 0., 0., 1.]),
        'N': torch.tensor([0., 0., 0., 0.])
    }

    def encode_sequence(seq_str):
        one_hot_list = []
        for char in seq_str:
            one_hot_vector = one_hot_map.get(char, torch.tensor([0.25, 0.25, 0.25, 0.25]))
            one_hot_list.append(one_hot_vector)
        return torch.stack(one_hot_list)

    if isinstance(sequences, list):
        return torch.stack([encode_sequence(seq) for seq in sequences])
    else:
        return encode_sequence(sequences)

def model_setup(model_name):
    """
    Preparing letter-level model
    """

    model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    model=model.cuda()
#    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

    return model#, tokenizer

def prediction(seq, model, model_len):

    seq_len = len(seq)
    
#    print(f"???\nlen(seq) = {len(seq)}\n???")
    
    if seq_len < model_len:
        seq = seq.ljust(model_len, 'N')
        
#        print(f"2???\nlen(seq) N = {len(seq)}\n???2")  
    
    model=model.cuda()
    one_hot_encoding = encode_sequences([seq]).cuda()

    preds = model(one_hot_encoding)
#    print(f"3???\np_preds = {preds['logits'].shape}\n???3")
    probabilities = torch.softmax(preds['logits'], dim=-1).cpu()[..., -1].detach().numpy()
    p_labels = probabilities[0] #.T[[6,2,3,7]].T

    fs = [
        "protein_coding_gene",
        "lncRNA",
        "exon",
        "intron",
        "splice_donor",
        "splice_acceptor",
        "5UTR",
        "3UTR",
        "CTCF-bound",
        "polyA_signal",
        "enhancer_Tissue_specific",
        "enhancer_Tissue_invariant",
        "promoter_Tissue_specific",
        "promoter_Tissue_invariant",
    ]


    p_labels = p_labels[:, [fs.index("5UTR"), fs.index("exon"), fs.index("intron"), fs.index("3UTR")]]

#    print(f"3???\np_labels.shape = {p_labels.shape}\n???3") 
    
    if 'N' in seq:
        p_labels = p_labels[:seq_len]
        
#        print(f"4???\np_labels.shape N = {p_labels.shape}\n???4")
    return p_labels


def split_info_fields(info_record, list_records):
    """
    Add new columns from attribures
    """
    
    records = info_record.split(";")
    result = {}
    for record in records:
        record_data = record.split("=")
        result[record_data[0]] = record_data[1]
    
    list_records.append(result)

def prepare_gff(args, chrom=['NC_060944.1']):
    col_names = [
        "seqid",
        "source",
        "type",
        "start",
        "end",
        "score",
        "strand",
        "phase",
        "attributes",
    ]
    
    data_gff = pd.read_csv(args.gff, sep="\t", names=col_names, header=None, comment="#")
    data_gff["start"] = (data_gff["start"] - 1)  # coodinates in GFF file are 1-based, convert to 0-based
    # data_gff["end"] = data_gff["end"] - 1 # do not substract from the end; intervals in GFF are closed, but now we can consider
    #                               them as half-opened intervals
    
    data_gff['lens'] = data_gff['end'] - data_gff['start']
        
       
    data = data_gff[(data_gff["type"].isin(['region', 'gene', 'mRNA', 'lnc_RNA', 'exon', 'CDS'])) & (data_gff["seqid"].isin(chrom))]
    data.reset_index(drop=True, inplace=True)    
    new_col_list = []
    data["attributes"].apply(split_info_fields, list_records=new_col_list)
    new_cols = pd.DataFrame.from_dict(new_col_list)
    data = pd.concat((data, new_cols), axis=1)
    
    chr_len = data[data["type"] == 'region'].iloc[0]['lens'] 
    id_indexed = data.set_index("ID")
    
    return data, id_indexed, chr_len


def find_segments_ones(array):
    
    ones_idx = np.where(array == 1)[0]
    if ones_idx.size == 0:
        return []

    split_idx = np.where(np.diff(ones_idx) > 1)[0] + 1
    split_ones_idx = np.split(ones_idx, split_idx)

    return [(segment[0], segment[-1] + 1) for segment in split_ones_idx]


def get_sequence(
    start_d,
    end_d,
    chr_name,
    reference,
):  
    """
    Get part of DNA sequence in string
    """
    DNA_seq = reference.fetch(reference=chr_name, start=start_d, end=end_d)
    DNA_seq = DNA_seq.upper()
    
    if len(set(DNA_seq) - set('ACTG')) > 0:
        return 'N' #, 'N'
        
#    DNA_seq_reverse = DNA_seq.replace('T', 'O').replace('C', 'Z').replace('A', 'T').replace('G', 'C').replace('O', 'A').replace('Z', 'G')[::-1]
    
    return DNA_seq #, DNA_seq_reverse

def search_parts(transcript_content, part, transcript_start):
    
    part_cotent = transcript_content[transcript_content["type"] == part]
        
    starts_part = np.array(part_cotent['start'].tolist())
    ends_part = np.array(part_cotent['end'].tolist())

    return list(zip(starts_part - transcript_start, ends_part - transcript_start))

def process_seq_prediction(transcript_seq, model, model_len, mode='forward'):

    assert transcript_seq, "Input transcript sequence is empty"
#    print(f"!!!\n{len(transcript_seq)}\n!!!")
    p_labels = np.concatenate(
        [prediction(transcript_seq[j:j+model_len], model, model_len)
             for j in range(0, len(transcript_seq), model_len)],
        axis=0
    )

    if mode == 'reverse':
        p_labels = p_labels[::-1]

    return p_labels
            
def get_h5_data(seq_info, transcript_chr, ref, chr_len, model, model_len, shift_ir):

    start_for_tokenize_d = max(0, seq_info["start"] - shift_ir)
    end_for_tokenize_d = min(chr_len, seq_info["end"] + shift_ir)

    seq_len = end_for_tokenize_d - start_for_tokenize_d


    coordinates = [start_for_tokenize_d, end_for_tokenize_d] 

    seq_forward = get_sequence(
        start_for_tokenize_d,
        end_for_tokenize_d,
        transcript_chr,
        ref,
    )
    if seq_forward == 'N':
        return np.array([]), 'N', np.array([])      
    else:
        return np.array([]), seq_forward, np.array([]) 
#    print(f"seq_forward = {len(seq_forward)}")
    
    p_labels_forward = process_seq_prediction(seq_forward, model, model_len, mode='forward').T
#    p_labels_reverse = process_seq_prediction(seq_reverse, model, tokenizer, model_len, mode='reverse').T
#    print(f"p_labels_forward.shape = {p_labels_forward.shape}")
#    assert p_labels_reverse.shape == p_labels_forward.shape
    assert p_labels_forward.shape[1] == len(seq_forward)
    
    return p_labels_forward, seq_forward, coordinates
    
def make_targets(transcript_content, transcript_class, transcript_seq, start_for_tokenize_d):
    
    part_list = []
    max_targets_labels = ['exon', 'CDS']
    targets_labels = max_targets_labels if transcript_class == 'mRNA' else ['exon']
    
    for part in targets_labels:
        part_list.append(search_parts(transcript_content, part, start_for_tokenize_d))    
    targets = np.zeros((len(max_targets_labels), len(transcript_seq)))

    for part in targets_labels:
        for i, (st, ed) in enumerate(part_list[targets_labels.index(part)]):
            targets[targets_labels.index(part), st : ed] = 1

    return targets

def process_transcript(transcript, data, id_indexed, chr_len, model, model_len, ref, output_path, h5_file, index_transcript, shift_ir=0):

    logging.debug(f"------------new transcript start------------") 
    
    transcript_content = data[data['Parent'] == transcript]
    transcript_info = id_indexed.loc[transcript]

    gene = transcript_info["Parent"]
    gene_info = id_indexed.loc[gene]

    transcript_strand = transcript_info["strand"]
    transcript_chr = transcript_info["seqid"]
    transcript_class = transcript_info["type"]
    
    
    assert (transcript_strand == "+" or transcript_strand == "-"), "Not identifited strand"
    
    logging.debug(transcript)
 
    p_labels_forward, transcript_seq, coordinates  = get_h5_data(transcript_info, transcript_chr,
                                                      ref, chr_len, model, model_len, shift_ir)
    
    gene_p_labels_forward, gene_seq, gene_coordinates  = get_h5_data(gene_info, transcript_chr,
                                                      ref, chr_len, model, model_len, shift_ir)
    
    if transcript_seq == 'N' or gene_seq == 'N':
        with h5py.File(f"{output_path}/{h5_file}.hdf5", "a") as file:
            group = file.create_group(f"transcript_{index_transcript['index']}") 
            group.attrs['chromosome'] = transcript_chr
            group.attrs['transcript_seq'] = transcript_seq
            group.attrs['gene_seq'] = gene_seq
            group.attrs['strand'] = transcript_strand
            group.attrs['type'] = transcript_class
            group.attrs['Parent'] = gene
            group.attrs['ID'] = transcript
    #        group.attrs['genome'] = 'GCF_009914755.1_T2T-CHM13v2.0_genomic'

        index_transcript['index'] += 1         

    else:    
        #targets = make_targets(transcript_content, transcript_class, transcript_seq, coordinates[0])
    
        with h5py.File(f"{output_path}/{h5_file}.hdf5", "a") as file:
            group = file.create_group(f"transcript_{index_transcript['index']}") 
            group.attrs['chromosome'] = transcript_chr
            group.attrs['transcript_seq'] = transcript_seq
            group.attrs['gene_seq'] = gene_seq
            group.attrs['strand'] = transcript_strand
            group.attrs['type'] = transcript_class
            group.attrs['Parent'] = gene
            group.attrs['ID'] = transcript
    #        group.attrs['genome'] = 'GCF_009914755.1_T2T-CHM13v2.0_genomic'
                         
#            group.create_dataset("transcript_coordinates", data=coordinates)
#            group.create_dataset("transcript_predictions_forward", data=p_labels_forward, compression="gzip")
#            group.create_dataset("transcript_predictions_reverse", data=p_labels_reverse, compression="gzip")
#            group.create_dataset("targets", data=targets, compression="gzip")
    
#            group.create_dataset("gene_coordinates", data=gene_coordinates)
#            group.create_dataset("gene_predictions_forward", data=gene_p_labels_forward, compression="gzip")
#            group.create_dataset("gene_predictions_reverse", data=gene_p_labels_reverse, compression="gzip")
      
        index_transcript['index'] += 1         

def main_process(args):
    
    model = model_setup(args.model)
    model_len = int(args.model_len)
    
    ref = pysam.Fastafile(args.fasta)
    transcripts_list = pd.read_csv(args.transcripts_list, header=None)[0].tolist()  
    
    index_transcript = defaultdict(int)  
    
#    selected_chrs = pd.read_csv(args.chrs, header=None)[0].tolist()
    selected_chrs = [args.chrs]
    for selected_chromosome in selected_chrs:
        logging.info(f"selected_chromosome = {selected_chromosome}")
        
        data, id_indexed, chr_len = prepare_gff(args, chrom=[selected_chromosome])
        if data.empty:
            continue
       
        for transcript in tqdm(transcripts_list):
            process_transcript(transcript, data, id_indexed, chr_len, model,
                               model_len, ref, args.output, args.h5, index_transcript)

                   
parser = argparse.ArgumentParser()
parser.add_argument("--chrs", type=str, default="all", help="Path to TXT file with chromosomes' names")
parser.add_argument("--gff", type=str, help="Path to GFF file")
parser.add_argument("--fasta", type=str, help="Path to FASTA file")

#parser.add_argument("--ckpt", type=str, help="Path to model checkpoint")
parser.add_argument("--model_len", type=int, default=49992, help="Lenght of input sequence")

parser.add_argument("--transcripts_list", type=str, help="Path to list of transcripts file")
parser.add_argument("--output", type=str, default='/mnt/10tb/home/shadskiy/nt_pred/segment_borzoi', help="Path to output")
parser.add_argument("--h5", type=str, default='sequences', help="Name of hdf5 file")
parser.add_argument(
    "--model",
    type=str,
    default="InstaDeepAI/segment_enformer",
)
args = parser.parse_args()

main_process(args)