import pandas as pd
import numpy as np
from tqdm import tqdm
import random
import string
import lda
import argparse
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_selection import SelectPercentile, f_classif
import torch
import json
import os

import utils_split as utils

CAPTION_PATH = ""
SPLITS_PATH = ""
NAME = ""
N = 0


def load_data(CAPTION_PATH, SPLITS_PATH):

    if(".txt" in CAPTION_PATH):
        df_captions = pd.read_csv(CAPTION_PATH, sep="\t")
    else:
        #iterate through all files in directory and concatenate
        df_captions = pd.DataFrame()
        for file in os.listdir(CAPTION_PATH):
            df = pd.read_csv(CAPTION_PATH + file, sep="\t")
            df_captions = pd.concat([df_captions, df])
    print("caption data size: ", df_captions.shape)
    print(df_captions.head())

    if os.path.exists(os.path.join(SPLITS_PATH, "trainval.txt")):
        trainval_examples = pd.read_csv(os.path.join(SPLITS_PATH, "trainval.txt"), sep='\t')['image_path']
        train_size = int(0.8 * len(trainval_examples))
        train_dataset, val_dataset = torch.utils.data.random_split(trainval_examples, [train_size, len(trainval_examples) - train_size], 
                                                                    generator=torch.Generator().manual_seed(42))
        train_image_paths = [train_dataset.dataset[i] for i in train_dataset.indices]
        val_image_paths = [val_dataset.dataset[i] for i in val_dataset.indices]
    else:
        train_examples = pd.read_csv(os.path.join(SPLITS_PATH, "train.txt"), sep='\t')['image_path']
        val_examples = pd.read_csv(os.path.join(SPLITS_PATH, "val.txt"), sep='\t')['image_path']
        train_image_paths = train_examples.values
        val_image_paths = val_examples.values
    test_examples = pd.read_csv(os.path.join(SPLITS_PATH, "test.txt"), sep='\t')['image_path']
    test_image_paths = test_examples.values


    #get captions for train/val and test splits
    df_train = df_captions[df_captions["image_path"].isin(train_image_paths)]
    df_val = df_captions[df_captions["image_path"].isin(val_image_paths)]
    df_test = df_captions[df_captions["image_path"].isin(test_image_paths)]
    
    print("train data size: ", df_train.shape)
    print("val data size: ", df_val.shape)
    print("test data size: ", df_test.shape)

    return df_train, df_val, df_test

def lda_eval(X, labels, n_topics, len_train, len_val, vectorizer, n_iter=400):
    #run LDA
    model = lda.LDA(n_topics=n_topics, n_iter=n_iter, random_state=42, alpha=0.05, eta=0.005)
    model.fit(X)

    #print top words for each topic
    topic_word = model.topic_word_
    n_top_words = 10
    for i, topic_dist in enumerate(topic_word):
        topic_words = np.array(vectorizer.get_feature_names_out())[np.argsort(topic_dist)][:-(n_top_words+1):-1]
        print('Topic {}: {}'.format(i, ' '.join(topic_words)))
     
    #train a classifier on the topic distribution
    input = torch.from_numpy(model.doc_topic_).float()
    input_train = input[:len_train]
    labels_train = list(labels[:len_train])
    input_val = input[len_train:len_train+len_val]
    labels_val = list(labels[len_train:len_train+len_val])
    input_test = input[len_train+len_val:]
    labels_test = list(labels[len_train+len_val:])
    acc, _ = utils.train_model_lrsearch(input_train, labels_train,
                                   input_val, labels_val,
                                   input_test, labels_test,
                                   n_topics, 
                                   len(set(labels_train)),
                                   list(set(labels_train))
                                   )
    return acc, model


def caption_baseline(df_train, df_val, df_test, n_topic_range):

    df_all = pd.concat([df_train, df_val, df_test])
    # df_trainval = pd.concat([df_train, df_val])

    #BOW baseline
    #get counts of all words
    all_words = np.unique(np.concatenate(df_all["caption"].apply(lambda x: x.split()).values))
    word_counts = np.zeros(len(all_words))
    for caption in df_all["caption"]:
        for word in caption.split():
            word_counts[np.where(all_words == word)] += 1
    
    best_BOW_acc = 0
    BOW_accuracies = []
    
    for n in tqdm(n_topic_range):
        top_n_words = all_words[np.argsort(word_counts)[-n:]]
        train_dist = []
        for caption in df_train["caption"]:
            caption_word_dist = np.zeros(len(top_n_words))
            for word in caption.split():
                caption_word_dist[np.where(top_n_words == word)] += 1
            train_dist.append(caption_word_dist)
        val_dist = []
        for caption in df_val["caption"]:
            caption_word_dist = np.zeros(len(top_n_words))
            for word in caption.split():
                caption_word_dist[np.where(top_n_words == word)] += 1
            val_dist.append(caption_word_dist)
        test_dist = []
        for caption in df_test["caption"]:
            caption_word_dist = np.zeros(len(top_n_words))
            for word in caption.split():
                caption_word_dist[np.where(top_n_words == word)] += 1
            test_dist.append(caption_word_dist)
        acc, _ = utils.train_model_lrsearch(torch.from_numpy(np.array(train_dist)).float(), 
                                   df_train["label"].values,
                                   torch.from_numpy(np.array(val_dist)).float(), 
                                   df_val["label"].values,
                                   torch.from_numpy(np.array(test_dist)).float(), 
                                   df_test["label"].values,
                                   len(top_n_words), 
                                   len(set(df_train["label"])),
                                   list(set(df_train["label"]))
                                   )
        BOW_accuracies.append(acc)
        if acc > best_BOW_acc:
            best_BOW_acc = acc

    #LDA baseline
    messages = df_all["caption"]
    labels = df_all["label"]

    #make term document matrix
    vectorizer = CountVectorizer()
    X = vectorizer.fit_transform(messages)

    words = vectorizer.get_feature_names_out()
    stopwords = []
    for w in words:
        #check if w appears in more than 50% of documents
        percent_docs = X[:, vectorizer.vocabulary_[w]].sum() / X.shape[0]
        if (percent_docs > 0.5):
            stopwords.append(w)
    print("Stopwords: ", str(len(stopwords)), stopwords)

    #remove stop words
    vectorizer = CountVectorizer(stop_words=stopwords)
    X = vectorizer.fit_transform(messages)

    best_LDA_acc = 0
    best_LDA_model = None
    best_n_topics = 0
    
    LDA_accuracies = []

    for n in tqdm(n_topic_range):
        acc, model = lda_eval(X, labels=labels, n_topics=n, len_train=df_train.shape[0], 
                              len_val=df_val.shape[0], 
                              vectorizer=vectorizer)
        if acc > best_LDA_acc:
            best_LDA_acc = acc
            best_LDA_model = model
            best_n_topics = n      
        LDA_accuracies.append(acc)
    return best_LDA_acc, best_LDA_model, best_BOW_acc, vectorizer, best_n_topics, LDA_accuracies, BOW_accuracies

    

