from __future__ import annotations

from typing import Any, Dict, Optional

import torch

from phijax.torch.data import UniformSampler, MeshSampler
from phijax.torch.equations.base import IVP
from phijax.torch.equations.registry import register_pde


def _grad(y: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    (g,) = torch.autograd.grad(
        y,
        x,
        grad_outputs=torch.ones_like(y),
        create_graph=True,
        retain_graph=True,
        allow_unused=False,
    )
    return g


@register_pde(
    "reaction",
    aliases=["reaction1d"],
    defaults={"epsilon": 5.0, "num_points_per_dim": 256},
)
class Reaction1D(IVP):
    loss_keys = ("ics", "bcs", "res")

    def __init__(self, config: Any, model: torch.nn.Module, *, device: Optional[torch.device] = None):
        super().__init__(config, model, device=device)

        pcfg = self.config.pde_config
        num_pts = int(getattr(pcfg, "num_points_per_dim", 101) or 101)
        self.epsilon = float(getattr(pcfg, "epsilon", 5.0) or 5.0)

        x = torch.linspace(0.0, 2.0 * torch.pi, num_pts, device=self.device)
        t = torch.linspace(0.0, 1.0, num_pts, device=self.device)
        tt, xx = torch.meshgrid(t, x, indexing="ij")

        sigma = torch.pi / 4.0
        h = torch.exp(-1.0 * (xx - torch.pi).square() / (2.0 * sigma.square()))
        exp_et = torch.exp(self.epsilon * tt)
        u_ref = h * exp_et / (h * (exp_et - 1.0) + 1.0)

        u0 = torch.exp(-1.0 * (x - torch.pi).square() / (2.0 * sigma.square()))

        self.u_ref = u_ref
        self.u0 = u0
        self.t_star = t
        self.x_star = x

        self.t0 = float(self.t_star[0].item())
        self.t1 = float(self.t_star[-1].item())
        self.x0 = float(self.x_star[0].item())
        self.x1 = float(self.x_star[-1].item())

        self.dom = torch.tensor([[self.t0, self.t1], [self.x0, self.x1]], dtype=torch.float32, device=self.device)

        if getattr(self.config.training, "sampler", None) is None:
            self.config.training.sampler = "uniform"

        bs = int(self.config.training.batch_size)

        if self.config.training.sampler == "uniform":
            self.sampler = UniformSampler(self.dom, batch_size=bs, device=self.device)
        elif self.config.training.sampler == "fixed":
            self.sampler = MeshSampler(self.dom, res=[num_pts, num_pts], batch_size=bs, device=self.device)
        else:
            raise ValueError(f"Unknown sampler: {self.config.training.sampler}")

    def u_net(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        z = torch.stack([t, x], dim=-1)
        return self.model(z)[..., 0]

    def r_net(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        t_req = t.clone().detach().requires_grad_(True)
        x_req = x.clone().detach().requires_grad_(True)
        u = self.u_net(t_req, x_req)
        u_t = _grad(u, t_req)
        return u_t - self.epsilon * u * (1.0 - u)

    def residuals(self, batch: torch.Tensor, *args) -> Dict[str, torch.Tensor]:
        x_ic = self.x_star
        t_ic = torch.full_like(x_ic, float(self.t0))
        u_ic = self.u_net(t_ic, x_ic)

        t_bc = self.t_star
        x_left = torch.full_like(t_bc, float(self.x0))
        x_right = torch.full_like(t_bc, float(self.x1))
        u_bc_left = self.u_net(t_bc, x_left)
        u_bc_right = self.u_net(t_bc, x_right)

        t_b = batch[:, 0]
        x_b = batch[:, 1]
        r_pred = self.r_net(t_b, x_b)

        return {
            "ics": u_ic - self.u0,
            "bcs": u_bc_left - u_bc_right,
            "res": r_pred,
        }
