from __future__ import annotations

import math
from typing import Any, Dict, Optional

import torch
from torch import nn

from phijax.torch.data import UniformSampler, MeshSampler
from phijax.torch.equations.base import IVP
from phijax.torch.equations.registry import register_pde



def _grad(y: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    (g,) = torch.autograd.grad(
        y,
        x,
        grad_outputs=torch.ones_like(y),
        create_graph=True,
        retain_graph=True,
        allow_unused=False,
    )
    return g

@register_pde(
    "convection", aliases=["1d_convection", "conv"], 
    defaults={ "epsilon": 50.0 , "num_points_per_dim": 256}
)
class Convection(IVP):
    def __init__(self, config: Any, model: nn.Module, device: Optional[torch.device] = None):
        super().__init__(config, model, device=device)

        pcfg = self.config.pde_config
        num_pts = int(getattr(pcfg, "num_points_per_dim", 256) or 256)
        epsilon = float(getattr(pcfg, "epsilon", 50.0) or 50.0)
        self.epsilon = epsilon

        x = torch.linspace(0.0, 2.0 * math.pi, num_pts, device=self.device)
        t = torch.linspace(0.0, 1.0, num_pts, device=self.device)
        tt, xx = torch.meshgrid(t, x, indexing="ij")
        
        self.u_ref = torch.sin(xx - epsilon * tt)
        self.u0 = torch.sin(x)

        self.t_star = t
        self.x_star = x
        self.t0, self.t1 = t[0], t[-1]
        self.x0, self.x1 = x[0], x[-1]

        self.dom = torch.tensor([[self.t0, self.t1], [self.x0, self.x1]], device=self.device)

        if getattr(self.config.training, "sampler", None) is None:
            self.config.training.sampler = "uniform"

        if self.config.training.sampler == "uniform":
            self.sampler = UniformSampler(self.dom, batch_size=self.config.training.batch_size, device=self.device)
            print("Using random/uniform sampler")
        elif self.config.training.sampler == "fixed":
            self.sampler = MeshSampler(self.dom, res=[51, 51], batch_size=self.config.training.batch_size, device=self.device)
            print("Using fixed mesh sampler", num_pts)
        else:
            raise ValueError(f"Unknown sampler: {self.config.training.sampler}")

    def u_net(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        z = torch.stack([t, x], dim=-1)
        out = self.model(z)
        return out[..., 0]

    def r_net(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        t_req = t.detach().requires_grad_(True)
        x_req = x.detach().requires_grad_(True)
        u = self.u_net(t_req, x_req)
        u_t = _grad(u, t_req)
        u_x = _grad(u, x_req)
        return u_t + self.epsilon * u_x

    def residuals(self, batch: torch.Tensor, *args) -> Dict[str, torch.Tensor]:
        x_line = self.x_star
        t_line = self.t_star

        u_ic = self.u_net(self.t0.expand_as(x_line), x_line)
        u_bc1 = self.u_net(t_line, self.x0.expand_as(t_line))
        u_bc2 = self.u_net(t_line, self.x1.expand_as(t_line))

        t_b = batch[:, 0]
        x_b = batch[:, 1]
        r = self.r_net(t_b, x_b)

        return {
            "ics": u_ic - self.u0,
            "bcs": u_bc1 - u_bc2,
            "res": r,
        }

@register_pde(
    "spectral_convection", aliases=["sconv", "spectral_conv"],
    defaults={"epsilon": 50.0, "num_points_per_dim": 256}
)
class SpectralConvection(Convection):
    loss_keys = ("ics", "bcs", "res", "spec")

    def _features(self, z: torch.Tensor) -> torch.Tensor:
        m = self.model
        if hasattr(m, "features"):
            return m.features(z)
        if hasattr(m, "forward_features"):
            return m.forward_features(z)
        if hasattr(m, "return_features"):
            feat, _ = m(z, return_features=True)
            return feat
        raise AttributeError("Model must expose features for spectral regularization.")

    def spectral_reg(self, Z: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
        F = self._features(Z)
        F = F / (torch.linalg.norm(F, dim=1, keepdim=True) + eps)
        n = F.shape[0]
        C = (F.T @ F) / float(n)
        trC = torch.trace(C)
        trC2 = torch.sum(C * C)
        return trC2 / (trC * trC + eps)

    def residuals(self, batch: torch.Tensor, *args) -> Dict[str, torch.Tensor]:
        res = super().residuals(batch, *args)
        res["spec"] = self.spectral_reg(batch).reshape(1)
        return res
