import os
from tqdm import tqdm
import torch
import numpy as np
from utils import *
from dataset import tokenize_line
from scipy.stats import spearmanr
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression


def read_twenty_ng_class(folder_path, emb_dataset, model, reduction="mean"):
    """
    read all documents in dataset folder path and return list of 
    averaged word embeddings
    """
    result = []
    for file in tqdm(os.listdir(folder_path)):
        token_data = []
        with open(f"{folder_path}/{file}","r",encoding="utf-8") as f:
            try:
                for line in f:
                    token_data.extend(tokenize_line(line))
            except:
                if not token_data:
                    continue # skip this document
        # convert to word index
        token_data = list(map(emb_dataset.word_to_idx, token_data))
        
        # convert to embedding, average
        token_data = torch.tensor(token_data).to(model.embeddings[0].weight.device)
        sent_emb = model.embed_sentence(token_data, reduction=reduction).cpu().numpy()
        # concact to overall result list
        result.append(sent_emb)
    print(f"dataset:{folder_path}, data shape: {np.array(result).shape}")
    return np.array(result)

def read_movie_sentiment_dataset(emb_dataset, model, reduction="mean"):
    import pandas as pd
    folder_path = "classification_datasets/stanfordSentimentTreebank"

    # read dictionary.txt file (convert sentiment value mapping from phrase id to sentence id)
    df_sentences = pd.read_table(f"{folder_path}/datasetSentences.txt")
    df_dictionary = pd.read_table(f"{folder_path}/dictionary.txt", sep="|", names=["phrase",  "phrase_index"])
    df_sentence_phrase = pd.merge(left=df_sentences, right=df_dictionary, left_on="sentence", right_on="phrase")
    del df_dictionary 
    df_sentence_phrase = df_sentence_phrase.drop(columns=["sentence"])
    df_sentiment = pd.read_table(f"{folder_path}/sentiment_labels.txt", sep="|")
    df_sentence_sentiment = pd.merge(left=df_sentence_phrase, right=df_sentiment, left_on="phrase_index", right_on="phrase ids")
    del df_sentence_phrase, df_sentiment
    df_sentence_sentiment.drop(columns=["phrase ids", "phrase_index"])
    df_split = pd.read_table(f"{folder_path}/datasetSplit.txt", sep=",")
    df_sentence_sentiment = pd.merge(left=df_sentence_sentiment, right=df_split, left_on="sentence_index", right_on="sentence_index")
    
    # remove "neutral examples"
    df_sentence_sentiment = df_sentence_sentiment[(df_sentence_sentiment["sentiment values"]>0.6)|(df_sentence_sentiment["sentiment values"]<=0.4)]
    # binarize sentiment label
    df_sentence_sentiment.loc[df_sentence_sentiment["sentiment values"]<=0.4, "sentiment values"] = 0.0
    df_sentence_sentiment.loc[df_sentence_sentiment["sentiment values"]> 0.6, "sentiment values"] = 1.0

    sent_embs = []
    for line in tqdm(list(df_sentence_sentiment["phrase"])):
        token_data = tokenize_line(line)
        token_data = list(map(emb_dataset.word_to_idx, token_data))
        # token_data = torch.tensor(token_data).to(model.embeddings2.weight.device)
        token_data = torch.tensor(token_data).to(model.embeddings[0].weight.device)
        sent_embs.append(model.embed_sentence(token_data, reduction=reduction).cpu().numpy())
    df_sentence_sentiment["phrase"] = sent_embs
    print(f"dataset: movie sentiment. number of data points: {len(df_sentence_sentiment)}")
    
    tain_split = df_sentence_sentiment["splitset_label"]==1
    test_split = df_sentence_sentiment["splitset_label"]==2
    val_split  = df_sentence_sentiment["splitset_label"]==3

    data = {"train":list(df_sentence_sentiment[tain_split]["phrase"]),
            "test":list(df_sentence_sentiment[test_split]["phrase"]),
            "val":list(df_sentence_sentiment[val_split]["phrase"])}
    labels = {"train":list(df_sentence_sentiment[tain_split]["sentiment values"]),
              "test":list(df_sentence_sentiment[test_split]["sentiment values"]),
              "val":list(df_sentence_sentiment[val_split]["sentiment values"])}
    return data, labels

