#!/usr/bin/env python
# -*- coding: utf-8 -*-
import h5py
import argparse
from tqdm import tqdm
import numpy as np
import pandas as pd
import pysam
import logging
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoConfig, AutoModel
from collections import defaultdict

np.set_printoptions(threshold=np.inf)
logging.basicConfig(format="%(asctime)s %(message)s", level=logging.INFO)
from safetensors.torch import load_file



def model_setup(model_name):
    """
    Preparing letter-level model
    """

    model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

    return model, tokenizer

def prediction(seq, model, tokenizer):

    max_num_dna_tokens = int(len(seq) / 6)

    if max_num_dna_tokens + 1 > 5001:
        inference_rescaling_factor = (max_num_dna_tokens + 1) / 2048        
        num_layers = len(model.esm.encoder.layer)
        
        for layer in range(num_layers):
            model.esm.encoder.layer[layer].attention.self.rotary_embeddings.rescaling_factor = inference_rescaling_factor
    else:
        inference_rescaling_factor = None


    tokens = tokenizer.batch_encode_plus([seq], return_tensors="pt", padding="max_length", max_length = max_num_dna_tokens)["input_ids"]

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.cuda()
    
    print(len(seq))
    tokens = tokens.cuda()
    attention_mask = (tokens != tokenizer.pad_token_id).cuda()
    print(tokens.shape)
    with torch.no_grad():
        outs = model(
            tokens,
            attention_mask=attention_mask,
        )

    logits = outs["logits"]
    probabilities = np.asarray(torch.nn.functional.softmax(logits, dim=-1).cpu())[...,-1]
      
    p_labels = np.array(probabilities[0].T[[6,2,3,7]]).T

    del outs
    del tokens, attention_mask

    return p_labels

def split_info_fields(info_record, list_records):
    """
    Add new columns from attribures
    """
    
    records = info_record.split(";")
    result = {}
    for record in records:
        record_data = record.split("=")
        result[record_data[0]] = record_data[1]
    
    list_records.append(result)

def prepare_gff(args, chrom=['NC_060944.1']):
    col_names = [
        "seqid",
        "source",
        "type",
        "start",
        "end",
        "score",
        "strand",
        "phase",
        "attributes",
    ]
    
    data_gff = pd.read_csv(args.gff, sep="\t", names=col_names, header=None, comment="#")
    data_gff["start"] = (data_gff["start"] - 1)  # coodinates in GFF file are 1-based, convert to 0-based
    # data_gff["end"] = data_gff["end"] - 1 # do not substract from the end; intervals in GFF are closed, but now we can consider
    #                               them as half-opened intervals
    
    data_gff['lens'] = data_gff['end'] - data_gff['start']
        
       
    data = data_gff[(data_gff["type"].isin(['region', 'gene', 'mRNA', 'lnc_RNA', 'exon', 'CDS'])) & (data_gff["seqid"].isin(chrom))]
    data.reset_index(drop=True, inplace=True)    
    new_col_list = []
    data["attributes"].apply(split_info_fields, list_records=new_col_list)
    new_cols = pd.DataFrame.from_dict(new_col_list)
    data = pd.concat((data, new_cols), axis=1)
    
    chr_len = data[data["type"] == 'region'].iloc[0]['lens'] 
    id_indexed = data.set_index("ID")
    
    return data, id_indexed, chr_len


def find_segments_ones(array):
    
    ones_idx = np.where(array == 1)[0]
    if ones_idx.size == 0:
        return []

    split_idx = np.where(np.diff(ones_idx) > 1)[0] + 1
    split_ones_idx = np.split(ones_idx, split_idx)

    return [(segment[0], segment[-1] + 1) for segment in split_ones_idx]


def get_sequence(
    start_d,
    end_d,
    chr_name,
    reference,
):  
    """
    Get part of DNA sequence in string
    """
    DNA_seq = reference.fetch(reference=chr_name, start=start_d, end=end_d)
    DNA_seq = DNA_seq.upper()
    
    if len(set(DNA_seq) - set('ACTG')) > 0:
        return 'N', 'N'
        
    DNA_seq_reverse = DNA_seq.replace('T', 'O').replace('C', 'Z').replace('A', 'T').replace('G', 'C').replace('O', 'A').replace('Z', 'G')[::-1]
    
    return DNA_seq, DNA_seq_reverse

