import itertools
import os
import numpy as np
from pathlib import Path

import torch
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

from src.data.dataset.loader import AISTDataset
from src.data.distance.nndtw import DTW
from plb.datamodules.data_transform import SkeletonTransform

from src.data.dataset.utils import save_paired_keypoints3d_as_video, rigid_align

data_dir = "../aistplusplus"

spatial_align = True

less_names = [  # each genre is charged with a pair
    # gBR
    ("gBR_sBM_cAll_d05_mBR4_ch04", 116, "gBR_sBM_cAll_d06_mBR4_ch04", 120),
    # gHO
    ('gHO_sBM_cAll_d21_mHO4_ch04', 52, 'gHO_sBM_cAll_d21_mHO2_ch04', 64),
    # gJB
    ('gJB_sBM_cAll_d07_mJB0_ch04', 208, 'gJB_sBM_cAll_d07_mJB3_ch04', 180),
    # gJS
    ('gJS_sBM_cAll_d02_mJS4_ch01', 470, 'gJS_sBM_cAll_d02_mJS5_ch01', 444),
    # gKR
    ('gKR_sBM_cAll_d30_mKR4_ch05', 98, 'gKR_sBM_cAll_d30_mKR2_ch05', 48),
    # gLH
    ('gLH_sBM_cAll_d17_mLH1_ch06', 156, 'gLH_sBM_cAll_d17_mLH5_ch06', 112),
    # gLO
    ('gLO_sBM_cAll_d13_mLO0_ch01', 164, 'gLO_sBM_cAll_d13_mLO2_ch01', 136),
    # gMH
    ('gMH_sBM_cAll_d23_mMH4_ch07', 108, 'gMH_sBM_cAll_d22_mMH3_ch07', 108),
    # gPO
    ('gPO_sBM_cAll_d12_mPO4_ch03', 168, 'gPO_sBM_cAll_d10_mPO2_ch03', 200),
    # gWA
    # ('gWA_sBM_cAll_d27_mWA4_ch09', 132, 'gWA_sBM_cAll_d27_mWA2_ch09', 100),  # visual difference too large
]

weirder_names = [  # each genre is charged with a pair
    # gBR
    ('gBR_sBM_cAll_d05_mBR4_ch04', 100, 'gPO_sBM_cAll_d10_mPO2_ch03', 100),
    # gHO
    ('gHO_sBM_cAll_d21_mHO4_ch04', 100, 'gMH_sBM_cAll_d22_mMH3_ch07', 100),
    # gJB
    ('gJB_sBM_cAll_d07_mJB0_ch04', 100, 'gLO_sBM_cAll_d13_mLO2_ch01', 100),
    # gJS
    ('gJS_sBM_cAll_d02_mJS4_ch01', 100, 'gKR_sBM_cAll_d30_mKR2_ch05', 100),
    # gKR
    ('gKR_sBM_cAll_d30_mKR4_ch05', 100, 'gLH_sBM_cAll_d17_mLH5_ch06', 100),
    # gLH
    ('gLH_sBM_cAll_d17_mLH1_ch06', 100, 'gJS_sBM_cAll_d02_mJS5_ch01', 100),
    # gLO
    ('gLO_sBM_cAll_d13_mLO0_ch01', 100, 'gJB_sBM_cAll_d07_mJB3_ch04', 100),
    # gMH
    ('gMH_sBM_cAll_d23_mMH4_ch07', 100, 'gHO_sBM_cAll_d21_mHO2_ch04', 100),
    # gPO
    ('gPO_sBM_cAll_d12_mPO4_ch03', 100, 'gBR_sBM_cAll_d06_mBR4_ch04', 100),
    # gWA
    # ('gWA_sBM_cAll_d27_mWA4_ch09', 132, 'gWA_sBM_cAll_d27_mWA2_ch09', 100),  # visual difference too large
]

