#!/usr/bin/env python
# -*- coding: utf-8 -*-
import h5py
import argparse
from tqdm import tqdm
import numpy as np
import pandas as pd
import pysam
import logging
from transformers import AutoTokenizer, AutoConfig
from collections import defaultdict
from Bio import Seq

logging.basicConfig(format="%(asctime)s %(message)s", level=logging.INFO)
np.set_printoptions(threshold=np.inf)

def split_info_fields(info_record, list_records):
    
    records = info_record.split(";")
    result = {}
    for record in records:
        record_data = record.split("=")
        result[record_data[0]] = record_data[1]
    
    list_records.append(result)

def compute_metrics(direction, transcript_classes, part, metrics):

    TP = sum(metrics[f'TP_{direction}_{transcript_class}_{part}'] for transcript_class in transcript_classes)
    FP = sum(metrics[f'FP_{direction}_{transcript_class}_{part}'] for transcript_class in transcript_classes)
    FN = sum(metrics[f'FN_{direction}_{transcript_class}_{part}'] for transcript_class in transcript_classes)

    recall = TP / (TP + FN) if TP + FN > 0 else 0

    precision = TP / (TP + FP) if TP + FP > 0 else 0

    f1 = 2 * recall * precision / (recall + precision) if recall + precision > 0 else 0

    return precision, recall, f1

def find_segments_ones(array):
    
    ones_idx = np.where(array == 1)[0]
    if ones_idx.size == 0:
        return []

    split_idx = np.where(np.diff(ones_idx) > 1)[0] + 1
    split_ones_idx = np.split(ones_idx, split_idx)

    return [(segment[0], segment[-1] + 1) for segment in split_ones_idx]

def max_pred(p_labels, part, labels):
    if part == 'exon':
        chosen_p_labels = p_labels[[labels.index('exon'), labels.index('intron')]]
    elif part == 'CDS':
        chosen_p_labels =p_labels[[labels.index('exon'), labels.index('intron'), labels.index('5UTR'), labels.index('3UTR')]]
    
    max_p_labels = np.max(chosen_p_labels, axis=0)
    preds = np.zeros(max_p_labels.shape)
    preds[chosen_p_labels[0] == max_p_labels] = 1

    return preds


def find_strand(p_labels, labels):
    
    left_5UTR = np.sum(p_labels[labels.index('5UTR'), :50])
    left_3UTR = np.sum(p_labels[labels.index('3UTR'), :50])

    right_5UTR = np.sum(p_labels[labels.index('5UTR'), -50:])
    right_3UTR = np.sum(p_labels[labels.index('3UTR'), -50:])
    
    pred_strand = ''
    
    if (left_5UTR - left_3UTR) - (right_5UTR - right_3UTR) >= 0:
        pred_strand = '+'
    elif (left_5UTR - left_3UTR) - (right_5UTR - right_3UTR) < 0:
        pred_strand = '-'

    return pred_strand
    
def gene_process(seq, strand):
    protein_vars_1_frames = []
    protein_vars_3_frames = []

    if strand == "+":
        for orf in range(3):
            protein_var = Seq.translate(seq[orf:], to_stop=False)
            [protein_vars_3_frames.append(protein) for protein in protein_var.split('*')]
            if orf == 0:
                [protein_vars_1_frames.append(protein) for protein in protein_var.split('*')]
                         
                
    elif strand == "-":
        reverse_seq = Seq.reverse_complement(seq)
        for orf in range(3):
            protein_var = Seq.translate(reverse_seq[orf:], to_stop=False)
            [protein_vars_3_frames.append(protein) for protein in protein_var.split('*')]
            if orf == 0:
                [protein_vars_1_frames.append(protein) for protein in protein_var.split('*')]

    max_len_3_frames = max(len(s) for s in protein_vars_3_frames)
    max_proteins_3_frames = [s for s in protein_vars_3_frames if len(s) == max_len_3_frames]

    max_len_1_frames = max(len(s) for s in protein_vars_1_frames)
    max_proteins_1_frames = [s for s in protein_vars_1_frames if len(s) == max_len_1_frames]

    return max_proteins_1_frames, max_proteins_3_frames


