from typing import Dict, Optional, List, Any, Union

import torch
from mmengine.evaluator import BaseMetric

from tqdm import tqdm
from .multimodal_mapping import (
    euclidean_distance_matrix,
    calculate_top_k,
)
from mmhug.registry import METRICS
from mmengine.structures import BaseDataElement
import torch.nn.functional as F
from torch import nn
from accelerate.utils import send_to_device


@METRICS.register_module(force=True)
class AudioMotionClipMetric(BaseMetric):
    """
    For Audio Motion Contrastive learning evaluation
    Including R-precision(Top-k)
    Multimodal distance
    """

    results: List[Any]

    def __init__(
        self,
        top_k=3,
        r_precision_batch: int = -1,
        audio_latent_key: str = "audio_codes",
        motion_latent_key: str = "motion_codes",
        collect_device: str = "cpu",
        prefix: Optional[str] = None,
    ):
        super(AudioMotionClipMetric, self).__init__(collect_device, prefix)
        self.top_k = top_k
        self.r_precision_batch = r_precision_batch
        self.audio_latent_key = audio_latent_key
        self.motion_latent_key = motion_latent_key

    def compute_metrics(self, results: List):
        """Compute following metrics:
        1. Multimodal distance
            After contrastive learning, the feature token of audio and motion should be similar.
            We compute the mean euclidean distance between audio and motion codes.
        2. R-precision(Top-k):
            For each audio token, we find the top-k motion tokens with the smallest distance.
            If the ground truth motion token is in the top-k, we count it as a correct match.
            The R-precision is the average correct match rate across all audio tokens.

        Args:
            results (List): The self.results which has been updated by self.process().

        Returns:
            Dict: The computed metrics. The keys are the names of the metrics,
            and the values are corresponding results.
        """
        num_test_samples = len(results)
        num_frames = 0
        mm_dist = 0
        top_k_mat_a2m = torch.zeros((self.top_k,))
        top_k_mat_m2a = torch.zeros((self.top_k,))
        # each code is T, D. T may variance in the batch
        all_audio_codes = [result[self.audio_latent_key].float() for result in results]
        all_motion_codes = [
            result[self.motion_latent_key].float() for result in results
        ]

        for audio_codes, motion_codes in tqdm(
            zip(all_audio_codes, all_motion_codes),
            desc="Evaluating Audio Motion Clip Metric",
            total=num_test_samples,
        ):
            num_frames += len(audio_codes)
            # audio_codes, motion_codes: T, D
            assert len(audio_codes) == len(
                motion_codes
            ), f"{len(audio_codes)} != {len(motion_codes)}"
            if self.r_precision_batch > 0:
                # randomly select r_precision frames.
                r_precision_bs = min(len(audio_codes), self.r_precision_batch)
                # shuffle, select r_precision tokens
                indices = torch.randperm(len(audio_codes))[:r_precision_bs]
                audio_codes = audio_codes[indices]
                motion_codes = motion_codes[indices]

            audio_codes = F.normalize(audio_codes)
            motion_codes = F.normalize(motion_codes)
            dist_mat = euclidean_distance_matrix(audio_codes, motion_codes)
            mm_dist += dist_mat.trace()

            argsmax_a2m = torch.argsort(dist_mat, dim=1)
            # [T, top_k] -> [top_k]
            topk = calculate_top_k(argsmax_a2m, top_k=self.top_k).sum(axis=0)
            assert len(topk) == self.top_k, topk.shape
            top_k_mat_a2m += topk

            argsmax_m2a = torch.argsort(dist_mat.T, dim=1)
            # [T, top_k] -> [top_k]
            topk = calculate_top_k(argsmax_m2a, top_k=self.top_k).sum(axis=0)
            top_k_mat_m2a += topk

        res = {
            "num_test_samples": num_test_samples,
            "num_frames": num_frames,
            "mm_dist": mm_dist / num_frames,
        }
        for k in range(self.top_k):
            res[f"a2m_r_precision_top_{k + 1}"] = top_k_mat_a2m[k] / num_frames
            res[f"m2a_r_precision_top_{k + 1}"] = top_k_mat_m2a[k] / num_frames

        return res

    def process(self, data_batch, data_samples: List[Union[BaseDataElement, Dict]]):
        """
        :param data_batch: model input(from test dataloader)
        :param data_samples: output of model.forward_predict
        :return:
        """
        # N * [T, C]
        batch_size = len(data_samples)
        for i in range(batch_size):
            sample = send_to_device(
                {
                    self.audio_latent_key: data_samples[i].get(self.audio_latent_key),
                    self.motion_latent_key: data_samples[i].get(self.motion_latent_key),
                },
                self.collect_device,
                non_blocking=True,
            )
            self.results.append(sample)
