#!/usr/bin/env python

import argparse
import json
import pandas as pd
from Bio import SeqIO
from Bio.SeqRecord import SeqRecord
from Bio.Seq import Seq
from Bio.SeqIO.FastaIO import FastaWriter
from project.classifiers import AMPClassifier
from project.expert_filtering import amino_based_filtering
from project.constants import CLASSIFIER_MODELS

def load_fasta_to_list(path_to_fasta):
    """Load sequences from a fasta file into a list."""
    sequences = []
    with open(path_to_fasta, 'r') as file:
        for record in SeqIO.parse(file, "fasta"):
            sequences.append(str(record.seq))
    return sequences

def remove_duplicates(sequences):
    """Remove duplicate sequences."""
    return list(set(sequences))

def remove_training_sequences(sequences, training_sequences):
    """Remove sequences that are already in the training set."""
    return [seq for seq in sequences if seq not in training_sequences]

def save_sequences_to_fasta(sequences, output_path):
    """Save sequences to a fasta file."""
    with open(output_path, 'w') as file:
        fasta_writer = FastaWriter(file, wrap=None)
        fasta_writer.write_file([SeqRecord(Seq(seq), id=f"seq_{i+1}", description="") for i, seq in enumerate(sequences)])

def main(path_to_fasta, path_to_training_set, max_length, min_length, strain_species, output_fasta, output_strain_specific_csv, input_sequences=None, save_intermediate=False):
    # Load sequences from FASTA if not provided
    if input_sequences is None:
        input_sequences = load_fasta_to_list(path_to_fasta)
        print(f"Loaded {len(input_sequences)} sequences from {path_to_fasta}")
    
    # Save stage 1 - input sequences
    if save_intermediate:
        save_sequences_to_fasta(input_sequences, output_fasta.replace('.fasta', '_stage1_input.fasta'))
        print(f"Saved {len(input_sequences)} input sequences")
    
    # Remove sequences with invalid amino acids (X)
    valid_sequences = [seq for seq in input_sequences if 'X' not in seq]
    print(f"Removed sequences with X: {len(valid_sequences)} sequences")

    # Remove duplicates
    unique_sequences = remove_duplicates(valid_sequences)
    print(f"After removing duplicates: {len(unique_sequences)} sequences")

    # Load and filter out training sequences
    training_sequences = load_fasta_to_list(path_to_training_set)
    print(f"Loaded {len(training_sequences)} sequences from {path_to_training_set}")

    novel_sequences = remove_training_sequences(unique_sequences, training_sequences)
    print(f"After removing sequences in the training set: {len(novel_sequences)} sequences")

    # Create DataFrame for length filtering
    sequences_df = pd.DataFrame({'sequence': novel_sequences})

    # Filter sequences based on length
    length_filtered_df = sequences_df[(sequences_df['sequence'].str.len() >= min_length) & 
                                    (sequences_df['sequence'].str.len() <= max_length)]
    length_filtered_sequences = length_filtered_df['sequence'].tolist()
    print(f"After length filtering: {len(length_filtered_sequences)} sequences")
    
    # Save stage 2 - after length/normal filtering
    if save_intermediate:
        save_sequences_to_fasta(length_filtered_sequences, output_fasta.replace('.fasta', '_stage2_length_filtered.fasta'))
        print(f"Saved {len(length_filtered_sequences)} length-filtered sequences")

    if strain_species == 'unconditional':
        # Initialize broad classifier
        broad_classifier = AMPClassifier(model_path=CLASSIFIER_MODELS["broad-classifier"])

        # Predict using broad classifier
        broad_predictions = broad_classifier(length_filtered_sequences)

        # Filter sequences passing broad classifier
        amp_sequences = [seq for seq, is_amp in zip(length_filtered_sequences, broad_predictions) if is_amp == 1]
        print(f"After classifier filtering: {len(amp_sequences)} sequences")
        
        # Save stage 3 - after classifier filtering
        if save_intermediate:
            save_sequences_to_fasta(amp_sequences, output_fasta.replace('.fasta', '_stage3_classifier_filtered.fasta'))
            print(f"Saved {len(amp_sequences)} classifier-filtered sequences")

        if len(amp_sequences) == 0:
            print(f"No sequences to classify for {strain_species}")
            return [], None

        # Create results DataFrame
        results_df = pd.DataFrame({'sequence': amp_sequences})

        # Run additional classifiers
        for classifier_name in CLASSIFIER_MODELS:
            if classifier_name == "broad-classifier" or classifier_name == "hemolytic-classifier":
                continue
            specific_classifier = AMPClassifier(model_path=CLASSIFIER_MODELS[classifier_name])
            predictions = specific_classifier(amp_sequences)
            results_df[classifier_name] = predictions
            
        results_df['rank'] = results_df.iloc[:, 1:].sum(axis=1)
        
        # Filter out sequences with rank 0
        results_df = results_df[results_df['rank'] > 0]
        print(f"After filtering out rank 0 sequences: {len(results_df)} sequences")
        
        results_df = results_df.sort_values(by='rank', ascending=False)
        
        # Save stage 4 - sequences with rank >= 1
        if save_intermediate:
            rank_filtered_sequences = results_df['sequence'].tolist()
            save_sequences_to_fasta(rank_filtered_sequences, output_fasta.replace('.fasta', '_stage4_rank_filtered.fasta'))
            print(f"Saved {len(rank_filtered_sequences)} rank-filtered sequences")

        if output_strain_specific_csv:
            # Save ranked sequences
            results_df.to_csv(output_strain_specific_csv, index=False)
            print(f"Ranked sequences saved to {output_strain_specific_csv}")

        if output_fasta:
            # Save filtered sequences (with rank > 0)
            final_sequences = results_df['sequence'].tolist()
            save_sequences_to_fasta(final_sequences, output_fasta)
            print(f"Filtered and saved {len(final_sequences)} sequences to {output_fasta}")

        return results_df['sequence'].tolist(), results_df
    else:
        # Initialize strain/species-specific classifier
        specific_classifier = AMPClassifier(model_path=CLASSIFIER_MODELS[f"{strain_species}-classifier"])

        if len(length_filtered_sequences) == 0:
            print(f"No sequences to classify for {strain_species}")
            return [], None

        # Get probability predictions
        probability_scores = specific_classifier.predict_proba(length_filtered_sequences)

        # Filter sequences with probability > 0.5
        amp_sequences = [seq for seq, prob in zip(length_filtered_sequences, probability_scores) if prob > 0.5]
        print(f"After classifier filtering: {len(amp_sequences)} sequences")

        # Create results DataFrame
        results_df = pd.DataFrame({
            'sequence': length_filtered_sequences, 
            'prediction_score': probability_scores
        })
        results_df = results_df.sort_values(by='prediction_score', ascending=False)

        if output_strain_specific_csv:
            # Save predictions
            results_df.to_csv(output_strain_specific_csv, index=False)
            print(f"Predictions and sequences saved to {output_strain_specific_csv}")

        if output_fasta:
            # Save filtered sequences
            save_sequences_to_fasta(amp_sequences, output_fasta)
            print(f"Filtered and saved {len(amp_sequences)} sequences to {output_fasta}")

        return amp_sequences, results_df

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Filter generated sequences based on classifiers and save to a FASTA file.')
    parser.add_argument('--path_to_fasta', type=str, default='results/generative-model-results/omegamp-generated-samples.fasta', help='Path to the FASTA file with generated sequences')
    parser.add_argument('--path_to_training_set', type=str, default='data/generative-model-data/AMPs.fasta', help='Path to the training set FASTA file (data/generative-model-data/AMPs.fasta)')
    parser.add_argument('--max_length', type=int, default=30, help='Maximum length of the sequences')
    parser.add_argument('--min_length', type=int, default=6, help='Minimum length of the sequences')
    parser.add_argument('--strain_species', type=str, choices=['unconditional',
                                                               'species-acinetobacterbaumannii', 'species-escherichiacoli', 
                                                               'species-klebsiellapneumoniae', 'species-pseudomonasaeruginosa', 
                                                               'species-staphylococcusaureus',
                                                               'strains-acinetobacterbaumannii-atcc19606',
                                                               'strains-escherichiacoli-atcc25922',
                                                               'strains-klebsiellapneumoniae-atcc700603',
                                                               'strains-pseudomonasaeruginosa-atcc27853',
                                                               'strains-staphylococcusaureus-atcc25923',
                                                               'strains-staphylococcusaureus-atcc33591',
                                                               'strains-staphylococcusaureus-atcc43300'], 
                                                               default='unconditional', help='Strain or species to condition on')
    parser.add_argument('--output_fasta', type=str, default='results/framework-results/script-filtered-sequences.fasta', help='Path to save the filtered sequences in FASTA format')
    parser.add_argument('--output_strain_specific_csv', type=str, default='results/framework-results/script-ranked-sequences.csv', help='Path to save the rank sequences in csv format')
    parser.add_argument('--save_intermediate', action='store_true', help='Save sequences after each filtering stage')
    
    args = parser.parse_args()

    main(args.path_to_fasta, args.path_to_training_set, args.max_length, args.min_length, args.strain_species, args.output_fasta, args.output_strain_specific_csv, save_intermediate=args.save_intermediate)
