from abc import abstractmethod
import logging
import torch
from mmhug.registry import TRAINERS
from mmengine.model import BaseModel
from typing import Dict, Any, Union
from torch import Tensor
from torch import nn
from mmengine import print_log


def count_parameters(model, trainable_only: bool = False) -> int:

    return sum(
        p.numel() for p in model.parameters() if (not trainable_only) or p.requires_grad
    )


@TRAINERS.register_module()
class BaseTrainerModel(BaseModel):

    def _collect_trainable_params(self) -> None:
        """
        Collects all trainable parameters from the listed sub-modules.

        This method will:
          1. Iterate over each attribute name in self.module_names.
          2. If the attribute exists and is an nn.Module, collect its parameters
             for which requires_grad=True.
          3. Store the union of these parameters in self._trainable_params.
          4. Log the total number of trainable parameters.
          5. Additionally, log each module’s own trainable-parameter count and
             percentage of the total.

        After this runs, self._trainable_params holds a flat list of all
        parameters to be optimized.
        """
        module_param_trainable_counts = {}
        module_param_all_counts = {}

        # 1. Iterate over direct child modules
        for name, module in self.named_children():
            # Gather only parameters that require gradients
            count_trainable = count_parameters(module, trainable_only=True)
            module_param_trainable_counts[name] = count_trainable
            count = count_parameters(module, trainable_only=False)
            module_param_all_counts[name] = count
            print_log(
                f"Module '{name}': {count_trainable:,}/{count:,} (trainable/all) parameters",
                logger="current",
                level=logging.INFO,
            )

    def train(self, mode):
        """
        Override this function to determine whether each module is in training mode here.
        """
        super().train(mode)

    def forward(self, batch: Dict[str, Any], mode: str = "tensor"):
        if mode == "tensor":
            return self.forward_tensor(batch)

        if mode == "loss":
            return self.forward_loss(batch)

        if mode == "predict":
            return self.forward_predict(batch)

    def _run_forward(
        self, data: dict | tuple | list, mode: str
    ) -> Dict[str, torch.Tensor] | list:
        """
        Modified from mmengine.BaseModel._run_forward. We directly parse batch data to model rather than
        passing **data or *data.
        """
        return self(data, mode)

    def forward_loss(self, batch: Dict[str, Dict[str, Tensor]]) -> Dict[str, Tensor]:
        raise NotImplementedError("")

    def forward_predict(self, batch: Dict[str, Dict[str, Tensor]]) -> Dict[str, Tensor]:
        raise NotImplementedError("")
