from typing import List
import os
import torch
from torch import Tensor
from torchmetrics import Metric
from torchmetrics.functional import pairwise_euclidean_distance
from .utils import *
from motGPT.config import instantiate_from_config

class TM2TMetrics(Metric):
    def __init__(self,
                 cfg,
                 dataname='humanml3d',
                 top_k=3,
                 R_size=32,
                 diversity_times=300,
                 dist_sync_on_step=True,
                 njoints=22,
                 **kwargs):
        super().__init__(dist_sync_on_step=dist_sync_on_step)

        self.cfg = cfg
        self.dataname = dataname
        self.njoints = njoints
        self.name = "matching, fid, and diversity scores"
        self.top_k = top_k
        self.R_size = R_size
        self.text = 'lm' in cfg.TRAIN.STAGE and cfg.model.params.task == 't2m'
        self.diversity_times = diversity_times

        self.motion_factor = 1.0

        self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("count_seq",
                       default=torch.tensor(0),
                       dist_reduce_fx="sum")

        self.metrics = []

        # Matching scores
        if self.text:
            self.add_state("Matching_score",
                            default=torch.tensor(0.0),
                            dist_reduce_fx="sum")
            self.add_state("gt_Matching_score",
                            default=torch.tensor(0.0),
                            dist_reduce_fx="sum")
            self.Matching_metrics = ["Matching_score", "gt_Matching_score"]
            for k in range(1, top_k + 1):
                self.add_state(
                    f"R_precision_top_{str(k)}",
                    default=torch.tensor(0.0),
                    dist_reduce_fx="sum",
                )
                self.Matching_metrics.append(f"R_precision_top_{str(k)}")
            for k in range(1, top_k + 1):
                self.add_state(
                    f"gt_R_precision_top_{str(k)}",
                    default=torch.tensor(0.0),
                    dist_reduce_fx="sum",
                )
                self.Matching_metrics.append(f"gt_R_precision_top_{str(k)}")
            self.metrics.extend(self.Matching_metrics)

        # Fid
        self.add_state("FID", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.metrics.append("FID")

        # Diversity
        self.add_state("Diversity",
                       default=torch.tensor(0.0),
                       dist_reduce_fx="sum")
        self.add_state("gt_Diversity",
                       default=torch.tensor(0.0),
                       dist_reduce_fx="sum")
        self.metrics.extend(["Diversity", "gt_Diversity"])

        # Chached batches
        self.add_state("text_embeddings", default=[], dist_reduce_fx='cat')
        self.add_state("recmotion_embeddings", default=[], dist_reduce_fx='cat')
        self.add_state("gtmotion_embeddings", default=[], dist_reduce_fx='cat')

        # T2M Evaluator
        self._get_t2m_evaluator(cfg)

    def _get_t2m_evaluator(self, cfg):
        """
        load T2M text encoder and motion encoder for evaluating
        """

        if self.dataname == "kit":
            dataname = "kit"
        else:
            dataname = "t2m"

        # init module
        self.t2m_textencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_textencoder)
        self.t2m_moveencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_moveencoder)
        self.t2m_motionencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_motionencoder)

        # load pretrianed
        t2m_checkpoint = torch.load(os.path.join(
            cfg.METRIC.TM2T.t2m_path, dataname, "text_mot_match/model/finest.tar"),
                                    map_location="cpu")

        self.t2m_textencoder.load_state_dict(t2m_checkpoint["text_encoder"])
        self.t2m_moveencoder.load_state_dict(
            t2m_checkpoint["movement_encoder"])
        self.t2m_motionencoder.load_state_dict(
            t2m_checkpoint["motion_encoder"])

        # freeze params
        self.t2m_textencoder.eval()
        self.t2m_moveencoder.eval()
        self.t2m_motionencoder.eval()
        for p in self.t2m_textencoder.parameters():
            p.requires_grad = False
        for p in self.t2m_moveencoder.parameters():
            p.requires_grad = False
        for p in self.t2m_motionencoder.parameters():
            p.requires_grad = False

    @torch.no_grad()
    def compute(self, sanity_flag):
        count = self.count.item()
        count_seq = self.count_seq.item()

        # Init metrics dict
        metrics = {metric: getattr(self, metric) for metric in self.metrics}

        # Jump in sanity check stage
        if sanity_flag:
            return metrics

        # Cat cached batches and shuffle
        shuffle_idx = torch.randperm(count_seq)

        if type(self.recmotion_embeddings) == list:
            all_genmotions = torch.cat(self.recmotion_embeddings,
                                   axis=0).cpu()[shuffle_idx, :]
        else:
            all_genmotions = self.recmotion_embeddings.cpu()[shuffle_idx, :]

        if type(self.gtmotion_embeddings) == list:
            all_gtmotions = torch.cat(self.gtmotion_embeddings,
                                    axis=0).cpu()[shuffle_idx, :]
        else:
            all_gtmotions = self.gtmotion_embeddings.cpu()[shuffle_idx, :]

        # Compute text related metrics
        if self.text:
            if type(self.text_embeddings) == list:
                all_texts = torch.cat(self.text_embeddings, axis=0).cpu()[shuffle_idx, :]
            else:
                all_texts = self.text_embeddings.cpu()[shuffle_idx, :]
            # Compute r-precision
            assert count_seq > self.R_size
            top_k_mat = torch.zeros((self.top_k, ))
            for i in range(count_seq // self.R_size):
                # [bs=32, 1*256]
                group_texts = all_texts[i * self.R_size:(i + 1) * self.R_size]
                # [bs=32, 1*256]
                group_motions = all_genmotions[i * self.R_size:(i + 1) *
                                               self.R_size]
                # [bs=32, 32]
                dist_mat = euclidean_distance_matrix(
                    group_texts, group_motions).nan_to_num()
                self.Matching_score += dist_mat.trace()
                argsmax = torch.argsort(dist_mat, dim=1)
                top_k_mat += calculate_top_k(argsmax,
                                             top_k=self.top_k).sum(axis=0)

            R_count = count_seq // self.R_size * self.R_size
            metrics["Matching_score"] = self.Matching_score / R_count
            for k in range(self.top_k):
                metrics[f"R_precision_top_{str(k+1)}"] = top_k_mat[k] / R_count

            # Compute r-precision with gt
            assert count_seq > self.R_size
            top_k_mat = torch.zeros((self.top_k, ))
            for i in range(count_seq // self.R_size):
                # [bs=32, 1*256]
                group_texts = all_texts[i * self.R_size:(i + 1) * self.R_size]
                # [bs=32, 1*256]
                group_motions = all_gtmotions[i * self.R_size:(i + 1) *
                                              self.R_size]
                # [bs=32, 32]
                dist_mat = euclidean_distance_matrix(
                    group_texts, group_motions).nan_to_num()
                # match score
                self.gt_Matching_score += dist_mat.trace()
                argsmax = torch.argsort(dist_mat, dim=1)
                top_k_mat += calculate_top_k(argsmax,
                                             top_k=self.top_k).sum(axis=0)
            metrics["gt_Matching_score"] = self.gt_Matching_score / R_count
            for k in range(self.top_k):
                metrics[f"gt_R_precision_top_{str(k+1)}"] = top_k_mat[k] / R_count

        # tensor -> numpy for FID
        all_genmotions = all_genmotions.numpy()
        all_gtmotions = all_gtmotions.numpy()

        # Compute fid
        mu, cov = calculate_activation_statistics_np(all_genmotions)
        gt_mu, gt_cov = calculate_activation_statistics_np(all_gtmotions)
        metrics["FID"] = calculate_frechet_distance_np(gt_mu, gt_cov, mu, cov)

        # Compute diversity
        assert count_seq > self.diversity_times
        metrics["Diversity"] = calculate_diversity_np(all_genmotions,
                                                      self.diversity_times)
        metrics["gt_Diversity"] = calculate_diversity_np(
            all_gtmotions, self.diversity_times)

        # Reset
        self.reset()

        return {**metrics}

    @torch.no_grad()
    def update(self,
               feats_ref: Tensor,
               feats_rst: Tensor,
               lengths_ref: List[int],
               lengths_rst: List[int],
               word_embs: Tensor = None,
               pos_ohot: Tensor = None,
               text_lengths: Tensor = None):

        self.count += sum(lengths_ref)
        self.count_seq += len(lengths_ref)

        # T2m motion encoder
        align_idx = np.argsort(lengths_ref)[::-1].copy()
        feats_ref = feats_ref[align_idx]
        lengths_ref = np.array(lengths_ref)[align_idx]
        gtmotion_embeddings = self.get_motion_embeddings(
            feats_ref, lengths_ref)
        cache = [0] * len(lengths_ref)
        for i in range(len(lengths_ref)):
            cache[align_idx[i]] = gtmotion_embeddings[i:i + 1]
        self.gtmotion_embeddings.extend(cache)

        align_idx = np.argsort(lengths_rst)[::-1].copy()
        feats_rst = feats_rst[align_idx]
        lengths_rst = np.array(lengths_rst)[align_idx]
        recmotion_embeddings = self.get_motion_embeddings(
            feats_rst, lengths_rst)
        cache = [0] * len(lengths_rst)
        for i in range(len(lengths_rst)):
            cache[align_idx[i]] = recmotion_embeddings[i:i + 1]
        self.recmotion_embeddings.extend(cache)

        # T2m text encoder
        if self.text:
            text_emb = self.t2m_textencoder(word_embs, pos_ohot, text_lengths)
            text_embeddings = torch.flatten(text_emb, start_dim=1).detach()
            self.text_embeddings.append(text_embeddings)

    def get_motion_embeddings(self, feats: Tensor, lengths: List[int]):
        m_lens = torch.tensor(lengths)
        m_lens = torch.div(m_lens,
                           self.cfg.DATASET.HUMANML3D.UNIT_LEN,
                           rounding_mode="floor")
        m_lens = m_lens // self.cfg.DATASET.HUMANML3D.UNIT_LEN
        if self.dataname == 'tomato':
            mov = self.t2m_moveencoder(feats[...,4:]).detach()
        elif self.dataname == 'motionx' and self.njoints==52:
            njoints=52
            feats_22 = torch.cat((feats[..., :4+21*3], feats[..., 4+(njoints-1)*3:4+(njoints-1)*3+21*6],
                               feats[..., 4+(njoints-1)*9:4+(njoints-1)*9+22*3]), dim=-1).to(feats)
            mov = self.t2m_moveencoder(feats_22).detach()
        else:
            mov = self.t2m_moveencoder(feats[...,:-4]).detach()
        emb = self.t2m_motionencoder(mov, m_lens)

        # [bs, nlatent*ndim] <= [bs, nlatent, ndim]
        return torch.flatten(emb, start_dim=1).detach()
