import os
import sys
import pandas as pd
import prepare_datasets
import prepare_clusterizer
import argparse

env_path = os.getcwd()
print(f"Environment path: {env_path}")
sys.path.append(env_path)

def main(dataset_name = None):
    if dataset_name == None:
        print("Select the dataset:")
        print("1. IMDB")
        print("2. One Billion Word")
        
        choice = input("Enter the number corresponding to your choice: ")
        if choice == '1':
            db_name = 'imdb'
        elif choice == '2':
            db_name = '1billion'
        else:
            print("Invalid choice. Please select a valid dataset.")
            return
    else:
        if dataset_name != "imdb" and dataset_name != "1billion":
            print("Invalid dataset name. Please select either 'imdb' or '1billion'.")
            return
        db_name = dataset_name
        
    csv_path = f'{db_name}_clustering_evaluation_results.csv'
    model_base_path = os.path.join("data", db_name)
    
    iteration = 5

    embedding_models = ["word2vec", "fasttext", "glove", "omnitm"]
    # embedding_models = ["omnitm"]
    labeled_datasets = ["20newsgroups", "reuters", "yelp", "amazon", "ag_news"]

    for embedding_model in embedding_models:
        clustering = prepare_clusterizer.get(model_base_path, embedding_model)
        
        for i in range(iteration):
            dataset_results = {}
            for dataset in labeled_datasets:
                print(f"Dataset: {dataset}")
                documents, labels = prepare_datasets.get(dataset, "data")
                document_embeddings = clustering.compute_embeddings(documents)
                
                if embedding_model == "omnitm":
                    from sklearn.decomposition import PCA
                    from sklearn.preprocessing import normalize
                    
                    pca = PCA(n_components=100) 
                    reduced_embeddings = pca.fit_transform(document_embeddings)
                    print(f"Reduced shape: {reduced_embeddings.shape}")
                    normalized_embeddings = normalize(reduced_embeddings, axis=1)
                    print(f"Normalized document embeddings shape: {normalized_embeddings.shape}")
                else:
                    normalized_embeddings = document_embeddings
                
                nmi, ari = clustering.cluster_documents_gpu(normalized_embeddings, labels)
                dataset_results[dataset] = {"NMI": nmi, "ARI": ari}

            from datetime import datetime
            row = {
                "Timestamp": datetime.now().isoformat(),
                "EmbeddingModel": embedding_model,
                "Iteration": i,
            }
            for dataset in labeled_datasets:
                row[f"{dataset}_NMI"] = dataset_results[dataset]["NMI"]
                row[f"{dataset}_ARI"] = dataset_results[dataset]["ARI"]

            results_df = pd.DataFrame([row])
            results_df.to_csv(csv_path, mode='a', header=not os.path.exists(csv_path), index=False)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset_name', type=str, required=False, help='Name of the dataset to use')
    args = parser.parse_args()

    if args.dataset_name:
        input = lambda _: args.dataset_name
    main(args.dataset_name)