#!/usr/bin/env python

import argparse
import os
import pandas as pd
from project.data import load_fasta_to_df
from project.classifiers import AMPClassifier, HemolyticClassifier
from project.synthetic_data import generate_synthetic_sequences
from project.constants import HQ_AMPs_FILE, HEMOLYTICS_FILE, CLASSIFIER_MODELS

def main(classifier_name, number_of_evaluated_sequences, mutations, additions):
    # Get list of classifiers to evaluate
    classifiers_to_evaluate = []
    if classifier_name == "all":
        classifiers_to_evaluate = list(CLASSIFIER_MODELS.keys())
    else:
        classifiers_to_evaluate = [classifier_name]

    for classifier_name in classifiers_to_evaluate:
        print(f"\n=== Evaluating {classifier_name} ===")
        
        # Generate sequences based on classifier type
        if classifier_name != "hemolytic-classifier":
            random_sequences, shuffled_sequences, mutated_sequences, added_deleted_sequences = generate_synthetic_sequences(
                HQ_AMPs_FILE, 
                number_of_evaluated_sequences, 
                mutations, 
                additions
            )
            classifier = AMPClassifier(model_path=CLASSIFIER_MODELS[classifier_name])
        else:
            random_sequences, shuffled_sequences, mutated_sequences, added_deleted_sequences = generate_synthetic_sequences(
                HEMOLYTICS_FILE, 
                number_of_evaluated_sequences, 
                mutations, 
                additions
            )
            classifier = HemolyticClassifier(model_path=CLASSIFIER_MODELS[classifier_name])

        random_predictions = classifier(random_sequences)
        shuffled_predictions = classifier(shuffled_sequences)
        mutated_predictions = classifier(mutated_sequences)
        added_deleted_predictions = classifier(added_deleted_sequences)

        print(f"Fraction of predicted positives for random sequences: {random_predictions.mean():.3f}")
        print(f"Fraction of predicted positives for shuffled sequences: {shuffled_predictions.mean():.3f}")
        print(f"Fraction of predicted positives for mutated sequences: {mutated_predictions.mean():.3f}")
        print(f"Fraction of predicted positives for sequences with addition and deletion: {added_deleted_predictions.mean():.3f}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Train and validate an XGBoost classifier for AMPs using provided datasets.')
    parser.add_argument('--classifier_name', type=str, default='broad-classifier', 
                      help='Classifier to evaluate. Use "all" to evaluate all available classifiers')
    parser.add_argument('--number_of_evaluated_sequences', type=int, default=5000, help='Number of sequences to evaluate')
    parser.add_argument('--mutations', type=int, default=5, help='Number of mutations per sequence (only for)')
    parser.add_argument('--additions', type=int, default=5, help='Number of additions per sequence (only for added-deleted mode)')
    args = parser.parse_args()

    main(args.classifier_name, args.number_of_evaluated_sequences, args.mutations, args.additions)
