"""Efficient GP data sampler for Neural Process training."""

import math
from typing import List, Optional, Tuple

import torch
from einops import rearrange, repeat

from src.data.utils import sample_sobol
from src.utils import DataAttr


def rbf_kernel(x1, x2, length_scale, sigma_f):
    """Radial Basis Function (RBF) kernel."""
    dist = (x1.unsqueeze(-2) - x2.unsqueeze(-3)) / length_scale
    return sigma_f.pow(2) * torch.exp(-0.5 * dist.pow(2).sum(-1))


def matern12_kernel(x1, x2, length_scale, sigma_f):
    """Matern 1/2 kernel."""
    dist = torch.norm((x1.unsqueeze(-2) - x2.unsqueeze(-3)) / length_scale, dim=-1)
    return sigma_f.pow(2) * torch.exp(-dist)


def matern32_kernel(x1, x2, length_scale, sigma_f):
    """Matern 3/2 kernel."""
    dist = torch.norm((x1.unsqueeze(-2) - x2.unsqueeze(-3)) / length_scale, dim=-1)
    sqrt3 = math.sqrt(3.0)
    return sigma_f.pow(2) * (1.0 + sqrt3 * dist) * torch.exp(-sqrt3 * dist)


def matern52_kernel(x1, x2, length_scale, sigma_f):
    """Matern 5/2 kernel."""
    dist = torch.norm((x1.unsqueeze(-2) - x2.unsqueeze(-3)) / length_scale, dim=-1)
    sqrt5 = math.sqrt(5.0)
    return (
        sigma_f.pow(2)
        * (1.0 + sqrt5 * dist + 5.0 * dist.pow(2) / 3.0)
        * torch.exp(-sqrt5 * dist)
    )


KERNEL_DICT = {
    "rbf": rbf_kernel,
    "matern12": matern12_kernel,
    "matern32": matern32_kernel,
    "matern52": matern52_kernel,
}