def search_parts(transcript_content, part, transcript_start):
    
    part_cotent = transcript_content[transcript_content["type"] == part]
        
    starts_part = np.array(part_cotent['start'].tolist())
    ends_part = np.array(part_cotent['end'].tolist())

    return list(zip(starts_part - transcript_start, ends_part - transcript_start))

def process_seq_prediction(transcript_seq, model, tokenizer, model_len, mode='forward'):

    assert transcript_seq, "Input transcript sequence is empty"
    
    p_labels = np.concatenate(
        [prediction(transcript_seq[j:j+model_len], model, tokenizer)
             for j in range(0, len(transcript_seq), model_len)],
        axis=0
    )

    if mode == 'reverse':
        p_labels = p_labels[::-1]

    return p_labels
            
def get_h5_data(seq_info, transcript_chr, ref, chr_len, model, tokenizer, model_len, shift_ir):

    start_for_tokenize_d = max(0, seq_info["start"] - shift_ir)
    end_for_tokenize_d = min(chr_len, seq_info["end"] + shift_ir)


    end_for_tokenize_d = end_for_tokenize_d - (end_for_tokenize_d - start_for_tokenize_d) % 24

    seq_len = end_for_tokenize_d - start_for_tokenize_d
    assert seq_len % 24 ==0 , "Not divisible by 24 seq lenght"


    coordinates = [start_for_tokenize_d, end_for_tokenize_d, seq_info["start"], seq_info["end"]] 

    seq_forward, seq_reverse = get_sequence(
        start_for_tokenize_d,
        end_for_tokenize_d,
        transcript_chr,
        ref,
    )
    if seq_forward == 'N':
        return np.array([]), np.array([]), 'N', np.array([])  
        
    print(len(seq_forward))    
    p_labels_forward = process_seq_prediction(seq_forward, model, tokenizer, model_len, mode='forward').T
#    p_labels_reverse = process_seq_prediction(seq_reverse, model, tokenizer, model_len, mode='reverse').T

#    assert p_labels_reverse.shape == p_labels_forward.shape
    assert p_labels_forward.shape[1] == len(seq_forward)
    
    return p_labels_forward, np.array([]), seq_forward, coordinates
    
def make_targets(transcript_content, transcript_class, transcript_seq, start_for_tokenize_d, end_for_tokenize_d):
    
    part_list = []
    max_targets_labels = ['exon', 'CDS']
    targets_labels = max_targets_labels if transcript_class == 'mRNA' else ['exon']
    
    for part in targets_labels:
        part_list.append(search_parts(transcript_content, part, start_for_tokenize_d))    
    targets = np.zeros((len(max_targets_labels), len(transcript_seq)))

    for part in targets_labels:
        for i, (st, ed) in enumerate(part_list[targets_labels.index(part)]):
            targets[targets_labels.index(part), st : ed] = 1

    return targets #[:, :end_for_tokenize_d - start_for_tokenize_d]

