import pandas as pd
import numpy as np
from tqdm import tqdm
import random
import string
import os
import lda
import argparse
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_selection import SelectPercentile, f_classif
from scipy.sparse import vstack

import utils_split as utils
import torch
import json


PATH = ""
NAME = ""


def lda_eval(X_train, labels_train, X_val, labels_val, X_test, labels_test, 
            n_topics, 
            cluster_membership_matrix_train,
            cluster_membership_matrix_val,
            cluster_membership_matrix_test, n_iter=2000, test_X=None, test_labels=None, classes=None):
    #run LDA
    model = lda.LDA(n_topics=n_topics, n_iter=n_iter, random_state=42, alpha=0.05, eta=0.005)
    model.fit(vstack([X_train, X_val, X_test]))
    
    #train a classifier on the topic distribution
    inputs = torch.from_numpy(model.doc_topic_).float()
    inputs_train = inputs[:X_train.shape[0]]
    inputs_val = inputs[X_train.shape[0]:X_train.shape[0]+X_val.shape[0]]
    inputs_test = inputs[X_train.shape[0]+X_val.shape[0]:]
    
    topic_acc, _, topic_val_acc = utils.train_model_lrsearch(inputs_train, labels_train, inputs_val, labels_val, inputs_test, labels_test,
                               n_topics, len(classes), classes, return_val_acc=True)

    input_clusters_train = torch.cat((inputs_train, cluster_membership_matrix_train), 1)
    input_clusters_val = torch.cat((inputs_val, cluster_membership_matrix_val), 1)
    input_clusters_test = torch.cat((inputs_test, cluster_membership_matrix_test), 1)
    topic_cluster_acc, _ = utils.train_model_lrsearch(input_clusters_train, labels_train, input_clusters_val, labels_val,
                                        input_clusters_test, labels_test,
                                        n_topics + cluster_membership_matrix_train.shape[1], len(classes), classes)

    return topic_acc, model, topic_cluster_acc, topic_val_acc

def lda_gridsearch(df_train, df_val, df_test, n, 
                   cluster_membership_matrix_train, 
                   cluster_membership_matrix_val,
                   cluster_membership_matrix_test,
                   classes):
    # train
    messages_train = df_train["message"]
    labels_train = df_train["label"]
    # val
    messages_val = df_val["message"]
    labels_val = df_val["label"]
    # test
    messages_test = df_test["message"]
    labels_test = df_test["label"]

    #get cluster only accuracies
    cluster_membership_matrix_train = torch.from_numpy(cluster_membership_matrix_train).float()
    cluster_membership_matrix_val = torch.from_numpy(cluster_membership_matrix_val).float()
    cluster_membership_matrix_test = torch.from_numpy(cluster_membership_matrix_test).float()
    cluster_only_acc, _ = utils.train_model_lrsearch(cluster_membership_matrix_train, 
                                            labels_train,
                                            cluster_membership_matrix_val,
                                            labels_val,
                                            cluster_membership_matrix_test,
                                            labels_test,
                                            cluster_membership_matrix_train.shape[1], 
                                            len(classes),
                                            classes)

    #make term document matrix
    vectorizer = CountVectorizer()
    messages = pd.concat([messages_train, messages_val, messages_test], ignore_index=True)
    X = vectorizer.fit_transform(messages)

    words = vectorizer.get_feature_names_out()
    stopwords = []
    for w in words:
        #check if w appears in more than 80% of documents
        percent_docs = X[:, vectorizer.vocabulary_[w]].sum() / X.shape[0]
        if (percent_docs > 0.5):
            stopwords.append(w)
    print("Stopwords: ", str(len(stopwords)), stopwords)

    #remove stop words
    vectorizer = CountVectorizer(stop_words=stopwords)
    X = vectorizer.fit_transform(messages)
    X_train = X[:len(df_train)]
    X_val = X[len(df_train):len(df_train)+len(df_val)]
    X_test = X[len(df_train)+len(df_val):]

    topic_test_acc, model, topic_cluster_acc, topic_val_acc = lda_eval(X_train, labels_train, X_val, labels_val, X_test, labels_test, 
                                                n_topics=n, 
                                                cluster_membership_matrix_train=cluster_membership_matrix_train,
                                                cluster_membership_matrix_val=cluster_membership_matrix_val,
                                                cluster_membership_matrix_test=cluster_membership_matrix_test,
                                                classes=classes)


    return topic_test_acc, model, topic_cluster_acc, vectorizer, cluster_only_acc, topic_val_acc


