from typing import List, Dict, Optional, Union
from mmhug.registry import METRICS
from mmengine.evaluator import BaseMetric
from accelerate.utils import send_to_device
import torch


@METRICS.register_module(force=True)
class EvaluatorPSNR(BaseMetric):

    def __init__(
        self,
        gt_key: str = "gt_video",
        pred_key: str = "pred_video",
        collect_device: str = "cpu",
        prefix: Optional[str] = None,
    ):
        super().__init__(collect_device, prefix)
        self.gt_key = gt_key
        self.pred_key = pred_key

    def compute_metrics(self, results: List[Dict]) -> Dict[str, float]:
        """
        Compute PSNR for each sample and return the average PSNR.
        PSNR = 10 * log10(MAX^2 / MSE), where MAX = 1.0
        """
        psnr_values = []
        for result in results:
            # result contains tensors on collect_device
            gt: torch.Tensor = result[self.gt_key].float()
            pred: torch.Tensor = result[self.pred_key].float()
            # shape: [T, C, H, W]
            mse = torch.mean((gt - pred) ** 2)
            if mse == 0:
                psnr_val = float("inf")
            else:
                psnr_val = (10.0 * torch.log10(1.0 / mse)).item()
            psnr_values.append(psnr_val)

        # Clear stored results
        self.results.clear()
        # Compute average PSNR
        avg_psnr = sum(psnr_values) / len(psnr_values) if psnr_values else 0.0
        return {"psnr": avg_psnr}

    def process(self, data_batch, data_samples: List[Union[Dict, object]]):
        """
        Collect ground-truth and prediction videos from model outputs.
        """
        for sample in data_samples:
            gt = sample.get(self.gt_key)
            pred = sample.get(self.pred_key)
            if gt.max() > 1.0:
                gt = gt / 255.0
                pred = pred / 255.0
            batch_item = send_to_device(
                {
                    self.gt_key: gt,
                    self.pred_key: pred,
                },
                self.collect_device,
                non_blocking=True,
            )
            self.results.append(batch_item)


if __name__ == "__main__":
    # 测试示例
    # 创建随机的 ground truth 和 pred 视频数据，形状: [T, C, H, W]
    T, C, H, W = 4, 3, 64, 64
    gt_video = torch.rand(T, C, H, W)
    # 添加少许噪声到 pred
    pred_video = gt_video + 0.01 * torch.randn(T, C, H, W)

    # 构造 data_samples 列表
    data_samples = [{"gt_video": gt_video, "pred_video": pred_video}]

    # 实例化评估器并处理数据
    evaluator = EvaluatorPSNR(collect_device="cpu", prefix="test_")
    # process 方法会收集到 evaluator.results
    evaluator.process(None, data_samples)
    # 计算指标
    metrics = evaluator.compute_metrics(evaluator.results)
    print("Computed Metrics:", metrics)
