#!/usr/bin/env python3

import matplotlib.pyplot as plt
import os
import numpy as np
import seaborn as sns
from collections import Counter
from Bio import SeqIO


def load_fasta_sequences(fasta_path):
    sequences = []
    with open(fasta_path, "r") as handle:
        for record in SeqIO.parse(handle, "fasta"):
            sequences.append(str(record.seq))
    return sequences


def compute_aa_frequencies(sequences):
    aa_counts = Counter()
    total_aa = 0

    for seq in sequences:
        aa_counts.update(seq)
        total_aa += len(seq)

    return {aa: count / total_aa for aa, count in aa_counts.items()}


def main():
    # Set style
    sns.set_style("whitegrid")
    sns.set_palette("Set1")  # Use Set1 color palette for distinct colors
    
    # Increase font sizes globally
    plt.rcParams.update({
        'font.size': 18,
        'axes.labelsize': 18,
        'axes.titlesize': 18,
        'xtick.labelsize': 18,
        'ytick.labelsize': 18,
        'legend.fontsize': 18
    })
    
    # Path to data files
    data_path = "experiments/results/files/generative-samples/"
    
    # Define models and their data sources
    models_and_generated_data = [
        ("OmegAMP", data_path + "OmegAMP/omegamp-generated-samples.fasta"),
        ("HydrAMP", data_path + "HydrAMP/hydramp-generated-samples.fasta"),
        ("AMPGAN", data_path + "amp-gan/amp-gan-samples.fasta"),
        ("AMP-Diffusion", data_path + "AMP-Diffusion/amp-diffusion.fasta"),
        ("Diff-AMP", data_path + "Diff-AMP/diff-amp-samples.fasta"),
    ]
    
    general_amps_data_path = "data/generative-model-data/AMPs.fasta"
    
    # Load sequences and compute frequencies
    aa_frequencies = {}
    if os.path.exists(general_amps_data_path):
        amp_sequences = load_fasta_sequences(general_amps_data_path)
        aa_frequencies["AMPs"] = compute_aa_frequencies(amp_sequences)
    
    for model, fasta_path in models_and_generated_data:
        if os.path.exists(fasta_path):
            sequences = load_fasta_sequences(fasta_path)
            aa_frequencies[model] = compute_aa_frequencies(sequences)
    
    # Create bar plot with grouped bars
    plt.figure(figsize=(20, 8))
    
    # Get all unique amino acids across all models
    all_aas = sorted(set().union(*[freqs.keys() for freqs in aa_frequencies.values()]))
    x = np.arange(len(all_aas))
    width = 0.15  # Width of the bars
    n_models = len(aa_frequencies)
    
    # Plot bars for each model
    max_y = 0
    for i, (model, freqs) in enumerate(aa_frequencies.items()):
        frequencies = [freqs.get(aa, 0) for aa in all_aas]
        max_y = max(max_y, max(frequencies))  # Track the maximum y-value
        plt.bar(x + i * width - (n_models - 1) * width / 2, 
                frequencies,
                width,
                label=model,
                alpha=0.8)
    
    # Increase y-axis limit by 5%
    plt.ylim(0, max_y * 1.05)
    
    # Customize the plot
    plt.xlabel("Amino Acid", fontsize=20, labelpad=10)
    plt.ylabel("Frequency", fontsize=20, labelpad=10)
    plt.xticks(x, all_aas, fontsize=20)
    plt.yticks(fontsize=20)
    
    # Adjust legend with AMPs first
    handles, labels = plt.gca().get_legend_handles_labels()
    ordered_indices = [labels.index("AMPs")] + [i for i, label in enumerate(labels) if label != "AMPs"]
    ordered_handles = [handles[i] for i in ordered_indices]
    ordered_labels = [labels[i] for i in ordered_indices]
    plt.legend(ordered_handles, ordered_labels, loc='upper left', fontsize=14, frameon=True)
    
    # Adjust layout to prevent text cutoff
    plt.tight_layout()
    
    # Create output directory if it doesn't exist
    output_dir = "experiments/results/plots"
    os.makedirs(output_dir, exist_ok=True)
    
    # Save the plot as an SVG file
    output_path = os.path.join(output_dir, "amino_acid_frequencies.svg")
    plt.savefig(output_path, format="svg", bbox_inches="tight")
    
    # Display the plot
    plt.show()


if __name__ == "__main__":
    main()