#!/usr/bin/env python
# -*- coding: utf-8 -*-
import h5py
import argparse
from tqdm import tqdm
import numpy as np
import pandas as pd
import pysam
import logging
import torch
from transformers import AutoTokenizer, AutoConfig
from collections import defaultdict

def split_info_fields(info_record, list_records):

    records = info_record.split(";")

    result = {}
    for record in records:
        record_data = record.split("=")
        result[record_data[0]] = record_data[1]
   
    list_records.append(result)

def find_segments_ones(array):
    
    ones_idx = np.where(array == 1)[0]
    if ones_idx.size == 0:
        return []

    split_idx = np.where(np.diff(ones_idx) > 1)[0] + 1
    split_ones_idx = np.split(ones_idx, split_idx)

    return [(segment[0], segment[-1] + 1) for segment in split_ones_idx]

def prepare_gff(gff, chrom=['NC_060944.1'], mode='full'):
    col_names = [
        "seqid",
        "source",
        "type",
        "start",
        "end",
        "score",
        "strand",
        "phase",
        "attributes",
    ]
    
    data_gff = pd.read_csv(gff, sep="\t", names=col_names, header=None, comment="#")
    data_gff["start"] = (data_gff["start"] - 1)  # coodinates in GFF file are 1-based, convert to 0-based
    # data_gff["end"] = data_gff["end"] - 1 # do not substract from the end; intervals in GFF are closed, but now we can consider
    #                               them as half-opened intervals
    
    data_gff['lens'] = data_gff['end'] - data_gff['start']

    if mode == 'short':
        gene_list = data_gff[data_gff['type'] == 'gene']['seqid'].to_list() 
        return data_gff, gene_list
       
    data = data_gff[(data_gff["type"].isin(['gene', 'mRNA', 'lnc_RNA', 'exon', 'CDS'])) & (data_gff["seqid"].isin(chrom))]
    data.reset_index(drop=True, inplace=True)
  
    new_col_list = []
    data["attributes"].apply(split_info_fields, list_records=new_col_list)

    new_cols = pd.DataFrame.from_dict(new_col_list)
    data = pd.concat((data, new_cols), axis=1)
    
    gene_list = data[(data['type'] == 'gene') & (data['gene_biotype'].isin(['protein_coding', 'lncRNA']))]['ID'].to_list()  
    
    return data, gene_list

def process_exons(exons, transcript_coor):

    processed_exons = []

    for exon_start, exon_end in exons:
        if exon_start > transcript_coor[1] or exon_end < transcript_coor[0]:
            continue

        new_start = max(exon_start, transcript_coor[0])
        new_end = min(exon_end, transcript_coor[1])

        processed_exons.append((new_start, new_end))

    return processed_exons



def search_parts(transcript_content, part, transcript_coor, gene_start, mode='transcript'):
    
    part_cotent = transcript_content[transcript_content["type"] == part]
        
    starts_part = np.array(part_cotent['start'].tolist())
    ends_part = np.array(part_cotent['end'].tolist())

    if mode == 'gene':
        exons = list(zip(starts_part + gene_start, ends_part + gene_start))
        new_exons = process_exons(exons, transcript_coor)
        return new_exons

    return list(zip(starts_part, ends_part))

def check_transcript(transcript, pred_content, gene_start, gff_indexed, part, labels, metrics, gene_name):

    transcript_content = gff_indexed[gff_indexed['Parent'] == transcript]
    transcript_class = gff_indexed.loc[transcript]['type']

    if transcript_class != "mRNA" and part == 'CDS':
        return 1 
    
    transcript_start = gff_indexed.loc[transcript]['start']
    transcript_end = gff_indexed.loc[transcript]['end']    

    part_set = search_parts(transcript_content, part, [], gene_start)
    
    transcript_coordinates = [transcript_start, transcript_end]
 
    preds = search_parts(pred_content, part, transcript_coordinates, gene_start, mode='gene')


    if set(preds) == set(part_set):
        metrics[f"{transcript_class}_{part}"].append(gene_name)
        return 1
        
    return 0


def process_gene(name, data, labels, transcript_level_labels, gff, gff_indexed, metrics, n_genes, data_names):
    """
    Find the transcript and compute metrics
    """
    logging.debug(f"------------new transcript start------------") 

    if name not in data_names:
        return
    
    gene_info = gff_indexed.loc[name]
    gene_strand = gene_info['strand']
    gene_start = gene_info['start']

    pred_content = data[(data['seqid'] == name) & (data['strand'] == gene_strand)]    
              
    transcripts_list = gff[gff['Parent'] == name]['ID'].to_list()

    for part in transcript_level_labels:
        for transcript in transcripts_list:
            status = check_transcript(transcript, pred_content, gene_start,
                                      gff_indexed, part, labels, metrics, name)
            if status == 1:
                break


def main_process(args):
    
    labels = ["5UTR", "exon", "intron", "3UTR", "CDS"]
    transcript_level_labels = ["exon", "CDS"]
    transcript_classes = ['mRNA', 'lnc_RNA']
    
    data, data_names = prepare_gff(args.input, mode='short')

    metrics = defaultdict(list)
    n_genes = defaultdict(int)

    selected_chrs = [args.chrs]
   
    gff, gene_list = prepare_gff(args.gff, chrom=selected_chrs)
    gff_indexed = gff.set_index("ID")

    
    for name in tqdm(gene_list):
        process_gene(name, data, labels, transcript_level_labels, gff, gff_indexed, metrics, n_genes, data_names)

    with open(f"{args.output}_genes.txt", 'a') as metrics_file:
        for genes_type in metrics:
            if "RNA" in genes_type:
                for transcript in metrics[genes_type]:
                    metrics_file.write(f"{transcript}\t{genes_type.rsplit('_', maxsplit=1)[0]}\t{genes_type.rsplit('_', maxsplit=1)[1]}\n")

    with open(f"{args.output}_count.txt", 'a') as metrics_file:
        for genes_type in metrics:
            metrics_file.write(f"{genes_type} {len(metrics[genes_type])}\n")
        metrics_file.write(f"{n_genes}\n")


parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, help="Path to input hdf5 file")
parser.add_argument("--gff", type=str, help="Path to input hdf5 file")
parser.add_argument("--chrs", type=str, default="all", help="Path to TXT file with chromosomes' names")
#parser.add_argument("--gene_lens", type=str, help="Type of predictions")
parser.add_argument("--output", type=str, help="Path to output faa file")
parser.add_argument(
    "--tokenizer_name",
    type=str,
    default="AIRI-Institute/gena-lm-bigbird-base-t2t",
    help="tokenizer",
)
args = parser.parse_args()

main_process(args)