import os
import sys
import pandas as pd
import prepare_datasets
import prepare_datasets
import argparse
import pickle
import evaluation_similarity as eval
from config import EMBEDDING_PARAMS
import tqdm

env_path = os.getcwd()
print(f"Environment path: {env_path}")
sys.path.append(env_path)

def main(dataset_name = None):
    if dataset_name == None:
        print("Select the dataset:")
        print("1. IMDB")
        print("2. One Billion Word")
        
        choice = input("Enter the number corresponding to your choice: ")
        if choice == '1':
            omni_parameters = EMBEDDING_PARAMS['omni_imdb']
            db_name = 'imdb'
        elif choice == '2':
            omni_parameters = EMBEDDING_PARAMS['omni_1billion']
            db_name = '1billion'
        else:
            print("Invalid choice. Please select a valid dataset.")
            return
    else:
        if dataset_name != "imdb" and dataset_name != "1billion":
            print("Invalid dataset name. Please select either 'imdb' or '1billion'.")
            return
        db_name = dataset_name
        
    DATA_DIR = os.path.join(env_path, 'data', db_name)
    DATASET_PATH = os.path.join(DATA_DIR, 'X.pickle')
    VECTORIZER_PATH = os.path.join(DATA_DIR, 'vectorizer_X.pickle')
    # Load the dataset and vectorizer
  
    with open(VECTORIZER_PATH, 'rb') as f:
        vectorizer_X = pickle.load(f)
    if vectorizer_X is None:
        print("Error: Unable to load vectorizer.")
        return
    print(f"Vectorizer contains number of features:{len(vectorizer_X.get_feature_names_out())}")
        
    csv_path = f'{db_name}_similarities_evaluation_results.csv'
    data_base_path = os.path.join(env_path, 'data')
    
    iteration = 3
    use_pretrained = False
    embedding_models = ["word2vec", "fasttext", "glove", "omnitm"]
    # embedding_models = ["omnitm"]
    similarity_datasets = ["rg-65", "wordsim353-sim", "mturk-287", "mturk-771", "men", "simlex999"]
    # similarity_datasets = ["rg-65"]
    
    if use_pretrained == False:
        with open(DATASET_PATH, 'rb') as f:
            X = pickle.load(f)
        if X is None:
            print("Error: Unable to load dataset.")
            return
        tokenized_sentences_path = os.path.join(DATA_DIR, "tokenized_sentences.pickle")
        if not os.path.exists(tokenized_sentences_path):
            print("Tokenizing sentences...")
            tokenized_sentences = []
            vocabulary = vectorizer_X.vocabulary_
            reverse_vocab = {index: word for word, index in vocabulary.items()}
            for row in tqdm.tqdm(X, desc="Tokenizing sentences", unit="sentence",total=X.shape[0]):
                word_indices = row.indices  # Non-zero indices in the sparse matrix row
                sentence = [reverse_vocab[index] for index in word_indices]
                tokenized_sentences.append(sentence)
            with open(tokenized_sentences_path, "wb") as f:
                pickle.dump(tokenized_sentences, f, protocol=4)
        else:
            print("Loading tokenized sentences...")
            with open(tokenized_sentences_path, "rb") as f:
                tokenized_sentences = pickle.load(f)
    else:
        tokenized_sentences = None
        X = None

    for i in range(iteration):
        for embedding_model in embedding_models:
            print(f"Embedding model: {embedding_model}")
            # Load the embedding model
            if embedding_model == "word2vec":
                parameters = EMBEDDING_PARAMS['word2vec']
                import embeddings.word2vec_similarity as model
                word_vectors = model.build_embedding(tokenized_sentences, parameters, use_pretrained, DATA_DIR)
            elif embedding_model == "fasttext":
                parameters = EMBEDDING_PARAMS['fasttext']
                import embeddings.fasttext_similarity as model
                word_vectors = model.build_embedding(tokenized_sentences, parameters, use_pretrained, DATA_DIR)
            elif embedding_model == "glove":
                parameters = EMBEDDING_PARAMS['glove']
                import embeddings.glove_similarity as model
                word_vectors = model.build_embedding(tokenized_sentences, parameters, use_pretrained, DATA_DIR)
            elif embedding_model == "omnitm":
                parameters = omni_parameters
                import embeddings.omnitm_similarity as model
                words = []
                for dataset in similarity_datasets:
                    dataset_path = os.path.join(data_base_path, dataset + ".csv")
                    words.extend(prepare_datasets.get_dataset_words(dataset_path))
                word_vectors = model.build_embedding(X, vectorizer_X, parameters, words, use_pretrained, DATA_DIR)
        
            dataset_results = {}
            for dataset in similarity_datasets:
                dataset_path = os.path.join(data_base_path, dataset + ".csv")
                if not os.path.exists(dataset_path):
                    raise FileNotFoundError(f"Dataset path does not exist: {dataset_path}")
                pair_list = prepare_datasets.get_dataset_pairs(dataset_path)
                if not pair_list:
                    raise ValueError(f"Pair list is empty for dataset: {dataset}")
                output_active, target_words = prepare_datasets.get_active_target(vectorizer_X, pair_list)
                if not output_active.size:
                    raise ValueError(f"Output active is empty for dataset: {dataset}")
                print(f"Dataset: {dataset}, Number of words: {len(target_words)}")
                
                if embedding_model == "omnitm":
                    target_similarity = model.get(word_vectors, vectorizer_X, target_words)
                else:
                    target_similarity = model.get(word_vectors, parameters, target_words)
                if not target_similarity:
                    raise ValueError(f"Target similarity is empty for dataset: {dataset}")
                spearman, kendal = eval.calculate(target_similarity,pair_list)
                dataset_results[dataset] = {"spearman": spearman, "kendal": kendal}

            from datetime import datetime
            row = {
                "Timestamp": datetime.now().isoformat(),
                "EmbeddingModel": embedding_model,
                "Iteration": i,
            }
            for dataset in similarity_datasets:
                row[f"{dataset}_Spearman"] = dataset_results[dataset]["spearman"]
                row[f"{dataset}_Kendal"] = dataset_results[dataset]["kendal"]

            results_df = pd.DataFrame([row])
            results_df.to_csv(csv_path, mode='a', header=not os.path.exists(csv_path), index=False)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset_name', type=str, required=False, help='Name of the dataset to use')
    args = parser.parse_args()

    if args.dataset_name:
        input = lambda _: args.dataset_name
    main(args.dataset_name)