import os
import pickle
import torch
import numpy as np
import pandas as pd
import argparse
import spacy
import matplotlib.pyplot as plt
from collections import defaultdict
from tqdm.auto import tqdm
from scipy.stats import entropy
from sklearn.metrics import pairwise_distances

np.seterr('raise')

DATA_PATH = ""
NAME = ""
SPACY_MODEL = spacy.load('en_core_web_md')

#csv of form topic_id, word1,word1_score,word2,word2_score,...
def get_word_given_topic(topic_filename):
    word_given_topic = {}
    with open(topic_filename, 'r') as f:
        #ignore first line
        f.readline()
        for line in f:
            parts = line.split(',')
            topic_id = int(parts[0])
            word_given_topic[topic_id] = {}
            for i in range(1, len(parts), 2):
                word_given_topic[topic_id][parts[i]] = float(parts[i+1])
    return word_given_topic

def get_documents(documents_filename):
    documents_df = pd.read_csv(documents_filename)
    docs = documents_df["message"].tolist()
    #make docs lowercase
    docs = [x.lower() for x in docs]
    return docs



def get_diversity(documents, word_given_topic, N, permutation = False):
    diversity_scores = []   
    
    #iterate through all topics to get top words
    top_words_per_topic = []
    for topic in word_given_topic:
        word_probs = word_given_topic[topic]
        top_words = sorted(word_probs, key=word_probs.get, reverse=True)[:N]
        top_words_per_topic.append(top_words)

    if(permutation):
        #randomize words across topics
        all_words = []
        for i in range(len(documents)):
            all_words += documents[i].split()
  
        top_words_per_topic = []
        for i in range(len(word_given_topic)):
            rand_N_words = np.random.choice(all_words, N, replace=False).tolist()
            top_words_per_topic.append(rand_N_words)

    #for each topic, calculate what % of top words are unique to only that topic
    for i in range(len(top_words_per_topic)):
        top_words = top_words_per_topic[i]       
        other_words = []
        for j in range(len(top_words_per_topic)):
            if (j != i): other_words += top_words_per_topic[j]       
        unique_words = [x for x in top_words if x not in other_words]
        diversity_scores.append(len(unique_words)/len(top_words))
    
    return np.mean(diversity_scores)



def get_npmi(word1, word2, documents):
    eps = 1e-8

    #p_w = num(docs with word)/num(docs)
    p_w1 = len([x for x in documents if word1 in x])/len(documents)
    p_w2 = len([x for x in documents if word2 in x])/len(documents)
    
    #p_w1_w2 = num(docs with word1 and word2)/num(docs)
    p_w1_w2 = len([x for x in documents if word1 in x and word2 in x])/len(documents)
    
    npmi = np.log((p_w1_w2 + eps)/(p_w1*p_w2))/(-np.log(p_w1_w2 + eps))
    return npmi

def get_relatedness(documents, word_given_topic, N, permutation = False):
    relatedness_scores = []

    #iterate through all topics to get top words
    top_words_per_topic = []
    for topic in word_given_topic:
        word_probs = word_given_topic[topic]
        top_words = sorted(word_probs, key=word_probs.get, reverse=True)[:N]
        top_words_per_topic.append(top_words)

    if(permutation):
        #randomize words across topics
        all_words = []
        for i in range(len(top_words_per_topic)):
            all_words += top_words_per_topic[i]
        
        #shuffle words across topics
        np.random.shuffle(all_words)
        top_words_per_topic = []
        for i in range(len(word_given_topic)):
            rand_N_words = all_words[i*N:(i+1)*N]
            top_words_per_topic.append(rand_N_words)

    #for each topic, calculate average NPMI between all pairs of top words
    for i in range(len(top_words_per_topic)):
        top_words = top_words_per_topic[i]
        npmi_scores = []
        for j in range(len(top_words)):
            for k in range(j+1, len(top_words)):
                npmi_scores.append(get_npmi(top_words[j], top_words[k], documents))
        if(len(npmi_scores) == 0): continue
        relatedness_scores.append(np.mean(npmi_scores))

    return np.mean(relatedness_scores)



def get_coherence(documents, word_given_topic, N):
    
    coherence_scores = []

    #iterate through all topics to get top words
    top_words_per_topic = []
    for topic in word_given_topic:
        word_probs = word_given_topic[topic]
        top_words = sorted(word_probs, key=word_probs.get, reverse=True)[:N]
        top_words_per_topic.append(top_words)


    #for each topic, calculate average similarity between all pairs of top words
    for i in range(len(top_words_per_topic)):
        top_words = top_words_per_topic[i]
        sim_scores = []
        for j in range(len(top_words)):
            for k in range(j+1, len(top_words)):
                sim_scores.append(SPACY_MODEL(top_words[j]).similarity(SPACY_MODEL(top_words[k])))
        coherence_scores.append(np.mean(sim_scores))

    return np.mean(coherence_scores)


def get_all_metrics(documents, word_given_topic, NAME, text):

    baseline_num = 100

    print("\n------ %s ------\n"%NAME)
    diversity = get_diversity(documents, word_given_topic, 25)
    print("diversity:", diversity)

    if(not text):
        diversity_baseline = []
        for i in tqdm(range(baseline_num)):
            diversity_perm = get_diversity(documents, word_given_topic, 25, permutation=True)
            diversity_baseline.append(diversity_perm)
        print("diversity baseline:", np.mean(diversity_baseline), np.std(diversity_baseline))

    relatedness = get_relatedness(documents, word_given_topic, 10)
    print("relatedness:", relatedness)

    if(not text):
        relatedness_baseline = []
        for i in tqdm(range(baseline_num)):
            relatedness_perm = get_relatedness(documents, word_given_topic, 10, permutation=True)
            relatedness_baseline.append(relatedness_perm)
        print("relatedness baseline:", np.mean(relatedness_baseline), np.std(relatedness_baseline))

    if(text):
        coherence = get_coherence(documents, word_given_topic, 10)
        print("coherence:", coherence)

    # print to csv
    with open("metrics/metrics_" + NAME + ("_text" if text else "") + ".csv", 'w') as f:
        f.write("diversity: %s\n"%diversity)
        f.write("relatedness: %s\n"%relatedness)
        if(not text):
            f.write("diversity baseline: %s, %s\n"%(np.mean(diversity_baseline), np.std(diversity_baseline)))
            f.write("relatedness baseline: %s, %s\n"%(np.mean(relatedness_baseline), np.std(relatedness_baseline)))
        if(text):
            f.write("coherence: %s\n"%coherence)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path', type=str, help='path to lda data')
    parser.add_argument('--name', type=str, help='name of experiment')
    parser.add_argument('--text', type=bool, default=False)
    args = parser.parse_args()

    if (args.data_path is None or args.name is None):
        print("Please specify path to lda data (--data_path) and name of experiment (--name).")
        exit()
    
    DATA_PATH = args.data_path
    NAME = args.name
    TEXT = args.text

    documents_filename = 'PATH' + DATA_PATH + '.csv'
    topic_filename = 'PATH' + DATA_PATH + '.csv'

    if(TEXT):
        documents_filename = 'PATH' + NAME + '.csv'
        topic_filename = 'PATH' + DATA_PATH + '.csv'

    documents = get_documents(documents_filename)
    word_given_topic = get_word_given_topic(topic_filename)
    
    get_all_metrics(documents, word_given_topic, NAME, text=TEXT)


main()