from typing import Union, Dict, List, Sequence

import torch
from mmengine import LOOPS
from mmengine.runner import autocast
from mmengine.runner import ValLoop as _ValLoop
from mmengine.runner import TestLoop as _TestLoop
from torch.utils.data import DataLoader

from mmengine.evaluator import Evaluator
from mmhug.utils import dtype_from_str


@LOOPS.register_module(force=True)
class ValLoop(_ValLoop):
    """
    The original ValLoop implementation only support fp32 and fp16.
    In this implementation, we support manually setting dtype(like bf16).
    All other implementation keep consistent with original mmengine.ValLoop.
    """

    def __init__(
        self,
        runner,
        dataloader: Union[DataLoader, Dict],
        evaluator: Union[Evaluator, Dict, List],
        fp16: bool = False,
        dtype="fp16",
    ) -> None:

        if isinstance(dtype, str):
            dtype = dtype_from_str(dtype)
        self.dtype = dtype
        super().__init__(runner, dataloader, evaluator, fp16)

    @torch.no_grad()
    def run_iter(self, idx, data_batch: Sequence[dict]):
        """Iterate one mini-batch.

        Args:
            data_batch (Sequence[dict]): Batch of data
                from dataloader.
        """
        self.runner.call_hook("before_val_iter", batch_idx=idx, data_batch=data_batch)
        # outputs should be sequence of BaseDataElement
        with autocast(enabled=self.fp16, dtype=self.dtype):
            outputs = self.runner.model.val_step(data_batch)
        self.evaluator.process(data_samples=outputs, data_batch=data_batch)
        self.runner.call_hook(
            "after_val_iter", batch_idx=idx, data_batch=data_batch, outputs=outputs
        )


@LOOPS.register_module(force=True)
class TestLoop(_TestLoop):
    """
    The original TestLoop implementation only support fp32 and fp16.
    In this implementation, we support manually setting dtype(like bf16).
    All other implementation keep consistent with original mmengine.TestLoop.
    """

    def __init__(
        self,
        runner,
        dataloader: Union[DataLoader, Dict],
        evaluator: Union[Evaluator, Dict, List],
        fp16: bool = False,
        dtype="fp16",
    ) -> None:
        if isinstance(dtype, str):
            dtype = dtype_from_str(dtype)
        else:
            raise NotImplementedError(f"Unsupported dtype {dtype}")
        self.dtype = dtype
        super().__init__(runner, dataloader, evaluator, fp16)

    @torch.no_grad()
    def run_iter(self, idx, data_batch: Sequence[dict]):
        """Iterate one mini-batch.

        Args:
            data_batch (Sequence[dict]): Batch of data
                from dataloader.
        """
        self.runner.call_hook("before_test_iter", batch_idx=idx, data_batch=data_batch)
        # outputs should be sequence of BaseDataElement
        with autocast(enabled=self.fp16, dtype=self.dtype):
            outputs = self.runner.model.test_step(data_batch)
        self.evaluator.process(data_samples=outputs, data_batch=data_batch)
        self.runner.call_hook(
            "after_test_iter", batch_idx=idx, data_batch=data_batch, outputs=outputs
        )
