from __future__ import annotations

import pathlib
from typing import TYPE_CHECKING

import laplace
import numpy as np
import torch
from torch import optim
from torch.nn.utils import vector_to_parameters

from .._ood_model import _OODModel

if TYPE_CHECKING:
    from jaxtyping import Float
    from torch import Tensor


class _LaplaceModel(_OODModel):
    """Base class for Laplace models in the OOD detection experiment."""

    @staticmethod
    def get_checkpoints(
        dataset_name: str,
        model_name: str,
        parametrization_name: str,
        root_dir: str,
        seed: int,
    ) -> list[pathlib.Path]:
        """Returns the list of checkpoints available to use for the Laplace approximation."""
        try:
            checkpoints = list(
                pathlib.Path(
                    pathlib.Path(root_dir)
                    / f"training_logs/{dataset_name}/{model_name}/{parametrization_name}/seed_{seed}"
                ).rglob("*.ckpt")
            )
        except TypeError as e:
            raise ValueError(f"No checkpoint found for seed: {seed}.") from e

        return checkpoints

    def forward(
        self, inputs: Float[Tensor, "batch *in_feature"]
    ) -> Float[Tensor, "batch *out_feature"]:
        likelihood = "classification"
        joint = False
        diagonal_output = False

        # If we haven't computed the Laplace approximation yet, simply return the MAP
        try:
            self.laplace_approximation._check_H_init()
        except AttributeError:
            return self.laplace_approximation.model(inputs)

        if self.pred_type == "glm":
            return self._glm_logits_forward_call(
                inputs,
                likelihood=likelihood,
                joint=joint,
                diagonal_output=diagonal_output,
            )
        elif self.pred_type == "nn":
            if likelihood == "classification":
                return self._nn_logit_samples(inputs, num_samples=self.num_samples_test)
            else:
                raise NotImplementedError()
        else:
            raise NotImplementedError()

    def _glm_logits_forward_call(
        self,
        inputs: Float[Tensor, "batch *in_feature"],
        likelihood: str,
        joint: bool = False,
        diagonal_output: bool = False,
    ) -> Float[Tensor, "batch *out_feature"]:
        """Forward pass of the Laplace approximation.

        Implements ParametricLaplace._glm_forward_call but for logits rather than probabilities.
        """
        f_mu, f_var = self.laplace_approximation._glm_predictive_distribution(
            inputs,
            joint=joint and likelihood == laplace.utils.enums.Likelihood.REGRESSION,
        )
        if self.link_approx == "probit":
            kappa = 1 / torch.sqrt(1.0 + np.pi / 8 * f_var.diagonal(dim1=1, dim2=2))
            return kappa * f_mu
        elif self.link_approx == "mc":
            f_samples = self.laplace_approximation._glm_functional_samples(
                f_mu,
                f_var,
                self.num_samples_test,
                # diagonal_output=diagonal_output,
            )
            return f_samples.mean(dim=0)
        else:
            raise NotImplementedError()

    def _nn_logit_samples(
        self,
        inputs: Float[Tensor, "batch *in_feature"],
        num_samples: int,
    ) -> Float[Tensor, "batch *out_feature"]:
        """Generates logit samples using neural network sampling."""
        f_samples = []

        for sample in self.laplace_approximation.sample(num_samples):
            vector_to_parameters(sample, self.laplace_approximation.params)
            logits = self.laplace_approximation.model(inputs)
            f_samples.append(logits.detach())

        vector_to_parameters(
            self.laplace_approximation.mean, self.laplace_approximation.params
        )
        f_samples = torch.stack(f_samples, dim=1)

        return f_samples.mean(dim=0)

    def configure_optimizers(self) -> optim.Optimizer:
        pass
