import pandas as pd
import numpy as np
from tqdm import tqdm
import random
import string
import lda
import argparse
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_selection import SelectPercentile, f_classif
import torch
import os

import caption_utils as utils

CAPTION_PATH = ""
SPLITS_PATH = ""
NAME = ""
N = 0


def load_data(CAPTION_PATH, SPLITS_PATH):

    if(".txt" in CAPTION_PATH):
        df_captions = pd.read_csv(CAPTION_PATH, sep="\t")
    else:
        #iterate through all files in directory and concatenate
        df_captions = pd.DataFrame()
        for file in os.listdir(CAPTION_PATH):
            df = pd.read_csv(CAPTION_PATH + file, sep="\t")
            df_captions = pd.concat([df_captions, df])
    print("caption data size: ", df_captions.shape)

    #get train/val and test splits    
    if("train.txt" in os.listdir(SPLITS_PATH)):
        train_temp = pd.read_csv(SPLITS_PATH + "train.txt", header=None)
        val_temp = pd.read_csv(SPLITS_PATH + "val.txt", header=None)
        trainval_examples = pd.concat([train_temp, val_temp])
    else:
        trainval_examples = pd.read_csv(SPLITS_PATH + "trainval.txt", header=None)
    test_examples = pd.read_csv(SPLITS_PATH + "test.txt", header=None)
    


    trainval_image_paths = trainval_examples[0].values
    test_image_paths = test_examples[0].values

    trainval_image_paths = [path.split(".png")[0]+".png" for path in trainval_image_paths]
    test_image_paths = [path.split(".png")[0]+".png" for path in test_image_paths]

    #get captions for train/val and test splits
    df_trainval = df_captions[df_captions["image_path"].isin(trainval_image_paths)]
    df_test = df_captions[df_captions["image_path"].isin(test_image_paths)]
    
    print("trainval data size: ", df_trainval.shape)
    print("test data size: ", df_test.shape)

    #----adding in------
    df_messages = pd.concat([df_trainval, df_test])
    df_messages["message"] = df_messages["caption"]
    df_messages.to_csv("text_data/messages_" + NAME + ".csv", index=False)
    print("SAVED: " + NAME + ", len: " + str(df_messages.shape[0]))

    return df_trainval, df_test

def lda_eval(X, labels, n_topics, len_train, vectorizer, n_iter=400):
    #run LDA
    model = lda.LDA(n_topics=n_topics, n_iter=n_iter, random_state=42, alpha=0.05, eta=0.005)
    model.fit(X)

    #print top words for each topic
    topic_word = model.topic_word_
    n_top_words = 10
    for i, topic_dist in enumerate(topic_word):
        topic_words = np.array(vectorizer.get_feature_names_out())[np.argsort(topic_dist)][:-(n_top_words+1):-1]
        print('Topic {}: {}'.format(i, ' '.join(topic_words)))
     
    #train a classifier on the topic distribution
    input = torch.from_numpy(model.doc_topic_).float()
    input_train = input[:len_train]
    labels_train = list(labels[:len_train])
    input_test = input[len_train:]
    labels_test = list(labels[len_train:])
    acc, _ = utils.train_model(input_train, n_topics, len(set(labels_train)), labels_train, test_input=input_test, test_labels=labels_test)

    return acc, model


