from __future__ import annotations

from typing import Callable, Dict, Optional, Tuple

import torch
from torch import nn


_ACTS: Dict[str, Callable[[torch.Tensor], torch.Tensor]] = {
    "relu": torch.relu,
    "gelu": torch.nn.functional.gelu,
    "swish": torch.nn.functional.silu,
    "sigmoid": torch.sigmoid,
    "tanh": torch.tanh,
    "sin": torch.sin,
    "identity": lambda x: x,
}


def get_activation(name: str) -> Callable[[torch.Tensor], torch.Tensor]:
    key = str(name).lower()
    if key not in _ACTS:
        raise NotImplementedError(f"Activation '{name}' not supported.")
    return _ACTS[key]


class PeriodicEmbedding(nn.Module):
    def __init__(
        self,
        period: Tuple[float, ...],
        axis: Tuple[int, ...],
        trainable: Tuple[bool, ...],
    ):
        super().__init__()
        if not (len(period) == len(axis) == len(trainable)):
            raise ValueError("period, axis, trainable must have same length")

        self.axis = tuple(int(a) for a in axis)
        self.periods = nn.ParameterList()
        self.registered = []

        for i, (p, tr) in enumerate(zip(period, trainable)):
            p_t = torch.tensor(float(p))
            if tr:
                self.periods.append(nn.Parameter(p_t))
                self.registered.append(True)
            else:
                self.register_buffer(f"period_{i}", p_t)
                self.periods.append(getattr(self, f"period_{i}"))
                self.registered.append(False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        squeeze = False
        if x.ndim == 1:
            x = x.unsqueeze(0)
            squeeze = True

        ys = []
        for i in range(x.shape[-1]):
            xi = x[..., i : i + 1]
            if i in self.axis:
                idx = self.axis.index(i)
                p = self.periods[idx].to(x.device, x.dtype)
                ys.append(torch.cos(p * xi))
                ys.append(torch.sin(p * xi))
            else:
                ys.append(xi)

        y = torch.cat(ys, dim=-1)
        return y.squeeze(0) if squeeze else y


class FourierEmbedding(nn.Module):
    def __init__(self, scale: float, dims: int):
        super().__init__()
        if dims % 2 != 0:
            raise ValueError("dims must be even")
        self.scale = float(scale)
        self.dims = int(dims)
        self.kernel: Optional[nn.Parameter] = None

    def _init_kernel(self, in_dim: int, device, dtype):
        if self.kernel is None:
            w = torch.randn(in_dim, self.dims // 2, device=device, dtype=dtype)
            self.kernel = nn.Parameter(self.scale * w)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        squeeze = False
        if x.ndim == 1:
            x = x.unsqueeze(0)
            squeeze = True

        self._init_kernel(x.shape[-1], x.device, x.dtype)
        proj = x @ self.kernel
        y = torch.cat([torch.cos(proj), torch.sin(proj)], dim=-1)
        return y.squeeze(0) if squeeze else y


class Embedding(nn.Module):
    def __init__(
        self,
        periodicity: Optional[Dict] = None,
        fourier_embeddings: Optional[Dict] = None,
    ):
        super().__init__()
        self.periodic = (
            PeriodicEmbedding(**periodicity) if periodicity is not None else None
        )
        self.fourier = (
            FourierEmbedding(**fourier_embeddings)
            if fourier_embeddings is not None
            else None
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.periodic is not None:
            x = self.periodic(x)
        if self.fourier is not None:
            x = self.fourier(x)
        return x


class MLP(nn.Module):
    def __init__(
        self,
        input_dim: int = 2,
        num_layers: int = 4,
        hidden_dim: int = 256,
        out_dim: int = 1,
        activation: str = "tanh",
        periodicity: Optional[Dict] = None,
        fourier_embeddings: Optional[Dict] = None,
    ):
        super().__init__()

        self.embed = Embedding(
            periodicity=periodicity,
            fourier_embeddings=fourier_embeddings,
        ) if fourier_embeddings or periodicity else nn.Identity()

        self.act = get_activation(activation)

        layers = []
        in_dim = input_dim
        for _ in range(int(num_layers)):
            layers.append(nn.Linear(in_dim, hidden_dim))
            in_dim = hidden_dim

        self.hidden = nn.ModuleList(layers)
        self.out = nn.Linear(in_dim, out_dim)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.embed(x)
        for layer in self.hidden:
            x = self.act(layer(x))
        return self.out(x)
