#!/usr/bin/env python
# -*- coding: utf-8 -*-
import h5py
import argparse
from tqdm import tqdm
import numpy as np
import pandas as pd
import pysam
import logging
from transformers import AutoTokenizer, AutoConfig
from collections import defaultdict

logging.basicConfig(format="%(asctime)s %(message)s", level=logging.INFO)
np.set_printoptions(threshold=np.inf)

def search_parts(transcript_content, part, transcript_start, pred_end):
    
    part_cotent = transcript_content[transcript_content["type"] == part]
        
    starts_part = np.array(part_cotent['start'].tolist())
    ends_part = np.array(part_cotent['end'].tolist())

    ends_part = np.where(ends_part > pred_end, pred_end, ends_part) #!!!!!!!!!!!!!!!

    return set(zip(starts_part - transcript_start, ends_part - transcript_start))
  
                    
def split_info_fields(info_record, list_records):

    records = info_record.split(";")

    result = {}
    for record in records:
        record_data = record.split("=")
        result[record_data[0]] = record_data[1]
   
    list_records.append(result)
                   
def get_sequence(
    start_d,
    end_d,
    chr_name,
    reference,
):  
    """
    Get part of DNA sequence in string
    """
    DNA_seq = reference.fetch(reference=chr_name, start=start_d, end=end_d)
    DNA_seq = DNA_seq.upper()
    
    if len(set(DNA_seq) - set('ACTG')) > 0:
        return 'N', 'N'
        
    DNA_seq_reverse = DNA_seq.replace('T', 'O').replace('C', 'Z').replace('A', 'T').replace('G', 'C').replace('O', 'A').replace('Z', 'G')[::-1]
    
    return DNA_seq, DNA_seq_reverse

def find_segments_ones(array):
    
    ones_idx = np.where(array == 1)[0]
    if ones_idx.size == 0:
        return []

    split_idx = np.where(np.diff(ones_idx) > 1)[0] + 1
    split_ones_idx = np.split(ones_idx, split_idx)

    return [(segment[0], segment[-1] + 1) for segment in split_ones_idx]

def prepare_gff(args, chrom=['NC_060944.1']):
    col_names = [
        "seqid",
        "source",
        "type",
        "start",
        "end",
        "score",
        "strand",
        "phase",
        "attributes",
    ]
    
    data_gff = pd.read_csv(args.gff, sep="\t", names=col_names, header=None, comment="#")
    data_gff["start"] = (data_gff["start"] - 1)  # coodinates in GFF file are 1-based, convert to 0-based
    # data_gff["end"] = data_gff["end"] - 1 # do not substract from the end; intervals in GFF are closed, but now we can consider
    #                               them as half-opened intervals
    
    data_gff['lens'] = data_gff['end'] - data_gff['start']
        
       
    data = data_gff[(data_gff["type"].isin(['gene', 'mRNA', 'lnc_RNA', 'exon', 'CDS'])) & (data_gff["seqid"].isin(chrom))]
    data.reset_index(drop=True, inplace=True)
  
    new_col_list = []
    data["attributes"].apply(split_info_fields, list_records=new_col_list)

    new_cols = pd.DataFrame.from_dict(new_col_list)
    data = pd.concat((data, new_cols), axis=1)
        
    return data

def max_pred(p_labels, part, labels):
    if part == 'exon':
        chosen_p_labels = p_labels[[labels.index('exon'), labels.index('intron')]]
    elif part == 'CDS':
        chosen_p_labels =p_labels[[labels.index('exon'), labels.index('intron'), labels.index('5UTR'), labels.index('3UTR')]]
    
    max_p_labels = np.max(chosen_p_labels, axis=0)
    preds = np.zeros(max_p_labels.shape)
    preds[chosen_p_labels[0] == max_p_labels] = 1

    return preds
    
