#!/usr/bin/env python
# -*- coding: utf-8 -*-
import h5py
import argparse
from tqdm import tqdm
import numpy as np
import pandas as pd
import pysam
import logging
import torch
from transformers import AutoTokenizer, AutoConfig
from collections import defaultdict
from Bio import Seq

logging.basicConfig(format="%(asctime)s %(message)s", level=logging.INFO)
np.set_printoptions(threshold=np.inf)

def split_info_fields(info_record, list_records):

    records = info_record.split(";")

    result = {}
    for record in records:
        record_data = record.split("=")
        result[record_data[0]] = record_data[1]
   
    list_records.append(result)

def find_segments_ones(array):
    
    ones_idx = np.where(array == 1)[0]
    if ones_idx.size == 0:
        return []

    split_idx = np.where(np.diff(ones_idx) > 1)[0] + 1
    split_ones_idx = np.split(ones_idx, split_idx)

    return [(segment[0], segment[-1] + 1) for segment in split_ones_idx]

def prepare_gff(gff, chrom=['NC_060944.1'], mode='full'):
    col_names = [
        "seqid",
        "source",
        "type",
        "start",
        "end",
        "score",
        "strand",
        "phase",
        "attributes",
    ]
    
    data_gff = pd.read_csv(gff, sep="\t", names=col_names, header=None, comment="#")
    data_gff["start"] = (data_gff["start"] - 1)  # coodinates in GFF file are 1-based, convert to 0-based
    # data_gff["end"] = data_gff["end"] - 1 # do not substract from the end; intervals in GFF are closed, but now we can consider
    #                               them as half-opened intervals
    
    data_gff['lens'] = data_gff['end'] - data_gff['start']

    if mode == 'short':
        gene_list = data_gff[data_gff['type'] == 'transcript']['seqid'].to_list() 
        return data_gff, gene_list
       
    data = data_gff[(data_gff["type"].isin(['gene', 'mRNA', 'lnc_RNA', 'exon', 'CDS'])) & (data_gff["seqid"].isin(chrom))]
    data.reset_index(drop=True, inplace=True)
  
    new_col_list = []
    data["attributes"].apply(split_info_fields, list_records=new_col_list)

    new_cols = pd.DataFrame.from_dict(new_col_list)
    data = pd.concat((data, new_cols), axis=1)
    
    gene_list = data[(data['type'] == 'gene') & (data['gene_biotype'].isin(['protein_coding', 'lncRNA']))]['ID'].to_list()  
    
    return data, gene_list

def compute_metrics(transcript_classes, part, metrics):

    TP = sum(metrics[f'TP_{transcript_class}_{part}'] for transcript_class in transcript_classes)
    FP = sum(metrics[f'FP_{transcript_class}_{part}'] for transcript_class in transcript_classes)
    FN = sum(metrics[f'FN_{transcript_class}_{part}'] for transcript_class in transcript_classes)

    recall = TP / (TP + FN) if TP + FN > 0 else 0

    precision = TP / (TP + FP) if TP + FP > 0 else 0

    f1 = 2 * recall * precision / (recall + precision) if recall + precision > 0 else 0

    return precision, recall, f1

def search_parts(transcript_content, part, gene_start, mode='transcript'):
    
    part_cotent = transcript_content[transcript_content["type"] == part]
        
    starts_part = np.array(part_cotent['start'].tolist())
    ends_part = np.array(part_cotent['end'].tolist())

    if mode == 'gene':
        exons = list(zip(starts_part + gene_start, ends_part + gene_start))
#        new_exons = process_exons(exons, transcript_coor)
        return exons #new_exons

    return list(zip(starts_part, ends_part))

def run_exon_level(pred_set, part_set, part, transcript_class, metrics):


    p_exons_set = set(pred_set)    
    y_exons_set = set(part_set)
    
    metrics[f'TP_{transcript_class}_{part}'] += len(y_exons_set & p_exons_set)
    metrics[f'FP_{transcript_class}_{part}'] += len(p_exons_set - y_exons_set)
    metrics[f'FN_{transcript_class}_{part}'] += len(y_exons_set - p_exons_set)


