#!/usr/bin/env python

import argparse
import pandas as pd
from Bio import SeqIO
from project.constants import AA_SCALES, PADDING_VALUE, eisenberg_scale, kyle_doolittle_scale
from project.metrics import Diversity, DiversityPredictedPositives, FitnessScore, NovelAMP, AMPProbability, Precision, PropertyKSDistance, Recall, FrechetAminoacidEmbeddingDistance, Uniqueness
from project.wrappers import HydrophobicScaleWrapper, combine_scales
from random import sample


def load_fasta_to_list(path_to_fasta, max_length):
    """Load sequences from a fasta file into a list."""
    sequences = []
    with open(path_to_fasta, 'r') as file:
        for record in SeqIO.parse(file, "fasta"):
            seq = str(record.seq)
            if len(seq) <= max_length:
                sequences.append(seq)
    return sequences


def main():
    parser = argparse.ArgumentParser(description='Compute metrics for generated sequences.')
    parser.add_argument('--path_to_curated_fasta', type=str, default='data/activity-data/curated-AMPs.fasta', help='Path to the input Fasta file with AMP sequences')
    parser.add_argument('--path_to_all_fasta', type=str, default='data/generative-model-data/AMPs.fasta', help='Path to the input Fasta file with all sequences')
    parser.add_argument('--path_to_fasta', type=str, default='results/generative-model-results/omegamp-generated-samples.fasta', help='Path to the FASTA file with generated sequences')
    parser.add_argument('--path_to_classifier', type=str, default='models/broad-classifier.json', help='Path to the AMP classifier model')
    parser.add_argument('--no_generated_samples', type=int, default=50000, help='Number of generated samples to evaluate')
    parser.add_argument('--min_length', type=int, default=1, help='Maximum length for sequences')
    parser.add_argument('--max_length', type=int, default=100, help='Maximum length for sequences')
    parser.add_argument('--output_csv', type=str, default=None, help='Path to save metrics results as CSV file')
    
    args = parser.parse_args()

    # Load AMP sequences from fasta
    curated_amp_sequences = load_fasta_to_list(args.path_to_curated_fasta, args.max_length)

    all_amp_sequences = load_fasta_to_list(args.path_to_all_fasta, args.max_length)

    # Load generated sequences from FASTA
    generated_sequences = load_fasta_to_list(args.path_to_fasta, args.max_length)
    generated_sequences = [seq for seq in generated_sequences if len(seq) <= args.max_length and len(seq) >= args.min_length]

    if len(generated_sequences) > args.no_generated_samples:
        generated_sequences = sample(generated_sequences, args.no_generated_samples)
    elif len(generated_sequences) == args.no_generated_samples:
        pass
    else:
        print(f"Warning: Number of generated sequences is less than the specified number of samples. Using all {len(generated_sequences)} generated sequences.")

    aa_scales = ["wimley_white_with_min_spacing", "pI", "levitt", "transmembrane_propensity", "aasi"]
    generative_model_scale = combine_scales([AA_SCALES[scale] for scale in aa_scales])
    generative_model_wrapper = HydrophobicScaleWrapper(generative_model_scale, args.max_length, PADDING_VALUE)
            

    # Initialize metrics
    amp_probability_metric = AMPProbability(args.path_to_classifier)
    novel_amp_metric = NovelAMP(all_amp_sequences)
    evaluation_scale = combine_scales([eisenberg_scale, kyle_doolittle_scale])
    wrapper = HydrophobicScaleWrapper(evaluation_scale, args.max_length, PADDING_VALUE)  
    precision_metric = Precision(wrapper, curated_amp_sequences)
    recall_metric = Recall(wrapper, curated_amp_sequences)
    faed_metric = FrechetAminoacidEmbeddingDistance(wrapper, curated_amp_sequences)
    properties_ks_distance_metric = PropertyKSDistance(["length", "charge", "hydrophobicity_eisenberg"], curated_amp_sequences) # FIXME hardcoded
    diversity_metric = Diversity()
    diversity_predicted_positives_metric = DiversityPredictedPositives(args.path_to_classifier)
    uniqueness_metric = Uniqueness()
    fitness_score_metric = FitnessScore()
    
    # Compute metrics
    metrics_results = {}
    
    amp_probability = amp_probability_metric(generated_sequences)
    metrics_results['AMPProbability'] = amp_probability
    print(f"AMPProbability: {amp_probability}")
    
    novel_amp = novel_amp_metric(generated_sequences)
    metrics_results['NovelAMP'] = novel_amp
    print(f"NovelAMP: {novel_amp}")
    
    precision = precision_metric(generated_sequences)
    metrics_results['Precision'] = precision
    print(f"Precision: {precision}")
    
    recall = recall_metric(generated_sequences)
    metrics_results['Recall'] = recall
    print(f"Recall: {recall}")
    
    faed = faed_metric(generated_sequences)
    metrics_results['FrechetAminoacidEmbeddingDistance'] = faed
    print(f"FrechetAminoacidEmbeddingDistance: {faed}")
    
    properties_ks_distances = properties_ks_distance_metric(generated_sequences)
    for property in properties_ks_distances:
        metrics_results[f'PropertyKSDistance_{property}'] = properties_ks_distances[property]
        print(f"PropertyKSDistance ({property}): {properties_ks_distances[property]}")
    
    uniqueness = uniqueness_metric(generated_sequences)
    metrics_results['Uniqueness'] = uniqueness
    print(f"Uniqueness: {uniqueness}")
    
    fitness_score = fitness_score_metric(generated_sequences)
    metrics_results['FitnessScore'] = fitness_score
    print(f"FitnessScore: {fitness_score}")
    
    diversity = diversity_metric(generated_sequences)
    metrics_results['Diversity'] = diversity
    print(f"Diversity: {diversity}")
    
    diversity_predicted_positives = diversity_predicted_positives_metric(generated_sequences)
    metrics_results['DiversityPredictedPositives'] = diversity_predicted_positives
    print(f"DiversityPredictedPositives: {diversity_predicted_positives}")
    
    
    # Save metrics to CSV if output path is provided
    if args.output_csv:
        df = pd.DataFrame([metrics_results])
        df.to_csv(args.output_csv, index=False)
        print(f"Metrics saved to {args.output_csv}")
    
if __name__ == "__main__":
    main()
