""" pretrain
- kNN Precision
"""

""" finetune
- Recall / Recall@3s / AP
"""

from collections import defaultdict
from typing import Dict

import torch
import torchmetrics
from sklearn.metrics import (
    average_precision_score,
    precision_recall_curve,
    recall_score,
    roc_auc_score,
    roc_curve,
)

import numpy as np
import os
import json

from torchmetrics.utilities import dim_zero_cat

class KnnPrecisionMetric(torchmetrics.Metric):
    def __init__(self, top_k_list):
        # super().__init__(compute_on_step=False, dist_sync_on_step=True)
        super().__init__(dist_sync_on_step=True)
        self.add_state("feat_data", default=[], dist_reduce_fx=None)
        self.add_state("vids_data", default=[], dist_reduce_fx=None)
        self.add_state("scene_data", default=[], dist_reduce_fx=None)
        self.top_k_list = set(top_k_list)
        self.max_k = max(self.top_k_list)

    def update(self, vid, invideo_scene_id, feat):
        assert isinstance(invideo_scene_id, torch.Tensor)
        assert isinstance(vid, torch.Tensor)
        assert isinstance(feat, torch.Tensor)
        self.feat_data.append(feat)
        self.vids_data.append(vid)
        self.scene_data.append(invideo_scene_id)

    def compute(self) -> torch.Tensor:
        score = defaultdict(dict)
        pool_feats = defaultdict(list)
        pool_invideo_scene_id = defaultdict(list)
        pool_gts = defaultdict(dict)

        num_data = 0
        for vid, invideo_scene_id, gathered_feat in zip(
            self.vids_data, self.scene_data, self.feat_data
        ):
            vid = vid.item()
            invideo_scene_id = invideo_scene_id.item()
            if invideo_scene_id not in pool_gts[vid]:
                pool_gts[vid][invideo_scene_id] = set()
            pool_gts[vid][invideo_scene_id].add(len(pool_feats[vid]))
            pool_invideo_scene_id[vid].append(invideo_scene_id)
            pool_feats[vid].append(gathered_feat)
            num_data += 1

        for top_k in self.top_k_list:
            score[top_k] = {"correct": 0, "total": 0}

        for vid, gt in pool_gts.items():
            X = torch.stack(pool_feats[vid])
            sim = torch.matmul(X, X.t())
            sim = sim - 999 * torch.eye(sim.shape[0]).type_as(sim)  # exclude self
            indices = torch.argsort(sim, descending=True)
            assert indices.shape[1] >= self.max_k, f"{indices.shape[1]} >= {self.max_k}"
            indices = indices[:, : self.max_k]

            for j in range(indices.shape[0]):
                _cache = {"correct": 0, "total": 0}
                _query_scene_id = pool_invideo_scene_id[vid][j]
                for k in range(self.max_k):
                    if _query_scene_id in gt:
                        if indices[j][k].item() in gt[_query_scene_id]:
                            _cache["correct"] += 1
                    _cache["total"] += 1
                    if k + 1 in self.top_k_list and len(gt[_query_scene_id]) > k:
                        score[k + 1]["correct"] += _cache["correct"]
                        score[k + 1]["total"] += _cache["total"]

        for top_k in self.top_k_list:
            assert score[top_k]["total"] > 0
            score[top_k]["precision"] = (
                100.0 * score[top_k]["correct"] / score[top_k]["total"]
            )
        del X, sim, indices, pool_feats, pool_invideo_scene_id, pool_gts
        torch.cuda.empty_cache()
        return score


class F1ScoreMetric(torchmetrics.classification.F1Score):
    def __init__(self, **metric_args):

        # # metrics_args = {"compute_on_step": False, "dist_sync_on_step": True}
        # super().__init__(**metrics_args)
        super().__init__()


class AveragePrecisionMetric(torchmetrics.classification.AveragePrecision):
    """
    ref:
        - https://github.com/PyTorchLightning/metrics/blob/master/torchmetrics/classification/average_precision.py
    """

    def __init__(self, **metric_args):

        # metrics_args = {"compute_on_step": False, "dist_sync_on_step": True}
        # super().__init__(**metrics_args)
        super().__init__()


