from __future__ import annotations

from typing import Any, Dict, Optional

import torch

from phijax.torch.data import UniformSampler, MeshSampler
from phijax.torch.equations.base import IVP
from phijax.torch.equations.registry import register_pde


def get_dataset(ref_path: str):
    import scipy.io
    data = scipy.io.loadmat(ref_path)
    u_ref = data["usol"]          # (Nt, Nx)
    t_star = data["t"].reshape(-1)
    x_star = data["x"].reshape(-1)
    return u_ref, t_star, x_star


def _grad(y: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    (g,) = torch.autograd.grad(
        y,
        x,
        grad_outputs=torch.ones_like(y),
        create_graph=True,
        retain_graph=True,
        allow_unused=False,
    )
    return g


@register_pde("burgers", aliases=["burgers1d"])
class Burgers(IVP):
    loss_keys = ("ics", "bcs_l", "bcs_r", "res")

    def __init__(self, config: Any, model: torch.nn.Module, *, device: Optional[torch.device] = None):
        super().__init__(config, model, device=device)

        pcfg = self.config.pde_config
        u_ref, t_star, x_star = get_dataset(pcfg.ref_path)

        self.u_ref = torch.as_tensor(u_ref, dtype=torch.float32, device=self.device)
        self.t_star = torch.as_tensor(t_star, dtype=torch.float32, device=self.device)
        self.x_star = torch.as_tensor(x_star, dtype=torch.float32, device=self.device)

        self.u0 = self.u_ref[0, :]  # (Nx,)

        self.t0 = float(self.t_star[0].item())
        self.t1 = float(self.t_star[-1].item())
        self.x0 = float(self.x_star[0].item())
        self.x1 = float(self.x_star[-1].item())

        self.dom = torch.tensor([[self.t0, self.t1], [self.x0, self.x1]], dtype=torch.float32, device=self.device)

        if getattr(self.config.training, "sampler", None) is None:
            self.config.training.sampler = "uniform"

        bs = int(self.config.training.batch_size)

        if self.config.training.sampler == "uniform":
            self.sampler = UniformSampler(self.dom, batch_size=bs, device=self.device)
        elif self.config.training.sampler == "fixed":
            self.sampler = MeshSampler(self.dom, res=[100, 200], batch_size=bs, device=self.device)
        else:
            raise ValueError(f"Unknown sampler: {self.config.training.sampler}")

        self.nu = float(getattr(pcfg, "nu", 0.01 / torch.pi))  # if you want override
        self._visc = float(getattr(pcfg, "visc", 0.01 / 3.141592653589793))

    def u_net(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        z = torch.stack([t, x], dim=-1)
        return self.model(z)[..., 0]

    def ux_net(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        t_req = t if t.requires_grad else t.detach()
        x_req = x.clone().detach().requires_grad_(True)
        u = self.u_net(t_req, x_req)
        return _grad(u, x_req)

    def r_net(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        t_req = t.clone().detach().requires_grad_(True)
        x_req = x.clone().detach().requires_grad_(True)

        u = self.u_net(t_req, x_req)
        u_t = _grad(u, t_req)
        u_x = _grad(u, x_req)
        u_xx = _grad(_grad(u, x_req), x_req)

        return u_t + u * u_x - (0.01 / torch.pi) * u_xx

    def residuals(self, batch: torch.Tensor, *args) -> Dict[str, torch.Tensor]:
        x_ic = self.x_star
        t_ic = torch.full_like(x_ic, float(self.t0))
        u_pred_ic = self.u_net(t_ic, x_ic)

        t_b = batch[:, 0]
        x_b = batch[:, 1]
        r_pred = self.r_net(t_b, x_b)

        t_bc = self.t_star
        x0 = torch.full_like(t_bc, float(self.x0))
        x1 = torch.full_like(t_bc, float(self.x1))
        u_bc_l = self.u_net(t_bc, x0)
        u_bc_r = self.u_net(t_bc, x1)

        return {
            "ics": u_pred_ic - self.u0,
            "bcs_l": u_bc_l,
            "bcs_r": u_bc_r,
            "res": r_pred,
        }