more_names = [  # each genre is charged with a triplet as well as a pair, thus 3 + 1
    # gBR
    ("gBR_sBM_cAll_d05_mBR0_ch02", "gBR_sBM_cAll_d06_mBR3_ch02"),
    ("gBR_sBM_cAll_d05_mBR4_ch02", "gBR_sBM_cAll_d06_mBR3_ch02"),
    ("gBR_sBM_cAll_d05_mBR0_ch02", "gBR_sBM_cAll_d05_mBR4_ch02"),
    ("gBR_sBM_cAll_d05_mBR4_ch04", "gBR_sBM_cAll_d06_mBR4_ch04"),
    # gHO
    ('gHO_sBM_cAll_d20_mHO4_ch01', 'gHO_sBM_cAll_d20_mHO1_ch01'),
    ('gHO_sBM_cAll_d20_mHO1_ch01', 'gHO_sBM_cAll_d19_mHO0_ch01'),
    ('gHO_sBM_cAll_d19_mHO0_ch01', 'gHO_sBM_cAll_d20_mHO1_ch01'),
    ('gHO_sBM_cAll_d21_mHO4_ch04', 'gHO_sBM_cAll_d21_mHO2_ch04'),
    # gJB
    ('gJB_sBM_cAll_d07_mJB0_ch03', 'gJB_sBM_cAll_d08_mJB1_ch03'),
    ('gJB_sBM_cAll_d08_mJB1_ch03', 'gJB_sBM_cAll_d08_mJB4_ch03'),
    ('gJB_sBM_cAll_d08_mJB4_ch03', 'gJB_sBM_cAll_d08_mJB1_ch03'),
    ('gJB_sBM_cAll_d07_mJB0_ch04', 'gJB_sBM_cAll_d09_mJB2_ch04'),
    # gJS
    ('gJS_sBM_cAll_d03_mJS3_ch05', 'gJS_sBM_cAll_d02_mJS1_ch05'),
    ('gJS_sBM_cAll_d02_mJS1_ch05', 'gJS_sBM_cAll_d03_mJS2_ch05'),
    ('gJS_sBM_cAll_d03_mJS2_ch05', 'gJS_sBM_cAll_d02_mJS1_ch05'),
    ('gJS_sBM_cAll_d02_mJS4_ch01', 'gJS_sBM_cAll_d02_mJS5_ch01'),
    # gKR
    ('gKR_sBM_cAll_d28_mKR2_ch04', 'gKR_sBM_cAll_d29_mKR4_ch04'),
    ('gKR_sBM_cAll_d29_mKR4_ch04', 'gKR_sBM_cAll_d28_mKR0_ch04'),
    ('gKR_sBM_cAll_d28_mKR0_ch04', 'gKR_sBM_cAll_d29_mKR4_ch04'),
    ('gKR_sBM_cAll_d30_mKR4_ch05', 'gKR_sBM_cAll_d29_mKR1_ch05'),
    # gLH
    ('gLH_sBM_cAll_d17_mLH5_ch05', 'gLH_sBM_cAll_d17_mLH1_ch05'),
    ('gLH_sBM_cAll_d17_mLH1_ch05', 'gLH_sBM_cAll_d18_mLH2_ch05'),
    ('gLH_sBM_cAll_d18_mLH2_ch05', 'gLH_sBM_cAll_d17_mLH1_ch05'),
    ('gLH_sBM_cAll_d17_mLH1_ch06', 'gLH_sBM_cAll_d17_mLH5_ch06'),
    # gLO
    ('gLO_sBM_cAll_d15_mLO2_ch02', 'gLO_sBM_cAll_d13_mLO2_ch02'),
    ('gLO_sBM_cAll_d13_mLO2_ch02', 'gLO_sBM_cAll_d15_mLO3_ch02'),
    ('gLO_sBM_cAll_d15_mLO3_ch02', 'gLO_sBM_cAll_d13_mLO2_ch02'),
    ('gLO_sBM_cAll_d13_mLO0_ch01', 'gLO_sBM_cAll_d13_mLO2_ch01'),
    # gMH
    ('gMH_sBM_cAll_d23_mMH4_ch06', 'gMH_sBM_cAll_d22_mMH1_ch06'),
    ('gMH_sBM_cAll_d22_mMH1_ch06', 'gMH_sBM_cAll_d23_mMH0_ch06'),
    ('gMH_sBM_cAll_d23_mMH0_ch06', 'gMH_sBM_cAll_d22_mMH1_ch06'),
    ('gMH_sBM_cAll_d23_mMH4_ch07', 'gMH_sBM_cAll_d22_mMH3_ch07'),
    # gPO
    ('gPO_sBM_cAll_d10_mPO3_ch10', 'gPO_sBM_cAll_d11_mPO5_ch10'),
    ('gPO_sBM_cAll_d11_mPO5_ch10', 'gPO_sBM_cAll_d11_mPO1_ch10'),
    ('gPO_sBM_cAll_d11_mPO1_ch10', 'gPO_sBM_cAll_d11_mPO5_ch10'),
    ('gPO_sBM_cAll_d12_mPO4_ch03', 'gPO_sBM_cAll_d10_mPO2_ch03'),
    # gWA
    ('gWA_sBM_cAll_d26_mWA0_ch10', 'gWA_sBM_cAll_d27_mWA4_ch10'),
    ('gWA_sBM_cAll_d27_mWA4_ch10', 'gWA_sBM_cAll_d27_mWA3_ch10'),
    ('gWA_sBM_cAll_d27_mWA3_ch10', 'gWA_sBM_cAll_d27_mWA4_ch10'),
    ('gWA_sBM_cAll_d27_mWA4_ch09', 'gWA_sBM_cAll_d27_mWA2_ch09'),
]


