import os
import sys
import pandas as pd
from evaluate_classifiers import evaluate_classifiers
import prepare_imdb
import argparse

env_path = os.getcwd()
print(f"Environment path: {env_path}")
sys.path.append(env_path)

def main(dataset_name=None):
    if dataset_name == None:
        print("Select the dataset:")
        print("1. IMDB")
        print("2. One Billion Word")
        
        choice = input("Enter the number corresponding to your choice: ")
        if choice == '1':
            db_name = 'imdb'
        elif choice == '2':
            db_name = '1billion'
        else:
            print("Invalid choice. Please select a valid dataset.")
            return
    else:
        if dataset_name != "imdb" and dataset_name != "1billion":
            print("Invalid dataset name. Please select either 'imdb' or '1billion'.")
            return
        db_name = dataset_name
    
    csv_path = f'{db_name}_classifier_evaluation_results.csv'
    model_base_path = os.path.join("data", db_name)

    # For quick testing
    # vocavbulary = 5
    # percent_to_change = 1
    # iteration = 1
    # tm_epochs = 1

    vocavbulary = 20000
    percent_to_change = 5
    iteration = 5
    tm_epochs = 10

    y_train, y_test, X_train_text, X_test_text = prepare_imdb.get(vocavbulary)

    # embedding_models = ["word2vec", "fasttext", "glove", "omnitm", "bert","elmo"]
    embedding_models = ["bert"]

    for embedding_model in embedding_models:
        print(f"Embedding model: {embedding_model}")
        # Load the embedding model
        if embedding_model == "word2vec":
            import embeddings.word2vec_perturb as model
            model_path=f"{model_base_path}/{embedding_model}.model"
        elif embedding_model == "fasttext":
            import embeddings.fasttext_perturb as model
            model_path=f"{model_base_path}/{embedding_model}.model"
        elif embedding_model == "glove":
            import embeddings.glove_perturb as model
            model_path=f"{model_base_path}/{embedding_model}.model"
        elif embedding_model == "bert":
            import embeddings.bert_perturb as model
            model_path=""
        elif embedding_model == "elmo":
            import embeddings.elmo_perturb as model
            model_path=""
        elif embedding_model == "omnitm":
            import embeddings.omnitm_perturb as model
            model_path=f"{model_base_path}/{embedding_model}.model"
        

        for i in range(iteration):
            X_train_augmented = model.perturb(X_train_text, y_train, model_path, percent_to_change=percent_to_change)
            data = (X_train_augmented, y_train, X_test_text, y_test)
            classifier_results = evaluate_classifiers(data, tm_epochs=tm_epochs)
            
            from datetime import datetime
            row = {"Timestamp": datetime.now().isoformat(), "EmbeddingModel": embedding_model, "Iteration": i}
            row.update(classifier_results)
            results_df = pd.DataFrame([row])
            results_df.to_csv(csv_path, mode='a', header=not os.path.exists(csv_path), index=False)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset_name', type=str, required=False, help='Name of the dataset to use')
    args = parser.parse_args()

    if args.dataset_name:
        input = lambda _: args.dataset_name
    main(args.dataset_name)