#!/usr/bin/env python
# -*- coding: utf-8 -*-
import h5py
import argparse
from tqdm import tqdm
import numpy as np
import pandas as pd
import pysam
import logging
import torch
from transformers import AutoTokenizer, AutoConfig
from collections import defaultdict

logging.basicConfig(format="%(asctime)s %(message)s", level=logging.INFO)
np.set_printoptions(threshold=np.inf)

def split_info_fields(info_record, list_records):

    records = info_record.split(";")

    result = {}
    for record in records:
        record_data = record.split("=")
        result[record_data[0]] = record_data[1]
   
    list_records.append(result)

def find_segments_ones(array):
    
    ones_idx = np.where(array == 1)[0]
    if ones_idx.size == 0:
        return []

    split_idx = np.where(np.diff(ones_idx) > 1)[0] + 1
    split_ones_idx = np.split(ones_idx, split_idx)

    return [(segment[0], segment[-1] + 1) for segment in split_ones_idx]

def prepare_gff(gff, chrom=['NC_060944.1'], mode='full'):
    col_names = [
        "seqid",
        "source",
        "type",
        "start",
        "end",
        "score",
        "strand",
        "phase",
        "attributes",
    ]
    
    data_gff = pd.read_csv(gff, sep="\t", names=col_names, header=None, comment="#")
    data_gff["start"] = (data_gff["start"] - 1)  # coodinates in GFF file are 1-based, convert to 0-based
    # data_gff["end"] = data_gff["end"] - 1 # do not substract from the end; intervals in GFF are closed, but now we can consider
    #                               them as half-opened intervals
    
    data_gff['lens'] = data_gff['end'] - data_gff['start']

    if mode == 'short':
        gene_list = data_gff[data_gff['type'] == 'transcript']['seqid'].to_list() 
        return data_gff, gene_list
       
    data = data_gff[(data_gff["type"].isin(['gene', 'mRNA', 'lnc_RNA', 'exon', 'CDS'])) & (data_gff["seqid"].isin(chrom))]
    data.reset_index(drop=True, inplace=True)
  
    new_col_list = []
    data["attributes"].apply(split_info_fields, list_records=new_col_list)

    new_cols = pd.DataFrame.from_dict(new_col_list)
    data = pd.concat((data, new_cols), axis=1)
    
    gene_list = data[(data['type'] == 'gene') & (data['gene_biotype'].isin(['protein_coding', 'lncRNA']))]['ID'].to_list()  
    
    return data, gene_list

def compute_metrics(transcript_classes, part, metrics):

    TP = sum(metrics[f'TP_{transcript_class}_{part}'] for transcript_class in transcript_classes)
    FP = sum(metrics[f'FP_{transcript_class}_{part}'] for transcript_class in transcript_classes)
    FN = sum(metrics[f'FN_{transcript_class}_{part}'] for transcript_class in transcript_classes)

    recall = TP / (TP + FN) if TP + FN > 0 else 0

    precision = TP / (TP + FP) if TP + FP > 0 else 0

    f1 = 2 * recall * precision / (recall + precision) if recall + precision > 0 else 0

    return precision, recall, f1

def search_parts(transcript_content, part, gene_start, mode='transcript'):
    
    part_cotent = transcript_content[transcript_content["type"] == part]
        
    starts_part = np.array(part_cotent['start'].tolist())
    ends_part = np.array(part_cotent['end'].tolist())

    if mode == 'gene':
        exons = list(zip(starts_part + gene_start, ends_part + gene_start))
#        new_exons = process_exons(exons, transcript_coor)
        return exons #new_exons

    return list(zip(starts_part, ends_part))

def run_exon_level(pred_set, part_set, part, transcript_class, metrics):


    p_exons_set = set(pred_set)    
    y_exons_set = set(part_set)
    
    metrics[f'TP_{transcript_class}_{part}'] += len(y_exons_set & p_exons_set)
    metrics[f'FP_{transcript_class}_{part}'] += len(p_exons_set - y_exons_set)
    metrics[f'FN_{transcript_class}_{part}'] += len(y_exons_set - p_exons_set)



def process_transcript(name, data, transcript_classes, part_list, gff, gff_indexed, metrics, n_genes, data_names):
    
    logging.debug(f"------------new transcript start------------") 

    target_info = gff_indexed.loc[name]
    target_strand = target_info['strand']
    target_class = target_info['type']
    target_start = target_info['start']
    target_end = target_info['end']

    if target_class == 'mRNA':
        part_list = ['exon', 'CDS']
    else:
        part_list = ['exon']
    
    pred_content = data[(data['seqid'] == name) & (data['strand'] == target_strand)]    
    target_content = gff_indexed[gff_indexed['Parent'] == name]
    
    for part in part_list:    
        if name not in data_names:
            pred_set = set()
        else: 
            pred_set = search_parts(pred_content, part, target_start, mode='gene')


        part_set = search_parts(target_content, part, target_start)
        run_exon_level(pred_set, part_set, part, target_class, metrics)

 
        
def main_process(args):
    part_list = ['exon', 'CDS']
    transcript_classes = ['mRNA', 'lnc_RNA']

    data, data_names = prepare_gff(args.input, mode='short')

    metrics = defaultdict(int)
    n_genes = defaultdict(int)

    selected_chrs = [args.chrs]
   
    gff, gene_list = prepare_gff(args.gff, chrom=selected_chrs)
    gff_indexed = gff.set_index("ID")

    transcripts_list = pd.read_csv(args.transcripts_list, header=None)[0].tolist()
    
    for name in tqdm(transcripts_list):
        process_transcript(name, data, transcript_classes, part_list, gff, gff_indexed, metrics, n_genes, data_names)


    with open(args.output, 'a') as metrics_file:
        metrics_file.write(f"precision\trecall\tf1\n")
        metrics_file.write(f"mRNA+lnc_RNA_exon\tmRNA_exon\tlnc_RNA_exon\tmRNA_CDS\n")
        metrics_file.write(f"precision\trecall\tf1\n")
        
        precision, recall, f1 = compute_metrics(['mRNA', 'lnc_RNA'], 'exon', metrics)
        metrics_file.write(f"{precision}\t{recall}\t{f1}\n")
        
        for transcript_class in ['mRNA', 'lnc_RNA']:                
            precision, recall, f1 = compute_metrics([transcript_class], 'exon', metrics)                      
            metrics_file.write(f"{precision}\t{recall}\t{f1}\n")
                    
        precision, recall, f1 = compute_metrics(['mRNA'], 'CDS', metrics)                      
        metrics_file.write(f"{precision}\t{recall}\t{f1}\n")

        metrics_file.write(f"\n\n{metrics}")
        metrics_file.write(f"\n\n{n_genes}")
        
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, help="Path to input file")
parser.add_argument("--gff", type=str, help="Path to input hdf5 file")
parser.add_argument("--chrs", type=str, default="all", help="Path to TXT file with chromosomes' names")
parser.add_argument("--transcripts_list", type=str, help="Path to input hdf5 file")
parser.add_argument("--output", type=str, help="Path to output file")
args = parser.parse_args()

main_process(args)