import os
import pickle, time
import json, math
import numpy as np
import pandas as pd
from itertools import combinations
from sklearn.metrics import normalized_mutual_info_score
from scipy.stats import entropy
from src.data.dataset.cluster_misc import lexicon, get_names, genre_list, vidn_parse
from src.data.dataset.loader import AISTDataset
from src.data.dataset.utils import save_paired_keypoints3d_as_video, rigid_align, rigid_align_sequence
from src.data.distance.nndtw import DTW

data_dir = "../aistplusplus"


def preprocess(df):
    # input: a df with cols: idx, word, length, y, name
    # split advanced dance into multiple rows or remove them
    # give each snippet a tag of corresponding base dance
    res = pd.DataFrame(columns=["idx", "word", "length", "y", "label", "name"])
    for index, row in df.iterrows():
        if "sBM" in row["name"]:
            parsed = vidn_parse(row["name"])
            tba = dict(row)
            tba["label"] = int(parsed["choreo"][2:4])
            res = res.append(tba, ignore_index=True)
        else:
            raise NotImplementedError
    return res


def metric_nmi(df):
    # input: a df with cols: idx, word, length, y, name, label
    gt, pd = [], []
    for index, row in df.iterrows():
        gt += [row["label"], ] * row["length"]
        pd += [lexicon.index(row["word"]), ] * row["length"]
    return normalized_mutual_info_score(gt, pd)


def metric_nmi_composed(df):
    gt, pd = [], []
    name_list = []
    for index, row in df.iterrows():
        cur_name = row["name"]
        if len(name_list) == 0 or name_list[-1] != cur_name:
            name_list.append(cur_name)
        pd += [lexicon.index(row["word"]), ] * row["length"]
    for name in name_list:
        with open(os.path.join("./composed_vids", f"{name}.txt"), "r") as f:
            lines = f.readlines()
            for line in lines:
                ss = line.split("_")
                gt.append(int(ss[-1][2:4]))

    return normalized_mutual_info_score(gt, pd)


def ngram_ent(df, n=4, lb=1):  # this is not n-gram entropy, this is a pre-processing function
    # input: a df with cols: idx, word, length, y, name
    # input is not expected to have gone through preprocessing
    # lb: filter out instance <= lb frames
    bins = {}
    dfs = {_: list(x[x["length"] > lb]["word"]) for _, x in df.groupby('y') if len(x) > 1}
    for k, v in dfs.items():
        if len(v) > n:
            for i in range(len(v) - n):
                pattern = "".join(v[i:i + n])
                if pattern in bins:
                    bins[pattern] += 1
                else:
                    bins[pattern] = 1
    return bins


def nge(df, n=2, lb=5):  # this is n-gram entropy
    bins = ngram_ent(df, n, lb)
    dist = [v for k, v in bins.items()]
    ent = entropy(np.array(dist) / sum(dist))
    return ent


def nge_serious(df, K, n=2, lb=5):  # this is n-gram entropy
    bins = ngram_ent(df, n, lb)
    dist = [v for k, v in bins.items()] + [0] * (K ** n - len(bins))  # compensate for n-gram that did not appear
    ent = entropy(np.array(dist) / sum(dist))
    return ent


def purity(df, K=-1):  # purity means a wrong implementation of 2-gram entropy
    d = ngram_ent(df, n=2, lb=5)
    if K == -1:  # try to guess a reasonable K
        word_list = set(df["word"])
        K = max([lexicon.index(w) for w in word_list])
        K = 5 * (K // 5 + 1)
    x = np.zeros((K, K))
    for k, c in d.items():
        i = lexicon.index(k[:2])
        y = lexicon.index(k[2:])
        x[i, y] += c
    freq = np.sum(x, axis=-1)  # [K]
    freq = freq / np.sum(freq)
    ent_container = np.zeros((K,))
    for i in range(K):
        tbe = [_ for _ in x[i] if _ != 0]
        ent_container[i] = entropy(tbe)
    tbr = np.sum(freq * ent_container)  # K has been divided once in freq
    return tbr


def inter_cluster_purity(df):  # purity means mDD
    # input: a df with cols: idx, word, length, y, name
    # input is not expected to have gone through preprocessing
    dfs = {_: x for _, x in df.groupby('name') if len(x) > 1}
    official_loader = AISTDataset(os.path.join(data_dir, "annotations"))

    skes = {}  # word: list of skeletons
    for name, x in dfs.items():
        ldd = official_loader.load_keypoint3d(name)
        s = 0
        for index, row in x.iterrows():  # x is result for a single video
            e = s + row["length"]
            if row["length"] > 5:  # filter snippets too small
                if row["word"] in skes:
                    skes[row["word"]].append(ldd[s:e])
                else:
                    skes[row["word"]] = [ldd[s:e], ]
            s = e

    word_error_container = []
    for word, skes in skes.items():
        if len(skes) > 1:  # combinations([1], 2) returns []
            error_container = []
            for (a, b) in combinations(skes, 2):
                T1, T2 = a.shape[0], b.shape[0]
                b = rigid_align_sequence(b, a)
                x = a.reshape(-1, 51)[:, None, :] - b.reshape(-1, 51)[None, :, :]
                l2_map = np.linalg.norm(x, ord=2, axis=-1)  # [L1, L2]
                path = 2 * DTW(l2_map) / (T1 + T2) / 17
                error_container.append(np.sum(path * l2_map))
            normed_error = sum(error_container) / len(error_container)
            word_error_container.append(normed_error)

    return sum(word_error_container) / len(word_error_container)


def inter_cluster_purity_composed(df):  # purity means mDD
    # input: a df with cols: idx, word, length, y, name
    # input is not expected to have gone through preprocessing
    dfs = {_: x for _, x in df.groupby('name') if len(x) > 1}
    official_loader = AISTDataset(os.path.join(data_dir, "annotations"))

    skes = {}  # word: list of skeletons
    for name, x in dfs.items():
        ldd = np.load(f"./composed_vids/{name}.npy")
        s = 0
        for index, row in x.iterrows():  # x is result for a single video
            e = s + row["length"]
            if row["length"] > 5:  # filter snippets too small
                if row["word"] in skes:
                    skes[row["word"]].append(ldd[s:e])
                else:
                    skes[row["word"]] = [ldd[s:e], ]
            s = e

    word_error_container = []
    for word, skes in skes.items():
        if len(skes) > 1:  # combinations([1], 2) returns []
            error_container = []
            for (a, b) in combinations(skes, 2):
                T1, T2 = a.shape[0], b.shape[0]
                b = rigid_align_sequence(b, a)
                x = a.reshape(-1, 51)[:, None, :] - b.reshape(-1, 51)[None, :, :]
                l2_map = np.linalg.norm(x, ord=2, axis=-1)  # [L1, L2]
                path = 2 * DTW(l2_map) / (T1 + T2) / 17
                error_container.append(np.sum(path * l2_map))
            normed_error = sum(error_container) / len(error_container)
            word_error_container.append(normed_error)

    return sum(word_error_container) / len(word_error_container)
