import os
import sys
import pickle
import embeddings.generate_word2vec as generate_word2vec
import embeddings.generate_fasttext as generate_fasttext
import embeddings.generate_glove as generate_glove
import embeddings.generate_omni as generate_omni
from config import EMBEDDING_PARAMS
import argparse

env_path = os.getcwd()
print(f"Environment path: {env_path}")
sys.path.append(env_path)

def main(dataset_name=None):
    if dataset_name == None:
        print("Select the dataset:")
        print("1. IMDB")
        print("2. One Billion Word")
        
        choice = input("Enter the number corresponding to your choice: ")
        if choice == '1':
            omni_parameters = EMBEDDING_PARAMS['omni_imdb']
            db_name = 'imdb'
        elif choice == '2':
            omni_parameters = EMBEDDING_PARAMS['omni_1billion']
            db_name = '1billion'
        else:
            print("Invalid choice. Please select a valid dataset.")
            return
    else:
        if dataset_name != "imdb" and dataset_name != "1billion":
            print("Invalid dataset name. Please select either 'imdb' or '1billion'.")
            return
        db_name = dataset_name
    
    DATA_DIR = os.path.join(env_path, 'data', db_name)
    DATASET_PATH = os.path.join(DATA_DIR, 'X.pickle')
    VECTORIZER_PATH = os.path.join(DATA_DIR, 'vectorizer_X.pickle')
        # Load the dataset and vectorizer
    with open(DATASET_PATH, 'rb') as f:
        X = pickle.load(f)
    
    with open(VECTORIZER_PATH, 'rb') as f:
        vectorizer_X = pickle.load(f)
    if vectorizer_X is None or X is None:
        print("Error: Unable to load dataset or vectorizer.")
        return
    print(f"Vectorizer contains number of features:{len(vectorizer_X.get_feature_names_out())}")
    
    print("Select the embedding method:")
    print("1. Word2Vec")
    print("2. FastText")
    print("3. GloVe")
    print("4. Omni TM-AE")
    
    choice = input("Enter the number corresponding to your choice: ")

    if choice == '1':
        parameters = EMBEDDING_PARAMS['word2vec']
        generate_word2vec.generate(X, vectorizer_X, parameters, db_name)
    elif choice == '2':
        parameters = EMBEDDING_PARAMS['fasttext']
        generate_fasttext.generate(X, vectorizer_X, parameters, db_name)
    elif choice == '3':
        parameters = EMBEDDING_PARAMS['glove']
        generate_glove.generate(X, vectorizer_X, parameters, db_name)
    elif choice == '4':
        parameters = omni_parameters
        generate_omni.generate(X, vectorizer_X, parameters, db_name)
    else:
        print("Invalid choice. Please select a valid embedding method.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset_name', type=str, required=False, help='Name of the dataset to use')
    args = parser.parse_args()

    if args.dataset_name:
        input = lambda _: args.dataset_name
    main(args.dataset_name)