import math
from typing import Any, Dict

import torch
import torch.nn as nn
from torch.quasirandom import SobolEngine


class MLP(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.fc1 = torch.nn.Linear(input_size, hidden_size)
        self.fc2 = torch.nn.Linear(hidden_size, output_size)
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size

    def forward(self, x):
        x = nn.functional.silu(self.fc1(x))
        if self.hidden_size == self.output_size:
            x = self.fc2(x) + x
        else:
            x = self.fc2(x)

        return x


class Basic_Epinet(torch.nn.Module):
    def __init__(
        self,
        feature_dim: int,
        epistemic_dim: int = 128,
        n_epinet_layers: int = 1,
        n_epinet_marg_layers: int = 2,
        n_prior_layers: int = 1,
        prior_scale: float = 2.0,
        prior_oscale: float = 0.75,
    ):
        super().__init__()
        self.feature_dim = feature_dim
        self.epistemic_dim = epistemic_dim
        self.n_epinet_layers = n_epinet_layers
        self.n_epinet_marg_layers = n_epinet_marg_layers
        self.n_prior_layers = n_prior_layers
        self.prior_scale = prior_scale
        self.prior_oscale = prior_oscale

        # quasi-random sobol samples (these are fixed at init)
        # clamp to [eps, 1-eps] to avoid numerical issues
        self.max_samples = 10_000
        eps = torch.finfo(torch.float32).eps
        sobol_samples = (
            SobolEngine(self.epistemic_dim)
            .draw(self.max_samples + 100)
            .clamp(eps, 1 - eps)[100:]
        )
        self.sobol_samples = nn.Parameter(sobol_samples, requires_grad=False)

        # epinet
        self.epinet = torch.nn.ModuleDict(
            {
                "layer": torch.nn.Sequential(
                    nn.LayerNorm(self.feature_dim, elementwise_affine=False),
                    # nn.Identity(),
                ),
                "cat_layer": torch.nn.Sequential(
                    nn.Linear(
                        self.feature_dim + self.epistemic_dim, self.epistemic_dim
                    ),
                    MLP(self.epistemic_dim, self.epistemic_dim, self.epistemic_dim),
                    MLP(self.epistemic_dim, self.epistemic_dim, self.epistemic_dim),
                    MLP(self.epistemic_dim, self.epistemic_dim, self.epistemic_dim),
                ),
                "out": torch.nn.Sequential(
                    nn.Linear(self.epistemic_dim, 1),
                ),
            }
        )

        # priornet
        self.priornet = torch.nn.ModuleDict(
            {
                "layer": torch.nn.Sequential(
                    nn.LayerNorm(self.feature_dim),
                    MLP(self.feature_dim, self.feature_dim // 2, self.feature_dim // 2),
                    MLP(
                        self.feature_dim // 2, self.feature_dim // 2, self.epistemic_dim
                    ),
                ),
            }
        )
        for p in self.priornet.parameters():
            p.requires_grad = False

    def epistemic_index(self, n_samples: int = 1, use_sobol: bool = True):
        """Transforms sobol samples to quasi-gaussian samples if n_samples > 1."""

        if n_samples > 1 and use_sobol:
            assert (
                0 < n_samples <= self.max_samples
            ), "n_samples must be [1, max_samples]"
            # deterministic quasi-gaussian samples
            x = self.sobol_samples[:n_samples]
            x = torch.sqrt(torch.tensor(2.0, device=x.device)) * torch.special.erfinv(
                2 * x - 1
            )
        else:
            # just gaussian
            x = torch.randn(
                n_samples, self.epistemic_dim, device=self.sobol_samples.device
            )

        return x

    def forward(self, latents: torch.Tensor) -> torch.Tensor:
        """Forward pass of the model."""

        latents = latents.detach()

        # epinet
        z = self.epistemic_index().repeat(latents.shape[0], 1).to(latents.device)
        x_epi_ = self.epinet["layer"](latents)
        x_epi = torch.cat([z, x_epi_], dim=-1)
        residual = self.epinet["out"](self.epinet["cat_layer"](x_epi) * z).squeeze(-1)

        # priornet
        with torch.no_grad():
            x_prior = self.priornet["layer"](latents)
            prior = (x_prior * z).sum(-1)
        residual = residual + self.prior_scale * prior

        return residual

    def sample_n(
        self,
        latents: torch.Tensor,
        n_samples: int = 100,
        residual_scale: float = 1.0,
        use_sobol: bool = True,
    ) -> torch.Tensor:

        latents = latents.detach()

        # epinet
        z = (
            self.epistemic_index(n_samples, use_sobol)
            .unsqueeze(1)
            .repeat(1, latents.shape[0], 1)
            .to(latents.device)
        )
        x_epi_ = self.epinet["layer"](latents).unsqueeze(0).expand(n_samples, -1, -1)
        x_epi = torch.cat([z, x_epi_], dim=-1)
        residual = self.epinet["out"](self.epinet["cat_layer"](x_epi) * z).squeeze(-1)

        # priornet
        x_prior = self.priornet["layer"](latents)
        prior = (x_prior * z).sum(-1)
        residual = residual + self.prior_scale * prior

        return residual * residual_scale
