from __future__ import annotations

from typing import Any, Dict, Optional

import torch

from phijax.torch.data import UniformSampler, MeshSampler
from phijax.torch.equations.base import IVP
from phijax.torch.equations.registry import register_pde


def _grad(y: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    (g,) = torch.autograd.grad(
        y,
        x,
        grad_outputs=torch.ones_like(y),
        create_graph=True,
        retain_graph=True,
        allow_unused=False,
    )
    return g


@register_pde(
    "wave",
    aliases=["wave1d", "1d_wave"],
    defaults={"epsilon": 3.0, "num_points_per_dim": 256},
)
class Wave1D(IVP):
    loss_keys = ("ics_u", "ics_v", "bcs_r", "bcs_l", "res")

    def __init__(self, config: Any, model: torch.nn.Module, *, device: Optional[torch.device] = None):
        super().__init__(config, model, device=device)

        pcfg = self.config.pde_config
        num_pts = int(getattr(pcfg, "num_points_per_dim", 256) or 256)
        self.epsilon = float(getattr(pcfg, "epsilon", 3.0) or 3.0)

        x = torch.linspace(0.0, 1.0, num_pts, device=self.device)
        t = torch.linspace(0.0, 1.0, num_pts, device=self.device)
        tt, xx = torch.meshgrid(t, x, indexing="ij")

        u_ref = torch.sin(torch.pi * xx) * torch.cos(2.0 * torch.pi * tt) + 0.5 * torch.sin(
            self.epsilon * torch.pi * xx
        ) * torch.cos(2.0 * torch.pi * self.epsilon * tt)

        u0 = torch.sin(torch.pi * x) + 0.5 * torch.sin(self.epsilon * torch.pi * x)

        self.u_ref = u_ref
        self.u0 = u0
        self.t_star = t
        self.x_star = x

        self.t0 = float(self.t_star[0].item())
        self.t1 = float(self.t_star[-1].item())
        self.x0 = float(self.x_star[0].item())
        self.x1 = float(self.x_star[-1].item())

        self.dom = torch.tensor([[self.t0, self.t1], [self.x0, self.x1]], dtype=torch.float32, device=self.device)

        if getattr(self.config.training, "sampler", None) is None:
            self.config.training.sampler = "uniform"

        bs = int(self.config.training.batch_size)

        if self.config.training.sampler == "uniform":
            self.sampler = UniformSampler(self.dom, batch_size=bs, device=self.device)
        elif self.config.training.sampler == "fixed":
            self.sampler = MeshSampler(self.dom, res=[51, 51], batch_size=bs, device=self.device)
        else:
            raise ValueError(f"Unknown sampler: {self.config.training.sampler}")

    def u_net(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        z = torch.stack([t, x], dim=-1)
        return self.model(z)[..., 0]

    def u_t(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        t_req = t.clone().detach().requires_grad_(True)
        x_req = x.clone().detach().requires_grad_(True)
        u = self.u_net(t_req, x_req)
        return _grad(u, t_req)

    def r_net(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        t_req = t.clone().detach().requires_grad_(True)
        x_req = x.clone().detach().requires_grad_(True)

        u = self.u_net(t_req, x_req)
        u_t = _grad(u, t_req)
        u_tt = _grad(u_t, t_req)

        u_x = _grad(u, x_req)
        u_xx = _grad(u_x, x_req)

        return u_tt - 4.0 * u_xx

    def residuals(self, batch: torch.Tensor, *args) -> Dict[str, torch.Tensor]:
        x_ic = self.x_star
        t_ic = torch.full_like(x_ic, float(self.t0))

        u_ic = self.u_net(t_ic, x_ic)

        t_req = t_ic.clone().detach().requires_grad_(True)
        x_req = x_ic.clone().detach().requires_grad_(True)
        u_for_v = self.u_net(t_req, x_req)
        v_ic = _grad(u_for_v, t_req)

        t_bc = self.t_star
        x_left = torch.full_like(t_bc, float(self.x0))
        x_right = torch.full_like(t_bc, float(self.x1))
        u_bc_left = self.u_net(t_bc, x_left)
        u_bc_right = self.u_net(t_bc, x_right)

        t_b = batch[:, 0]
        x_b = batch[:, 1]
        r_pred = self.r_net(t_b, x_b)

        return {
            "ics_u": u_ic - self.u0,
            "ics_v": v_ic,
            "bcs_r": u_bc_left,
            "bcs_l": u_bc_right,
            "res": r_pred,
        }
