#!/usr/bin/env python

import argparse
import os
import pandas as pd
from project.classifiers import AMPClassifier, HemolyticClassifier, PeptideClassifier
from project.constants import CLASSIFIER_MODELS, CLASSIFIER_DATASETS, HEMOLYTICS_FILE, HQ_AMPs_FILE
from project.data import get_dataset_for_activity_classifier, get_input_features_labels_mask_high_quality_idxs

def main(no_random_data, no_shuffled_data, no_mutated_data, with_secret_data, weight_balancing, classifier):
    # Load appropriate dataset based on classifier type
    if classifier == 'hemolytic-classifier':
        sequence_label_df = pd.read_csv(CLASSIFIER_DATASETS["toxicity"], dtype={'Id': str})
    else:
        sequence_label_df = pd.read_csv(CLASSIFIER_DATASETS["activity"])

    if no_random_data:
        sequence_label_df = sequence_label_df[~sequence_label_df['Id'].str.contains('random')]
    if no_shuffled_data:
        sequence_label_df = sequence_label_df[~sequence_label_df['Id'].str.contains('shuffled')]
    if no_mutated_data:
        sequence_label_df = sequence_label_df[~sequence_label_df['Id'].str.contains('mutated')]
    
    if with_secret_data:
        secret_data_df = pd.read_csv(CLASSIFIER_DATASETS["activity-secret"])
    else:
        secret_data_df = None

    print(f"Loaded dataset with {len(sequence_label_df)} sequences")

    if classifier == 'all':
        for classifier_name in CLASSIFIER_MODELS:
            if classifier_name == 'hemolytic-classifier':
                dataset = pd.read_csv(CLASSIFIER_DATASETS["toxicity"])
                classifier = HemolyticClassifier(model_path=None)
                reference_file = HEMOLYTICS_FILE
            else:
                dataset = get_dataset_for_activity_classifier(classifier_name, sequence_label_df, secret_data_df=secret_data_df)
                classifier = AMPClassifier(model_path=None)
                reference_file = HQ_AMPs_FILE
            input_features, labels, mask_high_quality_idxs = get_input_features_labels_mask_high_quality_idxs(dataset)
            print(f"Running classifier: {classifier_name}")
            classifier.eval_with_k_fold_cross_validation(input_features, labels, mask_high_quality_idxs=mask_high_quality_idxs, reference_file=reference_file)
            print(f"Finished running classifier: {classifier_name}")
    else:
        dataset = get_dataset_for_activity_classifier(classifier, sequence_label_df, secret_data_df=secret_data_df) if classifier != 'hemolytic-classifier' else sequence_label_df
        input_features, labels, mask_high_quality_idxs = get_input_features_labels_mask_high_quality_idxs(dataset)
        classifier = AMPClassifier(model_path=None) if classifier != 'hemolytic-classifier' else HemolyticClassifier(model_path=None)
        reference_file = HQ_AMPs_FILE if classifier != 'hemolytic-classifier' else HEMOLYTICS_FILE
        classifier.eval_with_k_fold_cross_validation(input_features, labels, weight_balancing=weight_balancing, mask_high_quality_idxs=mask_high_quality_idxs, reference_file=reference_file)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Train and validate an XGBoost classifier for AMPs using provided datasets.')
    parser.add_argument('--no_random_data', action='store_true', help='Do not use random data for training')
    parser.add_argument('--no_shuffled_data', action='store_true', help='Do not use shuffled data for training')
    parser.add_argument('--no_mutated_data', action='store_true', help='Do not use mutated data for training')
    parser.add_argument('--with_secret_data', action='store_true', help='Use secret data for training')
    parser.add_argument('--weight_balancing', type=str, default="balanced_with_adjustment_for_high_quality", help='Specify weights for training')
    parser.add_argument('--classifier', type=str, default='broad-classifier', help='Specify a classifier or "all" to run all classifiers')
    args = parser.parse_args()

    main(args.no_random_data, args.no_shuffled_data, args.no_mutated_data, args.with_secret_data, args.weight_balancing, args.classifier)
