import logging
import types
from abc import ABC, abstractmethod
from contextlib import contextmanager, nullcontext
from typing import Any, Callable, Dict

import torch
from torch import nn

from .abstract_task import AbstractTask, validate_task


@contextmanager
def force_half_batchnorm(model):
    # Store original forward methods
    original_forwards = {}

    def half_forward(self, x):
        exponential_average_factor = 0.0
        mean = x.mean([0, 2, 3])
        var = x.var([0, 2, 3], unbiased=False)
        mean = mean.to(x.dtype)
        var = var.to(x.dtype)
        norm = (x - mean[None, :, None, None]) / torch.sqrt(var[None, :, None, None] + self.eps)
        return norm * self.weight[None, :, None, None] + self.bias[None, :, None, None]

    # Replace forward methods
    for module in model.modules():
        if isinstance(module, nn.BatchNorm2d):
            original_forwards[module] = module.forward
            module.forward = types.MethodType(half_forward, module)

    try:
        yield
    finally:
        # Restore original forward methods
        for module, original_forward in original_forwards.items():
            module.forward = original_forward

class AbstractComputer(ABC):
    """An abstract base class for Computers."""

    # Specifies the dtype for storing TDA scores.
    score_dtype: torch.dtype = torch.float32

    # Specifies the dtype for storing gradients.
    grads_dtype: torch.dtype = torch.float32

    # Specifies the dtype for storing statistics (only applies to influence functions).
    stats_dtype: torch.dtype = torch.float32

    # Specifies the dtype for performing eigendecompositon (only applies to influence functions).
    eig_dtype: torch.dtype = torch.float64

    @abstractmethod
    def __init__(
        self,
        model: nn.Module,
        task: AbstractTask,
        logger_name: str,
        logging_level: int = logging.INFO,
        force_half_precision: bool = False,
    ) -> None:
        """Initialize the class AbstractComputer.

        Args:
        ----
            model (nn.Module):
                PyTorch model for which scores are computed.
            task (AbstractTask):
                Specifies the task for the pipeline. For details, see `AbstractTask` in
                `src/abstract_task.py`.
            logger_name (str):
                Name of the logger.
            logging_level (int, optional):
                The logging level. Defaults to `logging.INFO`.

        """
        self.model = model
        self.task = task
        # self.force_half_precision = force_half_precision
        # if self.force_half_precision and any(map(lambda x: isinstance(x, torch.float32,
        #                                                                [__class__.score_dtype,
        #                                                                 __class__.grads_dtype,
        #                                                                 __class__.stats_dtype]))):
        #     msg = "Asked for half precision but task dtypes are full precision."
        #     raise RuntimeError(msg)

        # Setup logging configurations.
        # TODO: REMOVE THIS
        logging.basicConfig()
        self.logger = logging.getLogger(logger_name)
        self.logger.setLevel(logging_level)

        self.force_half_precision = force_half_precision
        if self.force_half_precision and any(
            x == torch.float32 for x in [
                __class__.score_dtype,
                __class__.grads_dtype,
                __class__.stats_dtype
            ]
        ):
            self.logger.warning("Asked for half precision but task dtypes are full precision.")

        validate_task(model=self.model, task=self.task, logger=self.logger)

    def _compute_train_loss(
        self,
        params: Dict[str, torch.Tensor],
        buffers: Dict[str, torch.Tensor],
        batch: Any,
    ) -> torch.Tensor:
        """Compute the cumulative training loss for a given batch.

        Args:
        ----
            params (dict):
                Model parameters to be used for computation.
            buffers (dict):
                Model buffers to be used for computation.
            batch (Any):
                The batch of data on which the loss will be computed.

        """
        return self.task.get_train_loss(
            model=self.model,
            batch=batch,
            parameter_and_buffer_dicts=(params, buffers),
            sample=False,
            reduction="sum",
        )

    def _compute_train_loss_grad(self) -> Callable:
        """Return the function that computes gradients of loss w.r.t. parameters."""
        return torch.func.grad(self._compute_train_loss, argnums=0, has_aux=False)

    def _compute_train_loss_double_jac(self, num_samples: int | None, *kargs: Any, **kwargs: Any) -> Callable:
        """Select jacobiane estimation either full, or stochastic."""
        if num_samples is None:
            return self._compute_train_loss_double_jac_full(*kargs, **kwargs)
        else:
            return self._compute_train_loss_double_jac_stochastic(*kargs, num_samples=num_samples, **kwargs)

    def _compute_train_loss_double_jac_full(self) -> Callable:
        """Return the function that computes jacobian of loss gradients w.r.t. inputs."""

        def _deconstructed_func(params, buffers, inputs, labels) -> torch.Tensor:
            return self._compute_train_loss(params, buffers, (inputs, labels))

        grad_func = torch.func.grad(_deconstructed_func, argnums=0, has_aux=False)
        double_grad = torch.func.jacfwd(grad_func, argnums=2, has_aux=False)

        def _reconstruncted_double_grad(params, buffers, batch) -> torch.Tensor:
            inputs, labels = batch
            with force_half_batchnorm(self.model) if self.force_half_precision else nullcontext():
                return double_grad(params, buffers, inputs, labels)

        return _reconstruncted_double_grad

    def _compute_measurement(
        self,
        params: Dict[str, torch.Tensor],
        buffers: Dict[str, torch.Tensor],
        batch: Any,
    ) -> torch.Tensor:
        """Compute the cumulative measurement for a given batch.

        Args:
        ----
            params (dict):
                Model parameters to be used for computation.
            buffers (dict):
                Model buffers to be used for computation.
            batch (Any):
                The batch of data on which the loss will be computed.

        """
        return self.task.get_measurement(
            model=self.model,
            batch=batch,
            parameter_and_buffer_dicts=(params, buffers),
            sample=False,
            reduction="sum",
        )

    def _compute_measurement_grad(self) -> Callable:
        """Return the function that computes gradients of measurement w.r.t. parameters."""
        return torch.func.grad(self._compute_measurement, argnums=0, has_aux=False)

    # ──────────────────────────────────────────────────────────────
    # helpers
    # ──────────────────────────────────────────────────────────────
    @staticmethod
    def _zeros_like_param(p: torch.Tensor,
                        inp_shape: tuple[int, int, int]) -> torch.Tensor:
        return torch.zeros(p.shape + inp_shape, dtype=p.dtype, device=p.device)

    @staticmethod
    def _to_fp32_if_float(t: torch.Tensor) -> torch.Tensor:
        return t.to(torch.float32) if t.is_floating_point() else t.clone()

    @staticmethod
    def _cast_tree_fp32(tree: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        return {k: __class__._to_fp32_if_float(v) for k, v in tree.items()}  # noqa: SLF001

    # ──────────────────────────────────────────────────────────────
    # factory
    # ──────────────────────────────────────────────────────────────
    def _compute_train_loss_double_jac_stochastic(
            self,
            *,
            num_samples: int = 256,
            generator: torch.Generator | None = None,
    ) -> Callable:
        """Memory-friendly estimator of  d/dx[∇_θ L(θ,x)]  for a single image
        shaped (C,H,W).  Keeps exact Jacobian shape
            param_leaf.shape  +  (C,H,W).

        • Floats are safely up-cast to FP32 for differentiability;
        integer buffers stay integer, so indexing ops keep working.
        """

        # ---------- loss & ∇_θ loss in FP32 -----------------------
        def _loss(params32, buffers_mixed, x32, y):
            return self.task.get_train_loss(
                self.model,
                (x32, y),
                parameter_and_buffer_dicts=(params32, buffers_mixed),
            )

        grad_params32 = torch.func.grad(_loss, argnums=0, has_aux=False)
        # ---------- returned closure -----------------------------
        def _double_jac(
                params: dict[str, torch.Tensor],        # FP16
                buffers: dict[str, torch.Tensor],       # int64 + FP16
                batch: tuple[torch.Tensor, torch.Tensor],   # (C,H,W) FP16, label
        ) -> dict[str, torch.Tensor]:

            x_fp16, y = batch
            assert x_fp16.ndim == 3, "Expect a single image shaped (C,H,W)."

            # shapes
            C, H, W  = x_fp16.shape
            inp_shape, flat_dim = (C, H, W), C * H * W

            # 1) cast params & *floating* buffers to FP32
            params32      = self._cast_tree_fp32(params)
            buffers_mixed = self._cast_tree_fp32(buffers)   # ints remain ints
            x32           = x_fp16.to(torch.float32)

            # helper: ∇_θ L with closed-over params/buffers
            g = lambda z: grad_params32(params32, buffers_mixed, z, y)

            # 2) prepare output pytree in original dtypes
            jac = {n: self._zeros_like_param(p, inp_shape) for n, p in params.items()}

            # coordinates to probe
            idxs = torch.randint(0, flat_dim, (num_samples,),
                                generator=generator, device=x_fp16.device)

            for idx_t in idxs:
                idx_t_view = idx_t.view(1)  # shape (1,)

                v32 = torch.zeros_like(x32).view(-1)
                v32 = v32.scatter(0, idx_t_view, 1.0)
                v32 = v32.view_as(x32)

                _, hvp32 = torch.func.jvp(g, (x32,), (v32,))

            for name, leaf32 in hvp32.items():
                flat_jac = jac[name].view(*leaf32.shape, -1)

                # expand index and src to match (..., 1)
                index = idx_t.expand(*leaf32.shape, 1)
                src   = leaf32.to(flat_jac.dtype).unsqueeze(-1)

                flat_jac = flat_jac.scatter(dim=-1, index=index, src=src)
                jac[name] = flat_jac.view_as(jac[name])

            return jac

        return _double_jac