def main():
# read in command line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--path', type=str, help='path to cluster data')
    parser.add_argument('--test-path', type=str, help='path to splits')
    parser.add_argument('--classes-path', type=str, help='path to classes')
    parser.add_argument('--name', type=str, help='name of experiment')
    parser.add_argument('--dataset', type=str, 
                        choices=['aircraft', 'birdsnap', 'caltech101', 'caltech256', 'cifar10', 'cifar100',
                                 'cars', 'dtd', 'flowers', 'food', 'pets', 'sun'],
                        help='name of dataset')
    parser.add_argument('--n-clusters', type=int, help='number of clusters')
    parser.add_argument('--n-topics', type=int, help='number of topics')
    args = parser.parse_args()
    if (args.path is None or args.name is None):
        print("Please specify path to cluster tsv file (--path) and name of experiment (--name).")
        exit()
    PATH = args.path
    TEST_PATH = args.test_path
    NAME = args.name
    CLASSES_PATH = args.classes_path
    DATASET = args.dataset
    N_CLUSTERS = args.n_clusters

    df_train, df_val, df_test, classes = utils.load_data(PATH, NAME, TEST_PATH, CLASSES_PATH, save_message=True)
    # cluster_membership_matrix_train, cluster_membership_matrix_val, cluster_membership_matrix_test = utils.get_cluster_memberships(
    #     df_train, df_val, df_test)

    # load data
    if args.n_topics is None:
        n_topic_range= list(range(30, 330, 30))
    else:
        n_topic_range = [args.n_topics]  
    best_acc = 0
    best_test_acc = 0
    best_cluster_acc = 0
    best_topic_cluster_acc = 0
    best_model = None
    best_n_topics = 0
    all_topic_accs = []
    all_cluster_accs = []
    all_topic_cluster_accs = []
    for n_topic in n_topic_range:
        path_i = PATH.replace(str(args.n_clusters), str(n_topic))
        df_train_i, df_val_i, df_test_i, classes = utils.load_data(path_i, NAME, TEST_PATH, CLASSES_PATH)
        cluster_membership_matrix_train, cluster_membership_matrix_val, cluster_membership_matrix_test = utils.get_cluster_memberships(
            df_train_i, df_val_i, df_test_i)

        # n_topic_range = list(range(30, 330, 30))
        topic_test_acc, model, topic_cluster_acc, vectorizer, cluster_only_acc, topic_val_acc = lda_gridsearch(df_train, df_val, df_test, n_topic, 
                                            cluster_membership_matrix_train, 
                                            cluster_membership_matrix_val, 
                                            cluster_membership_matrix_test,
                                            classes)
        all_topic_accs.append(topic_test_acc)
        all_cluster_accs.append(cluster_only_acc)
        all_topic_cluster_accs.append(topic_cluster_acc)
        if topic_val_acc > best_acc:
            best_acc = topic_val_acc
            best_test_acc = topic_test_acc
            best_model = model
            best_cluster_acc = cluster_only_acc
            best_n_topics = n_topic
            best_topic_cluster_acc = topic_cluster_acc
    
    df_messages = pd.concat([df_train, df_val, df_test], ignore_index=True)

    topic_word_scores = best_model.topic_word_
    document_topic_scores = best_model.doc_topic_
    all_words = vectorizer.vocabulary_
    topic_score_dict = {}

    for topic_id in range(best_n_topics):
        scores = topic_word_scores[topic_id]
        word_score_dict = {}
        for word in all_words.keys():
            word_score_dict[word] = scores[all_words[word]]
        #sort word_score_dict by score
        word_score_dict = {k: v for k, v in sorted(word_score_dict.items(), key=lambda item: item[1], reverse=True)}
        word_score_dict = {k: v for k, v in word_score_dict.items() if v > 0.00001}
        word_score_str = ""
        for word in word_score_dict.keys():
            word_score_str += word + "," + str(word_score_dict[word]) + ","
        topic_score_dict[topic_id] = word_score_str[:-1]
    
    #write topic_score_dict to csv file
    with open("ldas/topic_score_dict_" + NAME + ".csv", 'w') as f:
        f.write("topic_id,word1,word1_score,word2,word2_score,...\n")
        for key in topic_score_dict.keys():
            f.write("%s,%s\n"%(key,topic_score_dict[key]))

    #write document_topic_scores to csv file
    with open("ldas/document_topic_scores_" + NAME + ".csv", 'w') as f:
        f.write("image_id,topic1_score,topic2_score,...\n")
        for i in range(len(document_topic_scores)):
            f.write("%s,%s\n"%(df_messages["image_id"][i],",".join([str(x) for x in document_topic_scores[i]])))
    
    #write stats to csv file
    with open("ldas/lda_stats_" + NAME + ".csv", 'w') as f:
        f.write("best topic accuracy,%s\n"%(best_test_acc))
        f.write("best n topics,%s\n"%(best_n_topics))
        # f.write("cluster only accuracy,%s\n"%(cluster_only_acc))
        f.write("cluster accuracy,%s\n"%(best_cluster_acc))
        f.write("cluster + topic accuracy,%s\n"%(best_topic_cluster_acc))
        
        f.write("searched num_topic range,%s\n"%(n_topic_range))

        f.write("\n ALL ACCURACIES \n")
        for i in range(len(n_topic_range)):
            f.write("%s,%s\n"%(n_topic_range[i], all_topic_accs[i]))
        
        f.write("\n\n ALL CLUSTER ACCURACIES \n")
        for i in range(len(n_topic_range)):
            f.write("%s,%s\n"%(n_topic_range[i], all_cluster_accs[i]))
        
        f.write("\n\n ALL TOPIC + CLUSTER ACCURACIES \n")
        for i in range(len(n_topic_range)):
            f.write("%s,%s\n"%(n_topic_range[i], all_topic_cluster_accs[i]))

    #write stats to json file
    os.makedirs('results_clusters', exist_ok=True)
    with open("results_clusters/" + NAME + ".json", 'w') as f:
        json.dump({
            'dataset': DATASET,
            'n_clusters': N_CLUSTERS,
            'best_topic_acc': best_acc,
            'best_n_topics': best_n_topics,
            'best_cluster_only_acc': cluster_only_acc,
            'best_cluster_topic_acc': best_cluster_acc,
            'n_topic_range': n_topic_range,
            'all_topic_accuracies': {n_topic_range[i]: all_topic_accs[i] for i in range(len(n_topic_range))},
            'all_cluster_accuracies': {n_topic_range[i]: all_cluster_accs[i] for i in range(len(n_topic_range))},
            'all_topic_cluster_accuracies': {n_topic_range[i]: all_topic_cluster_accs[i] for i in range(len(n_topic_range))}
        }, f, indent=4)
    
main()