import pandas as pd
import numpy as np
from tqdm import tqdm
import random
import string
import lda
import argparse
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_selection import SelectPercentile, f_classif

import utils as utils
import torch


PATH = ""
NAME = ""


def lda_eval(X, labels, n_topics, cluster_membership_matrix, n_iter=2000):
    #run LDA
    model = lda.LDA(n_topics=n_topics, n_iter=n_iter, random_state=42, alpha=0.05, eta=0.005)
    model.fit(X)
    
    #train a classifier on the topic distribution
    input = torch.from_numpy(model.doc_topic_).float()
    acc, _ = utils.train_model(input, n_topics, len(set(labels)), labels)

    input_clusters = torch.cat((input, cluster_membership_matrix), 1)
    acc_clusters, _ = utils.train_model(input_clusters, n_topics + cluster_membership_matrix.shape[1], len(set(labels)), labels)

    return acc, model, acc_clusters

def lda_gridsearch(df_messages, n_topic_range, cluster_membership_matrix):
    messages = df_messages["message"]
    labels = df_messages["label"]

    #get cluster only accuracies
    cluster_membership_matrix = torch.from_numpy(cluster_membership_matrix).float()
    cluster_only_acc, _ = utils.train_model(cluster_membership_matrix, cluster_membership_matrix.shape[1], len(set(labels)), labels)

    #make term document matrix
    vectorizer = CountVectorizer()
    X = vectorizer.fit_transform(messages)

    words = vectorizer.get_feature_names_out()
    stopwords = []
    for w in words:
        #check if w appears in more than 80% of documents
        percent_docs = X[:, vectorizer.vocabulary_[w]].sum() / X.shape[0]
        if (percent_docs > 0.5):
            stopwords.append(w)
    print("Stopwords: ", str(len(stopwords)), stopwords)

    #remove stop words
    vectorizer = CountVectorizer(stop_words=stopwords)
    X = vectorizer.fit_transform(messages)

    best_acc = 0
    best_cluster_acc = 0
    best_model = None
    best_n_topics = 0

    accuracies = []
    cluster_accuracies = []

    for n in tqdm(n_topic_range):
        acc, model, cluster_acc = lda_eval(X, labels=labels, n_topics=n, cluster_membership_matrix=cluster_membership_matrix)
        if acc > best_acc:
            best_acc = acc
            best_model = model
            best_cluster_acc = cluster_acc
            best_n_topics = n      
        accuracies.append(acc)
        cluster_accuracies.append(cluster_acc)
    return best_acc, best_model, best_cluster_acc ,vectorizer, best_n_topics, accuracies, cluster_accuracies, cluster_only_acc


def main():
# read in command line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--path', type=str, help='path to cluster data')
    parser.add_argument('--name', type=str, help='name of experiment')
    parser.add_argument('--split_path', type=str, help='path to dataset splits')
    args = parser.parse_args()
    if (args.path is None or args.name is None):
        print("Please specify path to cluster tsv file (--path) and name of experiment (--name).")
        exit()
    PATH = args.path
    NAME = args.name
    SPLIT_PATH = args.split_path

    # load data
    df_train, df_val, df_test = utils.load_data(PATH, NAME, SPLIT_PATH)
    cluster_membership_matrix = utils.get_cluster_memberships(df_messages)

    n_topic_range = list(range(20, 500, 20))
    # n_topic_range = list(range(100, 120, 20))
    best_acc, best_model, best_cluster_acc, vectorizer, best_n_topics, accuracies, cluster_accuracies, cluster_only_acc = lda_gridsearch(df_messages, n_topic_range, cluster_membership_matrix)

    topic_word_scores = best_model.topic_word_
    document_topic_scores = best_model.doc_topic_
    all_words = vectorizer.vocabulary_
    topic_score_dict = {}

    for topic_id in range(best_n_topics):
        scores = topic_word_scores[topic_id]
        word_score_dict = {}
        for word in all_words.keys():
            word_score_dict[word] = scores[all_words[word]]
        #sort word_score_dict by score
        word_score_dict = {k: v for k, v in sorted(word_score_dict.items(), key=lambda item: item[1], reverse=True)}
        word_score_dict = {k: v for k, v in word_score_dict.items() if v > 0.00001}
        word_score_str = ""
        for word in word_score_dict.keys():
            word_score_str += word + "," + str(word_score_dict[word]) + ","
        topic_score_dict[topic_id] = word_score_str[:-1]
    
    #write topic_score_dict to csv file
    with open("topic_score_dict_" + NAME + ".csv", 'w') as f:
        f.write("topic_id,word1,word1_score,word2,word2_score,...\n")
        for key in topic_score_dict.keys():
            f.write("%s,%s\n"%(key,topic_score_dict[key]))

    #write document_topic_scores to csv file
    with open("document_topic_scores_" + NAME + ".csv", 'w') as f:
        f.write("image_id,topic1_score,topic2_score,...\n")
        for i in range(len(document_topic_scores)):
            f.write("%s,%s\n"%(df_messages["image_id"][i],",".join([str(x) for x in document_topic_scores[i]])))
    
    #write stats to csv file
    with open("lda_stats_" + NAME + ".csv", 'w') as f:
        f.write("best topic accuracy,%s\n"%(best_acc))
        f.write("best n topics,%s\n"%(best_n_topics))
        f.write("cluster only accuracy,%s\n"%(cluster_only_acc))
        f.write("cluster + topic accuracy,%s\n"%(best_cluster_acc))
        
        f.write("searched num_topic range,%s\n"%(n_topic_range))

        f.write("\n ALL ACCURACIES \n")
        for i in range(len(n_topic_range)):
            f.write("%s,%s\n"%(n_topic_range[i], accuracies[i]))
        
        f.write("\n\n ALL TOPIC + CLUSTER ACCURACIES \n")
        for i in range(len(n_topic_range)):
            f.write("%s,%s\n"%(n_topic_range[i], cluster_accuracies[i]))
    

main()