#!/usr/bin/env python

import argparse
import pandas as pd
from Bio import SeqIO
import torch
from hydra import compose, initialize
from project.conditioning import ConditioningMasking
from project.config import load_conditioning
from project.metrics import ConditioningPropertiesMetricsNormalizedMAE



def load_fasta_to_list(path_to_fasta, max_length):
    """Load sequences from a fasta file into a list."""
    sequences = []
    with open(path_to_fasta, 'r') as file:
        for record in SeqIO.parse(file, "fasta"):
            seq = str(record.seq)
            if len(seq) <= max_length:
                sequences.append(seq)
    return sequences


def main():
    parser = argparse.ArgumentParser(description='Compute metrics for generated sequences.')
    parser.add_argument('--path_to_fasta', type=str, default='results/generative-model-results/subset-hq-conditional-samples.fasta', help='Path to the FASTA file with generated sequences')
    parser.add_argument('--path_to_conditioning', type=str, default='results/generative-model-results/subset-hq-conditional-conditioning.pt', help='Path to the file with conditioning vectors')
    parser.add_argument('--min_length', type=int, default=1, help='Maximum length for sequences')
    parser.add_argument('--max_length', type=int, default=100, help='Maximum length for sequences')
    parser.add_argument('--output_csv', type=str, default=None, help='Path to save metrics results as CSV file')
    
    args = parser.parse_args()

    # Load generated sequences from FASTA
    generated_sequences = load_fasta_to_list(args.path_to_fasta, args.max_length)
    generated_sequences = [seq for seq in generated_sequences if len(seq) <= args.max_length and len(seq) >= args.min_length]

    with initialize(version_base=None, config_path="../../../config"):
        config = compose(config_name="train")
        conditioning = load_conditioning(config)
        conditioning_masking = ConditioningMasking(conditioning.computable_names, conditioning.uncomputable_names)

    # Load conditioning vectors
    conditioning_used = torch.load(args.path_to_conditioning, map_location=torch.device('cpu'))

    # Initialize metrics
    conditioning_properties_metrics_normalized_mae = ConditioningPropertiesMetricsNormalizedMAE(conditioning, conditioning_masking)
    
    # Compute metrics
    metrics_results = {}
    metrics_results['ConditioningPropertiesMetricsNormalizedMAE'] = conditioning_properties_metrics_normalized_mae(generated_sequences, conditioning_used)
    print(metrics_results)

    # Save metrics to CSV if output path is provided
    if args.output_csv:
        df = pd.DataFrame([metrics_results])
        df.to_csv(args.output_csv, index=False)
        print(f"Metrics saved to {args.output_csv}")
    
if __name__ == "__main__":
    main()