def run_protein(p_labels, direction, part, labels, transcript_strand, transcript_seq, transcript_name, args):

    pred_strand = find_strand(p_labels, labels)
    if pred_strand != transcript_strand:
        return

    preds = max_pred(p_labels, part, labels)

    transcript_seq_arr = np.array(list(transcript_seq))
    
    found_gene = "".join(transcript_seq_arr[preds == 1])
    max_proteins_1_frames, max_proteins_3_frames = gene_process(found_gene, pred_strand)

    path = len(args.output.split('/')[-1])
    
    protein_number = 1
    if max_proteins_1_frames != []:
        for protein in max_proteins_1_frames:                
            if not (len(protein) < 10 or protein == None): 
                with open(f'{args.output}/{part}_{args.h5}_1.faa', "a") as mtext:
                    mtext.write(f">{transcript_name}.{protein_number}\n{protein}\n")
                protein_number += 1

    protein_number = 1
    if max_proteins_3_frames != []:
        for protein in max_proteins_3_frames:                
            if not (len(protein) < 10 or protein == None): 
                with open(f'{args.output}/{part}_{args.h5}_3.faa', "a") as mtext:
                    mtext.write(f">{transcript_name}.{protein_number}\n{protein}\n")
                protein_number += 1







def run_exon_level(p_labels, targets, direction, part, transcript_class, labels, metrics):
    
    preds = max_pred(p_labels, part, labels)
    
    p_labels_segments = find_segments_ones(preds)    
    y_labels_segments = find_segments_ones(targets)
    
    p_exons_set = set(p_labels_segments)    
    y_exons_set = set(y_labels_segments)
    
    metrics[f'TP_{direction}_{transcript_class}_{part}'] += len(y_exons_set & p_exons_set)
    metrics[f'FP_{direction}_{transcript_class}_{part}'] += len(p_exons_set - y_exons_set)
    metrics[f'FN_{direction}_{transcript_class}_{part}'] += len(y_exons_set - p_exons_set)

def process_transcript(transcript, data, direction, args, shift_ir=0):
    logging.debug(f"------------new transcript start------------") 

    transcript_name = transcript.attrs['ID']
    transcript_class = transcript.attrs['type']
    transcript_strand = transcript.attrs['strand']
    transcript_seq = transcript.attrs['transcript_seq']

    if transcript_class == 'mRNA':
    
        transcript_predictions_forward = np.array(transcript["transcript_predictions_forward"])
#        transcript_predictions_reverse = np.array(transcript["transcript_predictions_reverse"])
        
        labels = ['5UTR', 'exon', 'intron', '3UTR', 'CDS']
    

        if direction == 'forward':
            p_labels = transcript_predictions_forward
#            elif direction == 'reverse':
#                p_labels = transcript_predictions_reverse
#            elif direction == 'mean':
#                p_labels = (transcript_predictions_forward + transcript_predictions_reverse) / 2
             
        for part in ['exon', 'CDS']:
            run_protein(p_labels, direction, part, labels, transcript_strand, transcript_seq, transcript_name, args)
        
def main_process(args):
    transcripts_list = pd.read_csv(args.transcripts_list, header=None)[0].tolist()

    direction = str(args.direction)

    data = h5py.File(args.input, "r")
    length_data = len(list(data.keys()))

    
    for idx in tqdm(range(length_data)):
        transcript = data[f'transcript_{idx}']
        process_transcript(transcript, data, direction, args)

        
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, help="Path to input file")
parser.add_argument("--transcripts_list", type=str, help="Path to list of transcripts names file")
parser.add_argument("--direction", type=str, help="Path to output file")
parser.add_argument("--output", type=str, help="Path to output file")
parser.add_argument("--h5", type=str, help="Path to output file")
args = parser.parse_args()

main_process(args)