class SklearnAPMetric(torchmetrics.Metric):
    def __init__(self, **metric_args):

        # metrics_args = {"compute_on_step": False, "dist_sync_on_step": True}
        # super().__init__(**metrics_args)
        super().__init__()
        self.add_state("prob", default=[], dist_reduce_fx="cat")
        self.add_state("gts", default=[], dist_reduce_fx="cat")

    def update(self, prob, gts):
        assert isinstance(prob, torch.FloatTensor) or isinstance(
            prob, torch.cuda.FloatTensor
        )
        assert isinstance(gts, torch.LongTensor) or isinstance(
            gts, torch.cuda.LongTensor
        )

        self.prob.append(prob)
        self.gts.append(gts)

    def compute(self) -> torch.Tensor:
        prob = dim_zero_cat(self.prob).cpu().numpy() # self.prob.cpu().numpy()
        gts = dim_zero_cat(self.gts).cpu().numpy() # self.gts.cpu().numpy()
        # prob = np.concatenate([i.cpu().numpy() for i in self.prob], axis=0) # self.prob.cpu().numpy()
        # gts = np.concatenate([i.cpu().numpy() for i in self.gts], axis=0) # self.gts.cpu().numpy()
        ap = average_precision_score(np.nan_to_num(gts), np.nan_to_num(prob))
        precision, recall, thresholds = precision_recall_curve(
            np.nan_to_num(gts), np.nan_to_num(prob)
        )
        # ap = torch.Tensor([ap]).type_as(self.prob)
        # precision = torch.Tensor([precision]).type_as(self.prob)
        # recall = torch.Tensor([recall]).type_as(self.prob)
        # ap = torch.Tensor(np.array([ap])).type_as(self.prob)
        # precision = torch.Tensor(np.array([precision])).type_as(self.prob)
        # recall = torch.Tensor(np.array([recall])).type_as(self.prob)
        ap = torch.Tensor(np.array([ap])).type_as(self.prob[0])
        precision = torch.Tensor(np.array([precision])).type_as(self.prob[0])
        recall = torch.Tensor(np.array([recall])).type_as(self.prob[0])
        return ap, precision, recall


class SklearnAUCROCMetric(torchmetrics.Metric):
    def __init__(self, **metric_args):

        # metrics_args = {"compute_on_step": False, "dist_sync_on_step": True}
        # super().__init__(**metrics_args)
        super().__init__()
        self.add_state("prob", default=[], dist_reduce_fx="cat")
        self.add_state("gts", default=[], dist_reduce_fx="cat")

    def update(self, prob, gts):
        assert isinstance(prob, torch.FloatTensor) or isinstance(
            prob, torch.cuda.FloatTensor
        )
        assert isinstance(gts, torch.LongTensor) or isinstance(
            gts, torch.cuda.LongTensor
        )

        self.prob.append(prob)
        self.gts.append(gts)

    def compute(self) -> torch.Tensor:
        prob = dim_zero_cat(self.prob).cpu().numpy() # self.prob.cpu().numpy()
        gts = dim_zero_cat(self.gts).cpu().numpy() # self.gts.cpu().numpy()
        auc = roc_auc_score(np.nan_to_num(gts), np.nan_to_num(prob))
        fpr, tpr, threshold = roc_curve(np.nan_to_num(gts), np.nan_to_num(prob))

        # auc = torch.Tensor([auc]).type_as(self.prob)
        # fpr = torch.Tensor([fpr]).type_as(self.prob)
        # tpr = torch.Tensor([tpr]).type_as(self.prob)
        auc = torch.Tensor(np.array([auc])).type_as(self.prob[0])
        fpr = torch.Tensor(np.array([fpr])).type_as(self.prob[0])
        tpr = torch.Tensor(np.array([tpr])).type_as(self.prob[0])
        return auc, fpr, tpr