def gene_process(seq, strand):
    protein_vars_1_frames = []
    protein_vars_3_frames = []

    if strand == "+":
        for orf in range(3):
            protein_var = Seq.translate(seq[orf:], to_stop=False)
            [protein_vars_3_frames.append(protein) for protein in protein_var.split('*')]
            if orf == 0:
                [protein_vars_1_frames.append(protein) for protein in protein_var.split('*')]
                         
                
    elif strand == "-":
        reverse_seq = Seq.reverse_complement(seq)
        for orf in range(3):
            protein_var = Seq.translate(reverse_seq[orf:], to_stop=False)
            [protein_vars_3_frames.append(protein) for protein in protein_var.split('*')]
            if orf == 0:
                [protein_vars_1_frames.append(protein) for protein in protein_var.split('*')]

    max_len_3_frames = max(len(s) for s in protein_vars_3_frames)
    max_proteins_3_frames = [s for s in protein_vars_3_frames if len(s) == max_len_3_frames]

    max_len_1_frames = max(len(s) for s in protein_vars_1_frames)
    max_proteins_1_frames = [s for s in protein_vars_1_frames if len(s) == max_len_1_frames]

    return max_proteins_1_frames, max_proteins_3_frames


def run_protein(p_labels, part, transcript_strand, transcript_seq, transcript_name, pred_strand, args):

    preds = p_labels

    transcript_seq_arr = np.array(list(transcript_seq))
    
    found_gene = "".join(transcript_seq_arr[preds == 1])
    max_proteins_1_frames, max_proteins_3_frames = gene_process(found_gene, pred_strand)

    path = len(args.output.split('/')[-1])
    
#    protein_number = 1
#    if max_proteins_1_frames != []:
#        for protein in max_proteins_1_frames:                
#            if not (len(protein) < 10 or protein == None): 
#                with open(f'{args.output}/{part}_{args.h5}_1.faa', "a") as mtext:
#                    mtext.write(f">{transcript_name}.{protein_number}\n{protein}\n")
#                protein_number += 1
    
    protein_number = 1
    if max_proteins_3_frames != []:
        for protein in max_proteins_3_frames:                
            if not (len(protein) < 10 or protein == None): 
                with open(f'{args.output}/{part}_{args.h5}_3.faa', "a") as mtext:
                    
                    mtext.write(f">{transcript_name}.{protein_number}\n{protein}\n")
                protein_number += 1







def process_transcript(name, data, transcript_classes, part_list, gff, gff_indexed, metrics, n_genes, data_names, ref, args):
    
    logging.debug(f"------------new transcript start------------") 

    target_info = gff_indexed.loc[name]
    target_strand = target_info['strand']
    target_class = target_info['type']
    target_start = target_info['start']
    target_end = target_info['end']

    if target_class == 'mRNA':
        part_list = ['exon', 'CDS']
    else:
        return
    
    pred_content = data[(data['seqid'] == name) & (data['strand'] == target_strand)]    
    target_content = gff_indexed[gff_indexed['Parent'] == name]
    
    for part in part_list:    
        if name not in data_names:
            return
        else: 
            pred_set = search_parts(pred_content, part, target_start, mode='transcript')

            transcript_seq = ref.fetch(reference=f"{name}")
            transcript_seq = transcript_seq.upper()

            p_labels = np.zeros(len(transcript_seq), dtype=int)
            for index, (start, end) in enumerate(pred_set):
                p_labels[start:end] = 1
  
            run_protein(p_labels, part, target_strand, transcript_seq, name, target_strand, args)

 
        
def main_process(args):
    part_list = ['exon', 'CDS']
    transcript_classes = ['mRNA', 'lnc_RNA']

    data, data_names = prepare_gff(args.input, mode='short')

    metrics = defaultdict(int)
    n_genes = defaultdict(int)

    selected_chrs = [args.chrs]
   
    gff, gene_list = prepare_gff(args.gff, chrom=selected_chrs)
    gff_indexed = gff.set_index("ID")

    transcripts_list = pd.read_csv(args.transcripts_list, header=None)[0].tolist()
    
    ref = pysam.Fastafile(args.fasta)
    for name in tqdm(transcripts_list):
        process_transcript(name, data, transcript_classes, part_list, gff, gff_indexed, metrics, n_genes, data_names, ref, args)

        
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, help="Path to input file")
parser.add_argument("--gff", type=str, help="Path to input hdf5 file")
parser.add_argument("--fasta", type=str, help="Path to input hdf5 file")
parser.add_argument("--chrs", type=str, default="all", help="Path to TXT file with chromosomes' names")
parser.add_argument("--transcripts_list", type=str, help="Path to input hdf5 file")
parser.add_argument("--output", type=str, help="Path to output file")
parser.add_argument("--h5", type=str, help="Path to output file")
args = parser.parse_args()

main_process(args)