from typing import List

import torch
from torch import Tensor
from torchmetrics import Metric

from .utils import *


# motion reconstruction metric
class MRMetrics(Metric):

    def __init__(self,
                 njoints,
                 jointstype: str = "mmm",
                 force_in_meter: bool = True,
                 align_root: bool = True,
                 dist_sync_on_step=True,
                 **kwargs):
        super().__init__(dist_sync_on_step=dist_sync_on_step)

        self.name = 'Motion Reconstructions'
        self.jointstype = jointstype
        self.align_root = align_root
        self.force_in_meter = force_in_meter

        self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("count_seq",
                       default=torch.tensor(0),
                       dist_reduce_fx="sum")

        self.add_state("MPJPE",
                       default=torch.tensor([0.0]),
                       dist_reduce_fx="sum")
        self.add_state("PAMPJPE",
                       default=torch.tensor([0.0]),
                       dist_reduce_fx="sum")
        self.add_state("ACCEL",
                       default=torch.tensor([0.0]),
                       dist_reduce_fx="sum")

        self.MR_metrics = ["MPJPE", "PAMPJPE", "ACCEL"]

        # All metric
        self.metrics = self.MR_metrics

    def compute(self, sanity_flag):
        if self.force_in_meter:
            factor = 1000.0
        else:
            factor = 1.0

        count = self.count
        count_seq = self.count_seq
        mr_metrics = {}
        mr_metrics["MPJPE"] = self.MPJPE / count * factor
        mr_metrics["PAMPJPE"] = self.PAMPJPE / count * factor
        # n-2 for each sequences
        mr_metrics["ACCEL"] = self.ACCEL / (count - 2 * count_seq) * factor
        
        # Reset
        self.reset()
        
        return mr_metrics

    def update(self, joints_rst: Tensor, joints_ref: Tensor,
               lengths: List[int]):
        assert joints_rst.shape == joints_ref.shape
        assert joints_rst.dim() == 4
        # (bs, seq, njoint=22, 3)

        self.count += sum(lengths)
        self.count_seq += len(lengths)

        # avoid cuda error of DDP in pampjpe
        rst = joints_rst.detach().cpu()
        ref = joints_ref.detach().cpu()

        # align root joints index
        if self.align_root and self.jointstype in ['mmm', 'humanml3d']:
            align_inds = [0]
        else:
            align_inds = None

        for i in range(len(lengths)):
            self.MPJPE += torch.sum(
                calc_mpjpe(rst[i], ref[i], align_inds=align_inds))
            self.PAMPJPE += torch.sum(calc_pampjpe(rst[i], ref[i]))
            self.ACCEL += torch.sum(calc_accel(rst[i], ref[i]))
