from typing import List

import torch
from torch import Tensor
from torchmetrics import Metric
from torchmetrics.functional import pairwise_euclidean_distance
from .utils import *
import os
from motGPT.config import instantiate_from_config

class MMMetrics(Metric):
    full_state_update = True

    def __init__(self, cfg, dataname='humanml3d', 
                 mm_num_times=10, dist_sync_on_step=True, 
                 njoints=22,
                 **kwargs):
        super().__init__(dist_sync_on_step=dist_sync_on_step)

        self.name = "MultiModality scores"
        self.cfg = cfg
        self.dataname = dataname
        self.mm_num_times = mm_num_times
        self.njoints = njoints

        self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("count_seq",
                       default=torch.tensor(0),
                       dist_reduce_fx="sum")

        self.metrics = ["MultiModality"]
        self.add_state("MultiModality",
                       default=torch.tensor(0.),
                       dist_reduce_fx="sum")

        # chached batches
        self.add_state("mm_motion_embeddings", default=[], dist_reduce_fx=None)

        # T2M Evaluator
        self._get_t2m_evaluator(cfg)

    def _get_t2m_evaluator(self, cfg):
        """
        load T2M text encoder and motion encoder for evaluating
        """
        # init module
        self.t2m_textencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_textencoder)
        self.t2m_moveencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_moveencoder)
        self.t2m_motionencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_motionencoder)

        # load pretrianed
        if self.dataname == "kit":
            dataname = "kit"
        elif self.dataname == 'tomato':
            dataname ='motionx/vector_313'
        elif self.dataname == 'motionx':
            # dataname = "t2m"
            cfg.METRIC.TM2T.t2m_moveencoder.params['input_size'] = 263-4
            # dataname ='motionx/vector_623'
            dataname ='motionx/x-vector_263'
        else:
            dataname = "t2m"
        t2m_checkpoint = torch.load(os.path.join(
            cfg.METRIC.TM2T.t2m_path, dataname,
            "text_mot_match/model/finest.tar"),
                                    map_location="cpu")

        self.t2m_textencoder.load_state_dict(t2m_checkpoint["text_encoder"])
        self.t2m_moveencoder.load_state_dict(
            t2m_checkpoint["movement_encoder"])
        self.t2m_motionencoder.load_state_dict(
            t2m_checkpoint["motion_encoder"])

        # freeze params
        self.t2m_textencoder.eval()
        self.t2m_moveencoder.eval()
        self.t2m_motionencoder.eval()
        for p in self.t2m_textencoder.parameters():
            p.requires_grad = False
        for p in self.t2m_moveencoder.parameters():
            p.requires_grad = False
        for p in self.t2m_motionencoder.parameters():
            p.requires_grad = False

    def compute(self, sanity_flag):
        count = self.count.item()
        count_seq = self.count_seq.item()

        # init metrics
        metrics = {metric: getattr(self, metric) for metric in self.metrics}

        # if in sanity check stage then jump
        if sanity_flag:
            return metrics

        # cat all embeddings
        all_mm_motions = torch.cat(self.mm_motion_embeddings,
                                   axis=0).cpu().numpy()
        metrics['MultiModality'] = calculate_multimodality_np(
            all_mm_motions, self.mm_num_times)

        # Reset
        self.reset()

        return {**metrics}

    def update(
        self,
        feats_rst: Tensor,
        lengths_rst: List[int],
    ):
        self.count += sum(lengths_rst)
        self.count_seq += len(lengths_rst)

        align_idx = np.argsort(lengths_rst)[::-1].copy()
        feats_rst = feats_rst[align_idx]
        lengths_rst = np.array(lengths_rst)[align_idx]
        recmotion_embeddings = self.get_motion_embeddings(
            feats_rst, lengths_rst)
        cache = [0] * len(lengths_rst)
        for i in range(len(lengths_rst)):
            cache[align_idx[i]] = recmotion_embeddings[i:i + 1]

        mm_motion_embeddings = torch.cat(cache, axis=0).unsqueeze(0)
        # # store all mm motion embeddings
        self.mm_motion_embeddings.append(mm_motion_embeddings)

    def get_motion_embeddings(self, feats: Tensor, lengths: List[int]):
        m_lens = torch.tensor(lengths)
        m_lens = torch.div(m_lens,
                           self.cfg.DATASET.HUMANML3D.UNIT_LEN,
                           rounding_mode="floor")

        if self.dataname == 'tomato':
            mov = self.t2m_moveencoder(feats[...,4:]).detach()
        elif self.dataname == 'motionx' and self.njoints==52:
            njoints=52
            feats_22 = torch.cat((feats[..., :4+21*3], feats[..., 4+(njoints-1)*3:4+(njoints-1)*3+21*6],
                               feats[..., 4+(njoints-1)*9:4+(njoints-1)*9+22*3]), dim=-1).to(feats)
            mov = self.t2m_moveencoder(feats_22).detach()
        else:
            mov = self.t2m_moveencoder(feats[...,:-4]).detach()
        emb = self.t2m_motionencoder(mov, m_lens)

        # [bs, nlatent*ndim] <= [bs, nlatent, ndim]
        return torch.flatten(emb, start_dim=1).detach()
