#!/usr/bin/env python
# -*- coding: utf-8 -*-
import h5py
import argparse
from tqdm import tqdm
import numpy as np
import pandas as pd
import pysam
import logging
from transformers import AutoTokenizer, AutoConfig
from collections import defaultdict
from sklearn.metrics import average_precision_score

logging.basicConfig(format="%(asctime)s %(message)s", level=logging.INFO)
np.set_printoptions(threshold=np.inf)

def split_info_fields(info_record, list_records):
    
    records = info_record.split(";")
    result = {}
    for record in records:
        record_data = record.split("=")
        result[record_data[0]] = record_data[1]
    
    list_records.append(result)

def compute_metrics(direction, transcript_classes, part, metrics):

    TP = sum(metrics[f'TP_{direction}_{transcript_class}_{part}'] for transcript_class in transcript_classes)
    FP = sum(metrics[f'FP_{direction}_{transcript_class}_{part}'] for transcript_class in transcript_classes)
    FN = sum(metrics[f'FN_{direction}_{transcript_class}_{part}'] for transcript_class in transcript_classes)

    recall = TP / (TP + FN) if TP + FN > 0 else 0

    precision = TP / (TP + FP) if TP + FP > 0 else 0

    f1 = 2 * recall * precision / (recall + precision) if recall + precision > 0 else 0

    return precision, recall, f1

def find_segments_ones(array):
    
    ones_idx = np.where(array == 1)[0]
    if ones_idx.size == 0:
        return []

    split_idx = np.where(np.diff(ones_idx) > 1)[0] + 1
    split_ones_idx = np.split(ones_idx, split_idx)

    return [(segment[0], segment[-1] + 1) for segment in split_ones_idx]

def max_pred(p_labels, part, labels):
    if part == 'exon':
        chosen_p_labels = p_labels[[labels.index('exon'), labels.index('intron')]]
    elif part == 'CDS':
        chosen_p_labels =p_labels[[labels.index('exon'), labels.index('intron'), labels.index('5UTR'), labels.index('3UTR')]]
    
    max_p_labels = np.max(chosen_p_labels, axis=0)
    preds = np.zeros(max_p_labels.shape)
    preds[chosen_p_labels[0] == max_p_labels] = 1

    return preds

def run_exon_level(p_labels, targets, direction, part, transcript_class, labels, metrics):
    
    preds = max_pred(p_labels, part, labels)
    
    p_labels_segments = find_segments_ones(preds)    
    y_labels_segments = find_segments_ones(targets)
    
    p_exons_set = set(p_labels_segments)    
    y_exons_set = set(y_labels_segments)
    
    metrics[f'TP_{direction}_{transcript_class}_{part}'] += len(y_exons_set & p_exons_set)
    metrics[f'FP_{direction}_{transcript_class}_{part}'] += len(p_exons_set - y_exons_set)
    metrics[f'FN_{direction}_{transcript_class}_{part}'] += len(y_exons_set - p_exons_set)

def find_start_end(arr):
    first_index = np.argmax(arr == 1)
    last_index = len(arr) - np.argmax(arr[::-1] == 1)

    return first_index, last_index

def process_transcript(transcript, data, directions, labels, targets_labels, metrics, n_genes):
    
    logging.debug(f"------------new transcript start------------") 

    gene_seq = transcript.attrs['gene_seq']
    transcript_seq = transcript.attrs['transcript_seq']
    transcript_class = transcript.attrs['type']
    transcript_strand = transcript.attrs['strand']
    
    if gene_seq == 'N' or transcript_seq == 'N':
        n_genes[f'N in {transcript_class} gene'] += 1
        return

    transcript_name = transcript.attrs['ID']
    transcript_strand = transcript.attrs['strand']
    
    transcript_predictions_forward = np.array(transcript["transcript_predictions_forward"])
#    transcript_predictions_reverse = np.array(transcript["transcript_predictions_reverse"])
    targets = transcript["targets"]
    

    for direction in directions:
        if direction == 'forward':
            p_labels = transcript_predictions_forward
            
        if transcript_class == 'mRNA':
            part_list = ['exon', 'CDS']
        else:
            part_list = ['exon']
    
        for part in part_list:    
            run_exon_level(p_labels, targets[targets_labels.index(part)], direction, part, transcript_class, labels, metrics) 
           
def main_process(args):
    
    directions = ['forward'] #, 'mean']
    part_list = ['exon', 'CDS']
    transcript_classes = ['mRNA', 'lnc_RNA']
    
    metrics = defaultdict(int)
    n_genes = defaultdict(int)

    data = h5py.File(args.input, "r")
    length_data = len(list(data.keys()))
#    transcripts_list = pd.read_csv(args.transcripts_list, header=None)[0].tolist()
    
    labels = ["5UTR", "exon", "intron", "3UTR", "CDS"]
    targets_labels = ["exon", "CDS"]
    
    for idx in tqdm(range(length_data)):
        transcript = data[f'transcript_{idx}']
        process_transcript(transcript, data, directions, labels, targets_labels, metrics, n_genes)

    for direction in directions:
        for part in labels:
            metrics[f'pr_auc_{direction}_{part}'] = metrics[f'pr_auc_{direction}_{part}'] / length_data    

    with open(args.output, 'a') as metrics_file:
         metrics_file.write(f"{metrics}")


    with open(args.output, 'a') as metrics_file:
        metrics_file.write(f"precision\trecall\tf1\n")
        for direction in directions:
            metrics_file.write(f"{direction}\n")
            metrics_file.write(f"mRNA+lnc_RNA_exon\tmRNA_exon\tlnc_RNA_exon\tmRNA_CDS\n")
            metrics_file.write(f"precision\trecall\tf1\n")
            
            precision, recall, f1 = compute_metrics(direction, ['mRNA', 'lnc_RNA'], 'exon', metrics)
            metrics_file.write(f"{precision}\t{recall}\t{f1}\n")
            
            for transcript_class in ['mRNA', 'lnc_RNA']:                
                precision, recall, f1 = compute_metrics(direction, [transcript_class], 'exon', metrics)                      
                metrics_file.write(f"{precision}\t{recall}\t{f1}\n")
                        
            precision, recall, f1 = compute_metrics(direction, ['mRNA'], 'CDS', metrics)                      
            metrics_file.write(f"{precision}\t{recall}\t{f1}\n")

        metrics_file.write(f"\n\n{metrics}")
        metrics_file.write(f"\n\n{n_genes}")
        
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, help="Path to input file")
#parser.add_argument("--transcripts_list", type=str, help="Path to list of transcripts names file")
parser.add_argument("--output", type=str, help="Path to output file")
args = parser.parse_args()

main_process(args)