import os
import pickle
import torch
import numpy as np
import pandas as pd
import argparse
import matplotlib.pyplot as plt
from collections import defaultdict
from tqdm.auto import tqdm
import scipy.stats
from sklearn.metrics import pairwise_distances

DATA_PATH = ""
NAME = ""

#csv of form image_id,topic1_score,topic2_score,...
def get_topic_given_document(topic_filename):
    topic_given_document = {}
    with open(topic_filename, 'r') as f:
        #ignore first line
        f.readline()
        for line in f:
            parts = line.split(',')
            doc_id = str(parts[0])
            topic_given_document[doc_id] = {}
            for i in range(1, len(parts)):
                topic_given_document[doc_id][i-1] = float(parts[i])
                
    return topic_given_document

def get_documents(documents_filename):
    documents_df = pd.read_csv(documents_filename)
    image_paths = documents_df["image_id"].tolist()
    labels = documents_df["label"].tolist()
    #make dict of image_path -> label
    docs = {}
    for i in range(0, len(image_paths)):
        docs[image_paths[i]] = labels[i]
    return docs



def get_entropy(documents, topic_given_document, NAME, N):  
    entropies = []
    
    num_topics = len(list(topic_given_document.values())[0])
    for i in range(0, num_topics):
        #find the N documents with highest topic score
        topic_scores = []
        for doc_id in topic_given_document:
            topic_scores.append((topic_given_document[doc_id][i], doc_id))
        topic_scores = np.array(topic_scores)
        sorted_topic_scores = topic_scores[topic_scores[:,0].argsort()[::-1]]
        sorted_doc_ids = sorted_topic_scores[:N,1]        
        
        topic_labels = [documents[doc_id] for doc_id in sorted_doc_ids]

        #get distribution of labels
        all_labels = list(set(documents.values()))        
        probabilities = []
        for l in all_labels:
            probabilities.append(len([x for x in topic_labels if x == l]))
        probabilities = np.array(probabilities) / len(topic_labels)
        
        topic_entropy = scipy.stats.entropy(probabilities, base=len(all_labels))
        entropies.append(topic_entropy)
        print("Topic %d entropy: %f"%(i, topic_entropy))

    with open("metrics/entropy_results/entropy_" + NAME + ".csv", 'w') as f:
        f.write("entropy: %s\n"%entropies)
    
    return entropies

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path', type=str, help='path to lda data')
    parser.add_argument('--name', type=str, help='name of experiment')
    args = parser.parse_args()

    if (args.data_path is None or args.name is None):
        print("Please specify path to lda data (--data_path) and name of experiment (--name).")
        exit()
    
    DATA_PATH = args.data_path
    NAME = args.name

    documents_filename = 'PATH' + DATA_PATH + '.csv'
    topic_filename = 'PATH' + DATA_PATH + '.csv'

    documents = get_documents(documents_filename)
    topic_given_document = get_topic_given_document(topic_filename)
    
    entropies = get_entropy(documents, topic_given_document, NAME, 50)

main()