def KendallsTau(dist):
    # numpy array of distances
    min_idx = np.argmin(dist, axis=-1)  # [M]
    numerator = 0
    denominator = 0
    for (s, l) in itertools.combinations(list(range(min_idx.shape[0])), 2):
        denominator += 1
        numerator += (int(min_idx[s] < min_idx[l]) * 2 - 1)
    return numerator / denominator


class ReconstructionValidator(object):
    def __init__(self, DISTANCE, log_dir, rank, world_size):
        if log_dir is not None:
            self.video_dir = os.path.join(log_dir, "saved_videos")
        self.data_dir = data_dir
        official_loader = AISTDataset(os.path.join(data_dir, "annotations"))
        self.rank = rank
        self.world_size = world_size
        fast_dir = Path(data_dir) / "gt_dists"
        fast_dir.mkdir(parents=True, exist_ok=True)

        self.names = less_names  # list of pairs of strings
        bulk = len(self.names) // self.world_size
        self.names = self.names[bulk * self.rank: bulk * self.rank + bulk]
        self.axis = np.array([0, 0, 1])
        # axis = axis / math.sqrt(np.dot(axis, axis))
        self.norm_frame = 0
        self.dim = 3
        data_transform = SkeletonTransform(aug_shift_prob=1., aug_shift_range=6, aug_rot_prob=1., aug_rot_range=0.3,
                                           aug_time_prob=1., aug_time_rate=1.5, min_length=64, max_length=64)

        self.skeletons_cpu = []  # list of pairs of normalized torch tensors
        self.lengths = []  # list of pairs of torch tensor of shape [1, ]
        self.gt_dists = []
        kt_metrics = []
        for (name1, e1, name2, e2) in tqdm(self.names):
            tba_skeletons = tuple()
            tba_lengths = tuple()
            for (name, e) in [(name1, e1), (name2, e2)]:
                ldd = official_loader.load_keypoint3d(name)[:e]
                x = torch.Tensor(ldd).flatten(1, -1) / 100  # [T, 51]
                x = data_transform(x, seed=6)[0]
                ttl = x.shape[0]
                tba_skeletons += (x,)
                tba_lengths += (torch.tensor([ttl]),)
            self.skeletons_cpu.append(tba_skeletons)
            self.lengths.append(tba_lengths)

            T1 = tba_lengths[0].item()
            T2 = tba_lengths[1].item()
            gt_dist = np.zeros((T1, T2))
            ldd1 = tba_skeletons[0].numpy().reshape(-1, 17, 3)
            ldd2 = tba_skeletons[1].numpy().reshape(-1, 17, 3)
            for i in range(T1):
                for j in range(T2):
                    aligned = rigid_align(ldd2[j], ldd1[i])
                    gt_dist[i, j] = np.linalg.norm(aligned - ldd1[i], ord=2)
            self.gt_dists.append(gt_dist[:e1, :e2])
            kt_metrics.append(KendallsTau(gt_dist))
        kt_metric = sum(kt_metrics) / len(kt_metrics)
        print(f"ReconstructionValidator initialization done with Kendall's Tau {kt_metric}")

        self.weirder_names = weirder_names  # list of pairs of strings
        bulk = len(self.weirder_names) // self.world_size
        self.weirder_names = self.weirder_names[bulk * self.rank: bulk * self.rank + bulk]
        self.weirder_skeletons_cpu = []  # list of pairs of normalized torch tensors
        self.weirder_lengths = []  # list of pairs of torch tensor of shape [1, ]
        self.weirder_gt_dists = []
        self.weirder_golden_paths = []
        kt_metrics = []
        for (name1, e1, name2, e2) in tqdm(self.weirder_names):
            tba_skeletons = tuple()
            tba_lengths = tuple()
            for (name, e) in [(name1, e1), (name2, e2)]:
                ldd = official_loader.load_keypoint3d(name)[:e]
                x = torch.Tensor(ldd).flatten(1, -1) / 100  # [T, 51]
                x = data_transform(x)[0]
                ttl = x.shape[0]
                tba_skeletons += (x,)
                tba_lengths += (torch.tensor([ttl]),)
            self.weirder_skeletons_cpu.append(tba_skeletons)
            self.weirder_lengths.append(tba_lengths)

            T1 = tba_lengths[0].item()
            T2 = tba_lengths[1].item()
            gt_dist = np.zeros((T1, T2))
            ldd1 = tba_skeletons[0].numpy().reshape(-1, 17, 3)
            ldd2 = tba_skeletons[1].numpy().reshape(-1, 17, 3)
            for i in range(T1):
                for j in range(T2):
                    aligned = rigid_align(ldd2[j], ldd1[i])
                    gt_dist[i, j] = np.linalg.norm(aligned - ldd1[i], ord=2)
            self.weirder_gt_dists.append(gt_dist[:e1, :e2])
            kt_metrics.append(KendallsTau(gt_dist))
        kt_metric = sum(kt_metrics) / len(kt_metrics)
        print(f"Weirder: ReconstructionValidator initialization done with Kendall's Tau {kt_metric}")
        self.device = -1

    def move_to_gpu(self):
        self.skeletons = []
        for i in range(len(self.names)):
            tbm = self.skeletons_cpu[i]
            self.skeletons.append((tbm[0].to(self.device), tbm[1].to(self.device),))

        self.weirder_skeletons = []
        for i in range(len(self.weirder_names)):
            tbm = self.weirder_skeletons_cpu[i]
            self.weirder_skeletons.append((tbm[0].to(self.device), tbm[1].to(self.device),))

    def __call__(self, model, save=-1):
        # take in a model on CPU, now we dismiss any save request
        if self.device == -1:
            self.device = model.device.index
            self.move_to_gpu()
        res = {}
        kt = []
        weird_kt = []
        weirder_kt = []
        for i in range(len(self.names)):
            features = []
            for ske, l in zip(self.skeletons[i], self.lengths[i]):
                features.append(model(ske.unsqueeze(0), l.to(self.device)).squeeze(1))  # [T, f]
            similarity = features[0] @ features[1].t()
            dist = -similarity.detach().cpu().numpy()
            kt.append(KendallsTau(dist))

            # another way to get feature during inference
            weird_features = []
            for ske, l in zip(self.skeletons[i], self.lengths[i]):
                one = model(ske[:l // 2].unsqueeze(0), (l // 2).to(self.device)).squeeze(1)
                two = model(ske[l // 2:].unsqueeze(0), (l - l // 2).to(self.device)).squeeze(1)
                weird_features.append(torch.cat([one, two], dim=0))  # [T, f]
            weird_similarity = weird_features[0] @ weird_features[1].t()
            weird_dist = -weird_similarity.detach().cpu().numpy()
            weird_kt.append(KendallsTau(weird_dist))

        for i in range(len(self.weirder_names)):
            features = []
            for ske, l in zip(self.weirder_skeletons[i], self.weirder_lengths[i]):
                features.append(model(ske.unsqueeze(0), l.to(self.device)).squeeze(1))  # [T, f]
            similarity = features[0] @ features[1].t()
            weirder_dist = -similarity.detach().cpu().numpy()
            weirder_kt.append(KendallsTau(weirder_dist))

        res.update({
            "Kendall's Tau": sum(kt) / len(kt),
            "Weird Kendall's Tau": sum(weird_kt) / len(weird_kt),
            "Weirder Kendall's Tau": sum(weirder_kt) / len(weirder_kt),
        })
        return res

    def save_video_dtw(self, model, video_path, pair_idx=0):
        # take in a model on CPU
        name1, name2 = self.names[pair_idx]
        if self.device == -1:
            self.device = model.device.index
            self.move_to_gpu()
        skeletons = self.skeletons[pair_idx]
        lengths = self.lengths[pair_idx]
        gt_dist = self.gt_dists[pair_idx]
        golden_path = self.golden_paths[pair_idx]
        official_loader = AISTDataset(os.path.join(self.data_dir, "annotations"))
        ldd = [official_loader.load_keypoint3d(name1), official_loader.load_keypoint3d(name2)]

        res = {}
        features = []
        for ske, l in zip(skeletons, lengths):
            features.append(model(ske.unsqueeze(0), l).squeeze(1))  # [T, f]
        dist = features[0] @ features[1].t()
        dist = dist.detach().cpu().numpy()
        path = DTW(-dist)  # DTW is a minimization algorithm
        metric = float(np.sum(gt_dist * path)) / (sum(dist.shape) / 2)
        res.update({"recons_dtw": metric})
        hit = np.sum(path * golden_path)
        recall = hit / golden_path.sum()
        precision = hit / path.sum()
        F_alpha = 2 * precision * recall / (precision + recall)
        res.update({"F_alpla": F_alpha, "Precision": precision, "Recall": recall})

        print(f"For {name1} and {name2} under test, {res}")

        tbc1 = []
        tbc2 = []
        captions1 = []
        captions2 = []
        print(path.nonzero())
        for first_idx, second_idx in zip(*path.nonzero()):
            tbc1.append(ldd[0][first_idx])
            captions1.append(f"{self.names[0]}: {first_idx} frame")
            aligned = rigid_align(ldd[1][second_idx], ldd[0][first_idx])
            tbc2.append(aligned)
            captions2.append(f"{self.names[1]}: {second_idx} frame")
        keypoints3d1 = np.stack(tbc1, axis=0)
        keypoints3d2 = np.stack(tbc2, axis=0)
        save_paired_keypoints3d_as_video(keypoints3d1, keypoints3d2,
                                         captions1, captions2,
                                         data_dir, video_path,
                                         align=True)

    def save_video_max(self, model, video_path, pair_idx=0):
        # take in a model on CPU
        name1, name2 = self.names[pair_idx]
        if self.device == -1:
            self.device = model.device.index
            self.move_to_gpu()
        skeletons = self.skeletons[pair_idx]
        lengths = self.lengths[pair_idx]
        gt_dist = self.gt_dists[pair_idx]
        golden_path = self.golden_paths[pair_idx]
        official_loader = AISTDataset(os.path.join(self.data_dir, "annotations"))
        ldd = [official_loader.load_keypoint3d(name1), official_loader.load_keypoint3d(name2)]

        res = {}
        features = []
        for ske, l in zip(skeletons, lengths):
            features.append(model(ske.unsqueeze(0), l).squeeze(1))  # [T, f]
        dist = features[0] @ features[1].t()
        dist = dist.detach().cpu().numpy()
        # path = DTW(-dist)  # DTW is a minimization algorithm
        path = np.max(dist, axis=-1)[:, None] == dist
        metric = float(np.sum(gt_dist * path)) / (sum(dist.shape) / 2)
        res.update({"recons_max": metric})
        hit = np.sum(path * golden_path)
        recall = hit / golden_path.sum()
        precision = hit / path.sum()
        F_alpha = 2 * precision * recall / (precision + recall)
        res.update({"F_alpla": F_alpha, "Precision": precision, "Recall": recall})

        print(f"For {name1} and {name2} under test, {res}")

        tbc1 = []
        tbc2 = []
        captions1 = []
        captions2 = []
        for first_idx, second_idx in zip(*path.nonzero()):
            tbc1.append(ldd[0][first_idx])
            captions1.append(f"{self.names[0]}: {first_idx} frame")
            aligned = rigid_align(ldd[1][second_idx], ldd[0][first_idx])
            tbc2.append(aligned)
            captions2.append(f"{self.names[1]}: {second_idx} frame")
        keypoints3d1 = np.stack(tbc1, axis=0)
        keypoints3d2 = np.stack(tbc2, axis=0)
        save_paired_keypoints3d_as_video(keypoints3d1, keypoints3d2,
                                         captions1, captions2,
                                         data_dir, video_path,
                                         align=True)


repertoire = {"RECONSTRUCTION": ReconstructionValidator}


def construct_validator(type, configs):
    return repertoire[type](**configs)


if __name__ == "__main__":
    for i in range(100):
        x = np.random.rand(12, 10)
        kt = KendallsTau(x)
        print("Kendall's Tau:", kt)