def caption_baseline(df_trainval, df_test, n_topic_range):

    df_all = pd.concat([df_trainval, df_test])

    #BOW baseline
    #get counts of all words
    all_words = np.unique(np.concatenate(df_all["caption"].apply(lambda x: x.split()).values))
    word_counts = np.zeros(len(all_words))
    for caption in df_all["caption"]:
        for word in caption.split():
            word_counts[np.where(all_words == word)] += 1
    
    best_BOW_acc = 0
    BOW_accuracies = []
    
    for n in tqdm(n_topic_range):
        top_n_words = all_words[np.argsort(word_counts)[-n:]]
        word_dist = []
        for caption in df_trainval["caption"]:
            caption_word_dist = np.zeros(len(top_n_words))
            for word in caption.split():
                caption_word_dist[np.where(top_n_words == word)] += 1
            word_dist.append(caption_word_dist)
        test_dist = []
        for caption in df_test["caption"]:
            caption_word_dist = np.zeros(len(top_n_words))
            for word in caption.split():
                caption_word_dist[np.where(top_n_words == word)] += 1
            test_dist.append(caption_word_dist)
        acc, _ = utils.train_model(torch.from_numpy(np.array(word_dist)).float(), 
                                    len(top_n_words), 
                                    len(set(df_trainval["label"])), 
                                    df_trainval["label"].values,
                                    test_input = torch.from_numpy(np.array(test_dist)).float(),
                                    test_labels = df_test["label"].values)
        BOW_accuracies.append(acc)
        if acc > best_BOW_acc:
            best_BOW_acc = acc

    #LDA baseline
    messages = df_all["caption"]
    labels = df_all["label"]

    #make term document matrix
    vectorizer = CountVectorizer()
    X = vectorizer.fit_transform(messages)

    words = vectorizer.get_feature_names_out()
    stopwords = []
    for w in words:
        #check if w appears in more than 50% of documents
        percent_docs = X[:, vectorizer.vocabulary_[w]].sum() / X.shape[0]
        if (percent_docs > 0.5):
            stopwords.append(w)
    print("Stopwords: ", str(len(stopwords)), stopwords)

    #remove stop words
    vectorizer = CountVectorizer(stop_words=stopwords)
    X = vectorizer.fit_transform(messages)

    best_LDA_acc = 0
    best_LDA_model = None
    best_n_topics = 0
    
    LDA_accuracies = []

    for n in tqdm(n_topic_range):
        acc, model = lda_eval(X, labels=labels, n_topics=n, len_train=df_trainval.shape[0], vectorizer=vectorizer)
        if acc > best_LDA_acc:
            best_LDA_acc = acc
            best_LDA_model = model
            best_n_topics = n      
        LDA_accuracies.append(acc)
    return best_LDA_acc, best_LDA_model, best_BOW_acc, vectorizer, best_n_topics, LDA_accuracies, BOW_accuracies

    

def main():
# read in command line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--caption_path', type=str, help='path to caption data')
    parser.add_argument('--name', type=str, help='name of experiment')
    parser.add_argument('--splits_path', type=str, help='path to test data')
    parser.add_argument('--n', type=str, help='number of topics')
    args = parser.parse_args()
    if (args.caption_path is None or args.name is None or args.splits_path is None or args.n is None):
        print("Please specify path to caption tsv file (--caption_path) and name of experiment (--name) and path to test data (--splits_path) and number of topics (--n).")
        exit()
    CAPTION_PATH = args.caption_path
    global NAME
    NAME = args.name
    SPLITS_PATH = args.splits_path
    N = int(args.n)

    df_trainval, df_test = load_data(CAPTION_PATH, SPLITS_PATH)
    return
    n_topic_range = [N]

    best_LDA_acc, best_LDA_model, best_BOW_acc, vectorizer, best_n_topics, LDA_accuracies, BOW_accuracies = caption_baseline(df_trainval, df_test, n_topic_range)

    #write stats to csv file
    with open("caption_baseline_stats_" + NAME + ".csv", 'w') as f:
        f.write("caption LDA: best accuracy,%s\n"%(best_LDA_acc))
        f.write("best n topics,%s\n"%(best_n_topics))
        f.write("caption BOW: best accuracy,%s\n"%(best_BOW_acc))
     
        f.write("searched num dimensions range,%s\n"%(n_topic_range))

        f.write("\n ALL CAPTION LDA ACCURACIES \n")
        for i in range(len(n_topic_range)):
            f.write("%s,%s\n"%(n_topic_range[i], LDA_accuracies[i]))
        
        f.write("\n\n ALL CAPTION BOW ACCURACIES \n")
        for i in range(len(n_topic_range)):
            f.write("%s,%s\n"%(n_topic_range[i], BOW_accuracies[i]))
    print("written to caption_baseline_stats_" + NAME + ".csv")

main() 