class GPSampler:
    """
    Batched GP sampler that generates data efficiently.

    Args:
        kernel_list: List of kernel names to sample from
        kernel_weights: Sampling weights for each kernel
        x_range: Input domain bounds
        lengthscale_range: Range for sampling lengthscales
        variance_range: Range for sampling function variance
        noise_range: Range for sampling observation noise
        jitter: Small value added to diagonal for numerical stability
        device: Computation device
        dtype: Tensor data type
    """

    def __init__(
        self,
        kernel_list: List[str] = None,
        kernel_weights: List[float] = None,
        x_range: List[List[float]] = None,
        lengthscale_range: Tuple[float, float] = (0.05, 2.0),
        variance_range: Tuple[float, float] = (0.1, 2.0),
        noise_range: Tuple[float, float] = (0.02, 0.4),
        jitter: float = 1e-5,
        device: str = "cpu",
        dtype: torch.dtype = torch.float32,
    ):
        if kernel_list is None:
            kernel_list = ["matern12", "matern32", "matern52", "rbf"]
        if kernel_weights is None:
            kernel_weights = [0.1, 0.2, 0.35, 0.35]
        if x_range is None:
            x_range = [[-2.0], [2.0]]  # Default 1D

        self.kernel_list = kernel_list
        self.kernel_weights = torch.tensor(kernel_weights, device=device, dtype=dtype)
        # Shape: [2, x_dim]
        self.x_range = torch.tensor(x_range, device=device, dtype=dtype)
        self.x_dim = self.x_range.shape[1]
        self.lengthscale_range = torch.tensor(
            lengthscale_range, device=device, dtype=dtype
        )
        self.variance_range = torch.tensor(variance_range, device=device, dtype=dtype)
        self.noise_range = torch.tensor(noise_range, device=device, dtype=dtype)
        self.jitter = jitter
        self.device = device
        self.dtype = dtype

    def _compute_batch_kernel(self, x, lengthscales, sigma_f, kernel_fn):
        """Compute kernel matrices for entire batch with different hyperparameters."""

        # Use vmap to vectorize kernel computation over batch dimension
        def compute_single_kernel(x_single, ls_single, sf_single):
            return kernel_fn(x_single, x_single, ls_single, sf_single)

        # Vectorize over batch dimension
        K_batch = torch.vmap(compute_single_kernel)(x, lengthscales, sigma_f)

        return K_batch

    def generate_batch(
        self,
        batch_size: int,
        num_context: Optional[int] = None,
        num_buffer: int = 50,
        num_target: int = 50,
        context_range: Tuple[int, int] = (3, 47),
    ) -> DataAttr:
        """
        Generate batch of GP samples with fixed context size.

        Returns:
            DataAttr with fields: xc, yc, xb, yb, xt, yt
        """
        if num_context is None:
            # NOTE: Use discrete values: 4, 8, 16, 32, 48, 64, 128, 192
            context_range = [4, 8, 16, 32, 48, 64, 128, 192]
            num_context = context_range[torch.randint(0, len(context_range), (1,)).item()]
            # num_context = torch.randint(
            #     context_range[0], context_range[1] + 1, (1,)
            # ).item()

        num_total = num_context + num_buffer + num_target

        # Sample kernel for entire batch (same kernel type within batch)
        kernel_idx = torch.multinomial(self.kernel_weights, 1).item()
        kernel_fn = KERNEL_DICT[self.kernel_list[kernel_idx]]

        # Sample hyperparameters per function in batch
        # Sample lengthscales for each dimension
        lengthscales = self.lengthscale_range[0] + (
            self.lengthscale_range[1] - self.lengthscale_range[0]
        ) * torch.rand(batch_size, self.x_dim, device=self.device, dtype=self.dtype)

        # Scale lengthscales by sqrt(x_dim) as in original
        lengthscales = lengthscales * (self.x_dim**0.5)

        sigma_f = torch.sqrt(
            self.variance_range[0]
            + (self.variance_range[1] - self.variance_range[0])
            * torch.rand(batch_size, device=self.device, dtype=self.dtype)
        )

        noises = self.noise_range[0] + (
            self.noise_range[1] - self.noise_range[0]
        ) * torch.rand(batch_size, device=self.device, dtype=self.dtype)

        # Generate x points using Sobol sequence for better coverage
        # Generate all points at once and reshape
        x_all = sample_sobol(
            batch_size * num_total, self.x_range[0], self.x_range[1]
        ).to(self.device, self.dtype)
        x = rearrange(x_all, "(b n) d -> b n d", b=batch_size, n=num_total)

        # Compute kernel matrices for entire batch
        K = self._compute_batch_kernel(x, lengthscales, sigma_f, kernel_fn)

        # Add jitter for numerical stability
        eye = torch.eye(num_total, device=self.device, dtype=self.dtype)
        eye_batch = repeat(eye, "n m -> b n m", b=batch_size)
        K = K + self.jitter * eye_batch

        # Sample from GP prior using Cholesky decomposition
        L = torch.linalg.cholesky(K)
        z = torch.randn(batch_size, num_total, 1, device=self.device, dtype=self.dtype)
        y = torch.bmm(L, z)

        # Add observation noise
        noise_expanded = rearrange(noises, "b -> b 1 1")
        y = y + noise_expanded * torch.randn_like(y)

        # Create random permutation for context/buffer/target split
        perm = torch.randperm(num_total, device=self.device)
        ctx_idx = perm[:num_context]
        buf_idx = perm[num_context : num_context + num_buffer]
        tar_idx = perm[num_context + num_buffer :]

        return DataAttr(
            xc=x[:, ctx_idx],
            yc=y[:, ctx_idx],
            xb=x[:, buf_idx],
            yb=y[:, buf_idx],
            xt=x[:, tar_idx],
            yt=y[:, tar_idx],
        )