def train_val_classifier(data, labels, test):
    """
    average word vectors. l2-regularized logistic regression classifier
    Tsvetkov et al. 2015
    https://www.aclweb.org/anthology/D15-1243.pdf

    :param data: dictionary of train/test/val of numpy data of shape N X E
    :param labels: dictionary of train/test/val of list of 0, 1 as labels
    :param test: if validate on test set (only for final paper report)
    :return classification metrics
    """
    model = LogisticRegression(penalty="l2",solver="lbfgs",max_iter=5000)
    model.fit(data["train"], labels["train"])

    # need a hyper param loop
    result = {}
    pred = model.predict(data["test"])
    result["test_accuracy"] = (pred == labels["test"]).sum()/len(pred)
    pred = model.predict(data["val"])
    result["val_accuracy"] = (pred == labels["val"]).sum()/len(pred)
    return result


def eval_twenty_ng_dataset(topic, emb_dataset, model, test, reduction):
    """
    using 20 news group dataset same way as Yogatama & Smith 2014
    https://homes.cs.washington.edu/~nasmith/papers/yogatama+smith.acl14.pdf
    4 casks:
        - comp.sys:   ibm.pc.hardware  v  mac.hardware
        - rec.sport:  baseball  v  hockey
        - sci:        med  v space
        -             alt.atheism  v  soc.religion.christian

    read two topics in, shuffle, train/val/test split
    approx 48% train, 12% val, and 40% test
    Yogatama & Smith 2014
    https://homes.cs.washington.edu/~nasmith/papers/yogatama+smith.acl14.pdf
    dataset source: http://qwone.com/~jason/20Newsgroups/
    """
    topics = {
        "comp":("comp.sys.ibm.pc.hardware", "comp.sys.mac.hardware"),
        "sport":("rec.sport.baseball", "rec.sport.hockey"),
        "sci":("sci.med", "sci.space"),
        "reli":("alt.atheism", "soc.religion.christian")
    }
    data1 = read_twenty_ng_class(f"classification_datasets/20_newsgroups/{topics[topic][0]}", emb_dataset, model, reduction)
    labels = [0]*len(data1)
    data2 = read_twenty_ng_class(f"classification_datasets/20_newsgroups/{topics[topic][1]}", emb_dataset, model, reduction)
    labels.extend([1]*(len(data2)))
    data = np.append(data1, data2, axis=0)
    del data1, data2

    x_train, x_test, y_train, y_test = train_test_split(data, labels, test_size=0.52, random_state=42)
    x_val, x_test, y_val, y_test = train_test_split(x_test, y_test, test_size=0.77, random_state=42)
    data = {"train":x_train, "val":x_val, "test":x_test}
    labels = {"train":y_train,"val":y_val,"test":y_test}
    return train_val_classifier(data, labels, test)


def eval_movie_sentiment_dataset(emb_dataset, model, test, reduction):
    """
    approx train (72%), val (9%), test (19*)
    after disgarding all neutral views it's about 80% of each classes from original data
    Yogatama & Smith 2014
    https://homes.cs.washington.edu/~nasmith/papers/yogatama+smith.acl14.pdf
    dataset source: https://nlp.stanford.edu/sentiment/
    """
    data, labels = read_movie_sentiment_dataset(emb_dataset, model, reduction)
    return train_val_classifier(data, labels, test)