def main():
# read in command line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--caption_path', type=str, help='path to caption data')
    parser.add_argument('--name', type=str, help='name of experiment')
    parser.add_argument('--splits_path', type=str, help='path to test data')
    parser.add_argument('--n', type=str, help='number of topics')
    args = parser.parse_args()
    if (args.caption_path is None or args.name is None or args.splits_path is None or args.n is None):
        print("Please specify path to caption tsv file (--caption_path) and name of experiment (--name) and path to test data (--splits_path) and number of topics (--n).")
        exit()
    CAPTION_PATH = args.caption_path
    NAME = args.name
    SPLITS_PATH = args.splits_path
    N = int(args.n)

    df_train, df_val, df_test = load_data(CAPTION_PATH, SPLITS_PATH)
    df_messages = pd.concat([df_train, df_val, df_test], ignore_index=True)
    n_topic_range = [N]

    best_LDA_acc, best_LDA_model, best_BOW_acc, vectorizer, best_n_topics, LDA_accuracies, BOW_accuracies = caption_baseline(df_train, df_val, df_test, 
                                                                                                                             n_topic_range)

    topic_word_scores = best_LDA_model.topic_word_
    document_topic_scores = best_LDA_model.doc_topic_
    all_words = vectorizer.vocabulary_
    topic_score_dict = {}

    for topic_id in range(best_n_topics):
        scores = topic_word_scores[topic_id]
        word_score_dict = {}
        for word in all_words.keys():
            word_score_dict[word] = scores[all_words[word]]
        #sort word_score_dict by score
        word_score_dict = {k: v for k, v in sorted(word_score_dict.items(), key=lambda item: item[1], reverse=True)}
        word_score_dict = {k: v for k, v in word_score_dict.items() if v > 0.00001}
        word_score_str = ""
        for word in word_score_dict.keys():
            word_score_str += word + "," + str(word_score_dict[word]) + ","
        topic_score_dict[topic_id] = word_score_str[:-1]
    

    #write topic_score_dict to csv file
    with open("ldas/caption_topic_score_dict_" + NAME + ".csv", 'w') as f:
        f.write("topic_id,word1,word1_score,word2,word2_score,...\n")
        for key in topic_score_dict.keys():
            f.write("%s,%s\n"%(key,topic_score_dict[key]))

    #write document_topic_scores to csv file
    with open("ldas/caption_document_topic_scores_" + NAME + ".csv", 'w') as f:
        f.write("image_id,topic1_score,topic2_score,...\n")
        for i in range(len(document_topic_scores)):
            f.write("%s,%s\n"%(df_messages["image_path"][i],",".join([str(x) for x in document_topic_scores[i]])))
    
    #write stats to csv file
    with open("ldas/caption_baseline_stats_" + NAME + ".csv", 'w') as f:
        f.write("caption LDA: best accuracy,%s\n"%(best_LDA_acc))
        f.write("best n topics,%s\n"%(best_n_topics))
        f.write("caption BOW: best accuracy,%s\n"%(best_BOW_acc))
     
        f.write("searched num dimensions range,%s\n"%(n_topic_range))

        f.write("\n ALL CAPTION LDA ACCURACIES \n")
        for i in range(len(n_topic_range)):
            f.write("%s,%s\n"%(n_topic_range[i], LDA_accuracies[i]))
        
        f.write("\n\n ALL CAPTION BOW ACCURACIES \n")
        for i in range(len(n_topic_range)):
            f.write("%s,%s\n"%(n_topic_range[i], BOW_accuracies[i]))
    print("written to caption_baseline_stats_" + NAME + ".csv")

    os.makedirs('results_captions', exist_ok=True)
    #write stats to json file
    with open("results_captions/" + NAME + ".json", 'w') as f:
        json.dump({
            'dataset': NAME,
            'best_LDA_acc': best_LDA_acc,
            'best_n_topics': best_n_topics,
            'best_BOW_acc': best_BOW_acc,
            'LDA_accuracies': LDA_accuracies,
            'BOW_accuracies': BOW_accuracies
        }, f, indent=4)

main() 