# class AccuracyMetric(torchmetrics.classification.Accuracy):
class AccuracyMetric(torchmetrics.Metric):
    """
    ref:
        - https://github.com/PyTorchLightning/metrics/blob/master/torchmetrics/classification/accuracy.py
        - https://github.com/PyTorchLightning/metrics/blob/f61317ca17e3facc16e09c0e6cef0680961fc4ff/torchmetrics/functional/classification/accuracy.py#L72
    """

    def __init__(self, **metric_args):

        # metrics_args = {"compute_on_step": False, "dist_sync_on_step": True}
        # super().__init__(**metrics_args)
        super().__init__()

        self.eps = 1e-5
        self.threshold = 0.5
        self.add_state("tp", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("fp", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("tn", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("fn", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, prob, labels):

        assert isinstance(prob, torch.FloatTensor) or isinstance(
            prob, torch.cuda.FloatTensor
        )
        assert isinstance(labels, torch.LongTensor) or isinstance(
            labels, torch.cuda.LongTensor
        )

        gt_one = labels == 1
        gt_zero = labels == 0
        pred_one = prob >= self.threshold
        pred_zero = prob < self.threshold

        self.tp += (gt_one * pred_one).sum()
        self.fp += (gt_zero * pred_one).sum()
        self.tn += (gt_zero * pred_zero).sum()
        self.fn += (gt_one * pred_zero).sum()

    def compute(self) -> Dict[str, torch.Tensor]:
        # compute final result
        tp = self.tp
        fp = self.fp
        tn = self.tn
        fn = self.fn

        assert (tp + fn) > 0
        assert (fp + tn) > 0

        output = {}
        output["acc1"] = 100.0 * tp / (tp + fn + self.eps)
        output["acc0"] = 100.0 * tn / (fp + tn + self.eps)
        output["acc"] = 100.0 * (tp + tn) / (tp + fn + fp + tn + self.eps)

        # if __DEBUG__:
        #     self.print(
        #         f"TP:{tp.item()} / FP:{fp.item()} / TN:{tn.item()} / FN:{fn.item()}"
        #     )

        return output


class MovieNetMetric(torchmetrics.Metric):
    def __init__(self, cfg):
        # super().__init__(compute_on_step=False, dist_sync_on_step=True)
        super().__init__()
        self.add_state("vidx_data", default=[], dist_reduce_fx="cat")
        self.add_state("sid_data", default=[], dist_reduce_fx="cat")
        self.add_state("pred_data", default=[], dist_reduce_fx="cat")
        self.add_state("gt_data", default=[], dist_reduce_fx="cat")

        self.cfg = cfg

        # self.vid2idx = json.load(
        #     open(
        #         "/playpen-storage/mmiemon/MovieNet/bassl/bassl/data/movienet/anno/vid2idx.json", "r",
        #     )
        # )
        # self.idx2vid = {idx: vid for vid, idx in self.vid2idx.items()}
        # self.shot_path = "/playpen-storage/mmiemon/MovieNet/bassl/bassl/data/movienet/scene318/shot_movie318"

        self.vid2idx = json.load(
            open(
                # "data/movienet/anno/vid2idx.json", "r", #"bassl/data/movienet/anno/vid2idx.json", "r",
                os.path.join(cfg.DATA_PATH, "anno/vid2idx.json"), "r"
            )
        )
        self.idx2vid = {idx: vid for vid, idx in self.vid2idx.items()}
        self.shot_path = os.path.join(cfg.DATA_PATH, "scene318/shot_movie318") #"data/movienet/scene318/shot_movie318" #"bassl/data/movienet/scene318/shot_movie318"
    
    def update(self, vid, sid, pred, gt):
        # assert isinstance(vid, torch.Tensor)

        self.vidx_data.append(pred.new_tensor(self.vid2idx[vid], dtype=torch.long))
        self.sid_data.append(pred.new_tensor(int(sid), dtype=torch.long))
        self.pred_data.append(pred)
        self.gt_data.append(gt)

    def compute(self) -> torch.Tensor:

        result = defaultdict(dict)
        for vidx, sid, pred, gt in zip(
            self.vidx_data, self.sid_data, self.pred_data, self.gt_data
        ):
            result[self.idx2vid[vidx.item()]][sid.item()] = {
                "pred": pred.item(),
                "gt": gt.item(),
            }

        # compute exact recall
        recall = self._compute_exact_recall(result)
        recall_at_second = self._compute_recall_at_second(result)
        miou = self._compute_mIoU(result)

        del result  # recall, recall_one, pred, gt, preds, gts
        torch.cuda.empty_cache()
        return recall, recall_at_second, miou

    def _compute_exact_recall(self, result):
        recall = []
        for _, result_dict_one in result.items():
            preds, gts = [], []
            for _, item in result_dict_one.items():
                pred = int(item.get("pred"))
                gt = int(item.get("gt"))
                preds.append(pred)
                gts.append(gt)

            recall_one = recall_score(gts, preds, average="binary")
            recall.append(recall_one)
        # print('Recall: ', np.mean(recall))

        recall = np.mean(recall)
        pt_recall = self.pred_data[0].new_tensor(recall, dtype=torch.float)

        del recall, recall_one, pred, gt, preds, gts
        return pt_recall

    def _compute_recall_at_second(self, result, num_neighbor_shot=5, threshold=3):

        recall = []
        for vid, result_dict_one in result.items():
            shot_fn = "{}/{}.txt".format(self.shot_path, vid)
            with open(shot_fn, "r") as f:
                shot_list = f.read().splitlines()

            cont_one, total_one = 0, 0
            for shotid, item in result_dict_one.items():
                gt = int(item.get("gt"))
                shot_time = int(shot_list[int(shotid)].split(" ")[1])
                if gt != 1:
                    continue

                total_one += 1
                for ind in range(0 - num_neighbor_shot, 1 + num_neighbor_shot):
                    # shotid_cp = self.strcal(shotid, ind)
                    shotid_cp = shotid + ind
                    if shotid_cp < 0 or (shotid_cp >= len(shot_list)):
                        continue
                    shot_time_cp = int(shot_list[shotid_cp].split(" ")[1])
                    item_cp = result_dict_one.get(shotid_cp)
                    if item_cp is None:
                        continue
                    else:
                        pred = item_cp.get("pred")
                        # FPS == 24
                        gap_time = np.abs(shot_time_cp - shot_time) / 24
                        if gt == pred and gap_time < threshold:
                            cont_one += 1
                            break

            recall_one = cont_one / (total_one + 1e-5)
            recall.append(recall_one)

        recall = np.mean(recall)
        pt_recall = self.pred_data[0].new_tensor(recall, dtype=torch.float)
        return pt_recall

    def _compute_mIoU(self, result):
        mious = []
        for vid, result_dict_one in result.items():
            shot_fn = "{}/{}.txt".format(self.shot_path, vid)
            with open(shot_fn, "r") as f:
                shot_list = f.read().splitlines()

            gt_dict_one, pred_dict_one = {}, {}
            for shotid, item in result_dict_one.items():
                gt_dict_one.update({shotid: item.get("gt")})
                pred_dict_one.update({shotid: item.get("pred")})
            gt_pair_list = self.get_pair_list(gt_dict_one)
            pred_pair_list = self.get_pair_list(pred_dict_one)
            if pred_pair_list is None:
                mious.append(0)
                continue
            gt_scene_list = self.get_scene_list(gt_pair_list, shot_list)
            pred_scene_list = self.get_scene_list(pred_pair_list, shot_list)
            if gt_scene_list is None or pred_scene_list is None:
                return None
            miou1 = self.cal_miou(gt_scene_list, pred_scene_list)
            miou2 = self.cal_miou(pred_scene_list, gt_scene_list)
            mious.append(np.mean([miou1, miou2]))

        mious = np.mean(mious)
        pt_miou = self.pred_data[0].new_tensor(mious, dtype=torch.float)
        return pt_miou

    def get_scene_list(self, pair_list, shot_list):
        scene_list = []
        if pair_list is None:
            return None
        for item in pair_list:
            start = int(shot_list[int(item[0])].split(" ")[0])
            end = int(shot_list[int(item[-1])].split(" ")[1])
            scene_list.append((start, end))
        return scene_list

    def cal_miou(self, gt_scene_list, pred_scene_list):
        mious = []
        for gt_scene_item in gt_scene_list:
            rats = []
            for pred_scene_item in pred_scene_list:
                rat = self.getRatio(pred_scene_item, gt_scene_item)
                rats.append(rat)
            mious.append(np.max(rats))
        miou = np.mean(mious)
        return miou

    def getRatio(self, interval_1, interval_2):
        interaction = self.getIntersection(interval_1, interval_2)
        if interaction == 0:
            return 0
        else:
            return interaction / self.getUnion(interval_1, interval_2)

    def getIntersection(self, interval_1, interval_2):
        assert interval_1[0] < interval_1[1], "start frame is bigger than end frame."
        assert interval_2[0] < interval_2[1], "start frame is bigger than end frame."
        start = max(interval_1[0], interval_2[0])
        end = min(interval_1[1], interval_2[1])
        if start < end:
            return end - start
        return 0

    def getUnion(self, interval_1, interval_2):
        assert interval_1[0] < interval_1[1], "start frame is bigger than end frame."
        assert interval_2[0] < interval_2[1], "start frame is bigger than end frame."
        start = min(interval_1[0], interval_2[0])
        end = max(interval_1[1], interval_2[1])
        return end - start

    def get_pair_list(self, anno_dict):
        sort_anno_dict_key = sorted(anno_dict.keys())
        tmp = 0
        tmp_list = []
        tmp_label_list = []
        anno_list = []
        anno_label_list = []
        for key in sort_anno_dict_key:
            value = anno_dict.get(key)
            tmp += value
            tmp_list.append(key)
            tmp_label_list.append(value)
            if tmp == 1:
                anno_list.append(tmp_list)
                anno_label_list.append(tmp_label_list)
                tmp = 0
                tmp_list = []
                tmp_label_list = []
                continue
        if len(anno_list) == 0:
            return None
        while [] in anno_list:
            anno_list.remove([])
        tmp_anno_list = [anno_list[0]]
        pair_list = []
        for ind in range(len(anno_list) - 1):
            cont_count = int(anno_list[ind + 1][0]) - int(anno_list[ind][-1])
            if cont_count > 1:
                pair_list.extend(tmp_anno_list)
                tmp_anno_list = [anno_list[ind + 1]]
                continue
            tmp_anno_list.append(anno_list[ind + 1])
        pair_list.extend(tmp_anno_list)
        return pair_list