def eval_classification_dataset(dataset_name, emb_dataset, model, test=False):
    result = {}
    if "twenty_ng" in dataset_name:
        topic = dataset_name.replace("twenty_ng","").strip("_")
        result["sum"] = eval_twenty_ng_dataset(topic, emb_dataset, model, test, "sum")
        result["mean"] = eval_twenty_ng_dataset(topic, emb_dataset, model, test, "mean")
    if "movie_sentiment" in dataset_name:
        result["sum"] = eval_movie_sentiment_dataset(emb_dataset, model, test, "sum")
        result["mean"] = eval_movie_sentiment_dataset(emb_dataset, model, test, "mean")
    return result

def calculate_20ng_statistics():
    topics = {
        "comp":("comp.sys.ibm.pc.hardware", "comp.sys.mac.hardware"),
        "sport":("rec.sport.baseball", "rec.sport.hockey"),
        "sci":("sci.med", "sci.space"),
        "reli":("alt.atheism", "soc.religion.christian")
    }
    
    for category, folders  in tqdm(topics.items()):
        folder_1 = f"classification_datasets/20_newsgroups/{folders[0]}"
        folder_2 = f"classification_datasets/20_newsgroups/{folders[1]}"
        result = []
        for file in os.listdir(folder_1):
            token_data = []
            with open(f"{folder_1}/{file}","r",encoding="utf-8") as f:
                try:
                    for line in f:
                        token_data.extend(tokenize_line(line))
                except:
                    if not token_data:
                        continue # skip this document
            result.append(len(token_data))
        for file in os.listdir(folder_2):
            token_data = []
            with open(f"{folder_2}/{file}","r",encoding="utf-8") as f:
                try:
                    for line in f:
                        token_data.extend(tokenize_line(line))
                except:
                    if not token_data:
                        continue # skip this document
            result.append(len(token_data))
        result = np.array(result)
        print(result)
        print(result.shape)
        print(f"category: {category}, average: {result.mean()}, median: {np.median(result)}")

def calculate_movie_statistics():
    import pandas as pd
    folder_path = "classification_datasets/stanfordSentimentTreebank"

    # read dictionary.txt file (convert sentiment value mapping from phrase id to sentence id)
    df_sentences = pd.read_table(f"{folder_path}/datasetSentences.txt")
    df_dictionary = pd.read_table(f"{folder_path}/dictionary.txt", sep="|", names=["phrase",  "phrase_index"])
    df_sentence_phrase = pd.merge(left=df_sentences, right=df_dictionary, left_on="sentence", right_on="phrase")
    del df_dictionary 
    df_sentence_phrase = df_sentence_phrase.drop(columns=["sentence"])
    df_sentiment = pd.read_table(f"{folder_path}/sentiment_labels.txt", sep="|")
    df_sentence_sentiment = pd.merge(left=df_sentence_phrase, right=df_sentiment, left_on="phrase_index", right_on="phrase ids")
    del df_sentence_phrase, df_sentiment
    df_sentence_sentiment.drop(columns=["phrase ids", "phrase_index"])
    df_split = pd.read_table(f"{folder_path}/datasetSplit.txt", sep=",")
    df_sentence_sentiment = pd.merge(left=df_sentence_sentiment, right=df_split, left_on="sentence_index", right_on="sentence_index")
    
    # remove "neutral examples"
    df_sentence_sentiment = df_sentence_sentiment[(df_sentence_sentiment["sentiment values"]>0.6)|(df_sentence_sentiment["sentiment values"]<=0.4)]
    # binarize sentiment label
    df_sentence_sentiment.loc[df_sentence_sentiment["sentiment values"]<=0.4, "sentiment values"] = 0.0
    df_sentence_sentiment.loc[df_sentence_sentiment["sentiment values"]> 0.6, "sentiment values"] = 1.0

    lens = []
    for line in tqdm(list(df_sentence_sentiment["phrase"])):
        token_data = tokenize_line(line)
        lens.append(len(token_data))
    lens = np.array(lens)

    print(f"movie sentiment data, avg len:{lens.mean()}, median: {np.median(lens)}")


if __name__ == "__main__":
    # calculate_20ng_statistics()
    calculate_movie_statistics()