def process_transcript(transcript, data, id_indexed, chr_len, model, model_len, tokenizer, ref, output_path, h5_file, index_transcript, shift_ir=0):

    logging.debug(f"------------new transcript start------------") 
    
    transcript_content = data[data['Parent'] == transcript]
    transcript_info = id_indexed.loc[transcript]

    gene = transcript_info["Parent"]
    gene_info = id_indexed.loc[gene]

    transcript_strand = transcript_info["strand"]
    transcript_chr = transcript_info["seqid"]
    transcript_class = transcript_info["type"]
    
    
    assert (transcript_strand == "+" or transcript_strand == "-"), "Not identifited strand"
    
    logging.debug(transcript)
    print('transcript_pred')
    p_labels_forward, p_labels_reverse, transcript_seq, coordinates  = get_h5_data(transcript_info, transcript_chr,
                                                      ref, chr_len, model, tokenizer, model_len, shift_ir)

    gene_p_labels_forward, gene_p_labels_reverse, gene_seq, gene_coordinates  = get_h5_data(gene_info, transcript_chr,
                                                      ref, chr_len, model, tokenizer, model_len, shift_ir)
    
    if transcript_seq == 'N' or gene_seq == 'N':
        with h5py.File(f"{output_path}/{h5_file}.hdf5", "a") as file:
            group = file.create_group(f"transcript_{index_transcript['index']}") 
            group.attrs['chromosome'] = transcript_chr
            group.attrs['transcript_seq'] = transcript_seq
            group.attrs['gene_seq'] = gene_seq
            group.attrs['strand'] = transcript_strand
            group.attrs['type'] = transcript_class
            group.attrs['Parent'] = gene
            group.attrs['ID'] = transcript
    #        group.attrs['genome'] = 'GCF_009914755.1_T2T-CHM13v2.0_genomic'

        index_transcript['index'] += 1         

    else:    
        targets = make_targets(transcript_content, transcript_class, transcript_seq, coordinates[0], coordinates[1])
    
        with h5py.File(f"{output_path}/{h5_file}.hdf5", "a") as file:
            group = file.create_group(f"transcript_{index_transcript['index']}") 
            group.attrs['chromosome'] = transcript_chr
            group.attrs['transcript_seq'] = transcript_seq
            group.attrs['gene_seq'] = gene_seq
            group.attrs['strand'] = transcript_strand
            group.attrs['type'] = transcript_class
            group.attrs['Parent'] = gene
            group.attrs['ID'] = transcript
    #        group.attrs['genome'] = 'GCF_009914755.1_T2T-CHM13v2.0_genomic'
                         
            group.create_dataset("transcript_coordinates", data=coordinates)
            group.create_dataset("transcript_predictions_forward", data=p_labels_forward, compression="gzip")
            group.create_dataset("transcript_predictions_reverse", data=p_labels_reverse, compression="gzip")
            group.create_dataset("targets", data=targets, compression="gzip")
    
            group.create_dataset("gene_coordinates", data=gene_coordinates)
            group.create_dataset("gene_predictions_forward", data=gene_p_labels_forward, compression="gzip")
            group.create_dataset("gene_predictions_reverse", data=gene_p_labels_reverse, compression="gzip")
      
        index_transcript['index'] += 1         

def main_process(args):
    
    model, tokenizer = model_setup(args.model)
    model_len = int(args.model_len)
    
    ref = pysam.Fastafile(args.fasta)
    transcripts_list = pd.read_csv(args.transcripts_list, header=None)[0].tolist()  
    
    index_transcript = defaultdict(int)  
    
#    selected_chrs = pd.read_csv(args.chrs, header=None)[0].tolist()
    selected_chrs = [args.chrs]
    for selected_chromosome in selected_chrs:
        logging.info(f"selected_chromosome = {selected_chromosome}")
        
        data, id_indexed, chr_len = prepare_gff(args, chrom=[selected_chromosome])
        if data.empty:
            continue

        for transcript in tqdm(transcripts_list):
            process_transcript(transcript, data, id_indexed, chr_len, model,
                               model_len, tokenizer, ref, args.output, args.h5, index_transcript)

                   
parser = argparse.ArgumentParser()
parser.add_argument("--chrs", type=str, default="all", help="Path to TXT file with chromosomes' names")
parser.add_argument("--gff", type=str, help="Path to GFF file")
parser.add_argument("--fasta", type=str, help="Path to FASTA file")

#parser.add_argument("--ckpt", type=str, help="Path to model checkpoint")
parser.add_argument("--model_len", type=int, default=49992, help="Lenght of input sequence")

parser.add_argument("--transcripts_list", type=str, help="Path to list of transcripts file")
parser.add_argument("--output", type=str, help="Path to output")
parser.add_argument("--h5", type=str, help="Name of hdf5 file")
parser.add_argument(
    "--model",
    type=str,
    default="InstaDeepAI/segment_nt",
)
args = parser.parse_args()

main_process(args)