def check_transcript(transcript, p_labels, gene_coordinates, gff_indexed, part, labels, metrics, gene_name):

    transcript_content = gff_indexed[gff_indexed['Parent'] == transcript]
    transcript_class = gff_indexed.loc[transcript]['type']

    if transcript_class != "mRNA" and part == 'CDS':
        return 1 
    
    transcript_start = gff_indexed.loc[transcript]['start']
    transcript_end = gff_indexed.loc[transcript]['end']    

    part_set = search_parts(transcript_content, part, transcript_start, gene_coordinates[1])
    
    transcript_coordinates = [transcript_start, transcript_end]
       
    transcript_p_labels = p_labels[:, transcript_coordinates[0] - gene_coordinates[0]:transcript_coordinates[0] - gene_coordinates[0] + transcript_end - transcript_start]
    
    preds = max_pred(transcript_p_labels, part, labels)

    if set(find_segments_ones(preds)) == part_set:
        metrics[f"{transcript_class}_{part}"].append(gene_name)
        return 1
        
    return 0


def process_gene(idx, data, labels, transcript_level_labels, gff, gff_indexed, direction, metrics, n_genes):
    """
    Find the transcript and compute metrics
    """
    logging.debug(f"------------new transcript start------------") 
    
    sample_name = f"transcript_{str(idx)}"    
    gene_seq = data[sample_name].attrs['gene_seq']
    transcript_seq = data[sample_name].attrs['transcript_seq']

    if gene_seq == 'N' or transcript_seq == 'N':
        n_genes['N in gene'] += 1
        return

    p_labels_forward = np.array(data[sample_name]["gene_predictions_forward"]) #["predictions_mean"]#["coordinates"]
#    p_labels_reverse = np.array(data[sample_name]["gene_predictions_reverse"])

    gene_coordinates = np.array(data[sample_name]["gene_coordinates"])
    gene_name = data[sample_name].attrs['Parent']

#    if p_labels_forward.shape[1] != p_labels_reverse.shape[1]:
#        metrics[f"dif_seq"].append(gene_name)
#        return
        
    if direction == 'forward':
        p_labels = p_labels_forward
#    elif direction == 'reverse':
#        p_labels = p_labels_reverse
#    elif direction == 'mean':
#        p_labels = (p_labels_forward + p_labels_reverse) / 2
                    
    transcripts_list = gff[gff['Parent'] == gene_name]['ID'].to_list()

    for part in transcript_level_labels:
        for transcript in transcripts_list:
            status = check_transcript(transcript, p_labels, gene_coordinates,
                                      gff_indexed, part, labels, metrics, gene_name)
            if status == 1:
                break

    

def main_process(args):
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)

#    if args.direction == 'both':
#        directions = ['forward', 'reverse']
    if args.direction == 'mean':
        direction = 'mean'
    else:
        direction = str(args.direction)
    
    labels = ["5UTR", "exon", "intron", "3UTR"] #, "CDS"]
    transcript_level_labels = ["exon"] #, "CDS"]
    transcript_classes = ['mRNA', 'lnc_RNA']
    
    data = h5py.File(args.input, "r")
    length_data = len(list(data.keys()))

    metrics = defaultdict(list)
    n_genes = defaultdict(int)
 
#    selected_chrs = pd.read_csv(args.chrs, header=None)[0].tolist()
    selected_chrs = [args.chrs]
   
    gff = prepare_gff(args, selected_chrs)
    gff_indexed = gff.set_index("ID")


    
    for idx in tqdm(range(length_data)):
        process_gene(idx, data, labels, transcript_level_labels, gff, gff_indexed, direction, metrics, n_genes)

    with open(f"{args.output}_genes.txt", 'a') as metrics_file:
        for genes_type in metrics:
            if "RNA" in genes_type:
                for transcript in metrics[genes_type]:
                    metrics_file.write(f"{transcript}\t{genes_type.rsplit('_', maxsplit=1)[0]}\t{genes_type.rsplit('_', maxsplit=1)[1]}\n")

    with open(f"{args.output}_count.txt", 'a') as metrics_file:
        for genes_type in metrics:
            metrics_file.write(f"{genes_type} {len(metrics[genes_type])}\n")
        metrics_file.write(f"{n_genes}\n")



parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, help="Path to input hdf5 file")
parser.add_argument("--gff", type=str, help="Path to input hdf5 file")
parser.add_argument("--chrs", type=str, default="all", help="Path to TXT file with chromosomes' names")
parser.add_argument("--direction", type=str, help="Type of predictions")
parser.add_argument("--output", type=str, help="Path to output faa file")
parser.add_argument(
    "--tokenizer_name",
    type=str,
    default="AIRI-Institute/gena-lm-bigbird-base-t2t",
    help="tokenizer",
)
args = parser.parse_args()

main_process(args)