#!/usr/bin/env python

import argparse
import os
import pandas as pd
from project.classifiers import AMPClassifier, HemolyticClassifier
from project.data import load_fasta_to_df
from project.constants import CLASSIFIER_MODELS

def main(fasta_file, classifier_choice, output_csv, predict_proba):

    # Load sequences from fasta file
    sequences = load_fasta_to_df(fasta_file)

    # Check if the user wants to run a single classifier or all classifiers
    if classifier_choice in CLASSIFIER_MODELS:
        model_path = CLASSIFIER_MODELS[classifier_choice]

        # Load the specified classifier model
        model = AMPClassifier(model_path=model_path) if classifier_choice != 'hemolytic-classifier' else HemolyticClassifier(model_path=model_path)
        model.eval()

        # Run inference with the single classifier
        if predict_proba:
            predictions = model.predict_proba(sequences["Sequence"].to_list())
        else:
            predictions = model(sequences["Sequence"].to_list())
        sequences["Prediction"] = predictions

        # Save results with predictions from the single classifier
        sequences.to_csv(output_csv, index=False)
        print(f"Predictions for {len(sequences)} sequences using classifier {classifier_choice} saved to {output_csv}")

        return sequences

    elif classifier_choice == "all":
        # Dictionary to store predictions from all models
        all_predictions = sequences.copy()

        # Pre-compute AMP features since they're shared across AMP classifiers
        amp_features = AMPClassifier(model_path=None).get_input_features(sequences["Sequence"].to_list())

        # Iterate over all classifiers
        for classifier, model_path in CLASSIFIER_MODELS.items():
            print(f"Running classifier: {classifier}")
            model = AMPClassifier(model_path=model_path) if classifier != 'hemolytic-classifier' else HemolyticClassifier(model_path=model_path)
            model.eval()

            # Run inference for the current model
            if classifier == 'hemolytic-classifier':
                sequences_list = sequences["Sequence"].to_list()
                predictions = model.predict_proba(sequences_list) if predict_proba else model(sequences_list)
            else:
                predictions = model.predict_from_features(amp_features, proba=predict_proba)
            all_predictions[classifier] = predictions

        # Save results with predictions from all classifiers
        all_predictions.to_csv(output_csv, index=False)
        print(f"Predictions for {len(sequences)} sequences using all classifiers saved to {output_csv}")

        return all_predictions
    else:
        raise ValueError(f"Classifier {classifier_choice} not found. Available classifiers: {', '.join(CLASSIFIER_MODELS.keys())}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Load a classifier model or ensemble and run inference on protein sequences.')
    parser.add_argument('fasta_file', type=str, help='Path to the fasta file with sequences to predict')
    parser.add_argument('--classifier', type=str, default='broad-classifier', help='Specify a classifier or "all" to run all classifiers')
    parser.add_argument('--output_csv', type=str, default='results/classifier-results/script-classifier-results.csv', help='Path to the output CSV file for predictions')
    parser.add_argument('--predict_proba', action='store_true', help='Output probabilities instead of predictions')
    
    args = parser.parse_args()

    main(args.fasta_file, args.classifier, args.output_csv, args.predict_proba)
