from __future__ import annotations

import math
from typing import Any, Dict, Optional

import numpy as np
import torch

from phijax.torch.data import UniformSampler, MeshSampler
from phijax.torch.equations.base import IVP
from phijax.torch.equations.registry import register_pde


def get_dataset(ref_path: str):
    import scipy.io
    data = scipy.io.loadmat(ref_path)
    u_ref = data["usol"]          # (Nt, Nx)
    t_star = data["t"].reshape(-1)
    x_star = data["x"].reshape(-1)
    return u_ref, t_star, x_star


def _grad(y: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    (g,) = torch.autograd.grad(
        y,
        x,
        grad_outputs=torch.ones_like(y),
        create_graph=True,
        retain_graph=True,
        allow_unused=False,
    )
    return g


@register_pde("allencahn", aliases=["ac", "allen-cahn"])
class AllenCahn(IVP):
    loss_keys = ("ics", "res")

    def __init__(self, config: Any, model: torch.nn.Module, *, device: Optional[torch.device] = None):
        super().__init__(config, model, device=device)
        pcfg = self.config.pde_config

        u_ref, t_star, x_star = get_dataset(pcfg.ref_path)

        self.u_ref = torch.as_tensor(u_ref, dtype=torch.float32, device=self.device)
        self.t_star = torch.as_tensor(t_star, dtype=torch.float32, device=self.device)
        self.x_star = torch.as_tensor(x_star, dtype=torch.float32, device=self.device)

        self.u0 = self.u_ref[0, :]  # (Nx,)

        self.t0 = self.t_star[0].item()
        self.t1 = self.t_star[-1].item()
        self.x0 = self.x_star[0].item()
        self.x1 = self.x_star[-1].item()

        self.dom = torch.tensor([[self.t0, self.t1], [self.x0, self.x1]], dtype=torch.float32, device=self.device)

        if getattr(self.config.training, "sampler", None) is None:
            self.config.training.sampler = "uniform"

        bs = int(self.config.training.batch_size)

        if self.config.training.sampler == "uniform":
            self.sampler = UniformSampler(self.dom, batch_size=bs, device=self.device)
        elif self.config.training.sampler == "fixed":
            self.sampler = MeshSampler(self.dom, res=[100, 200], batch_size=bs, device=self.device)
        else:
            raise ValueError(f"Unknown sampler: {self.config.training.sampler}")

    def u_net(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        z = torch.stack([t, x], dim=-1)
        u = self.model(z)[..., 0]
        return u

    def ux_net(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        t_req = t if t.requires_grad else t.detach()
        x_req = x.clone().detach().requires_grad_(True)
        u = self.u_net(t_req, x_req)
        return _grad(u, x_req)

    def r_net(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        t_req = t.clone().detach().requires_grad_(True)
        x_req = x.clone().detach().requires_grad_(True)

        u = self.u_net(t_req, x_req)
        u_t = _grad(u, t_req)
        u_x = _grad(u, x_req)
        u_xx = _grad(u_x, x_req)

        return u_t + 5.0 * u**3 - 5.0 * u - 1e-4 * u_xx

    def residuals(self, batch: torch.Tensor, *args) -> Dict[str, torch.Tensor]:
        x_ic = self.x_star
        t_ic = torch.full_like(x_ic, float(self.t0))
        u_pred_ic = self.u_net(t_ic, x_ic)

        t_b = batch[:, 0]
        x_b = batch[:, 1]
        r_pred = self.r_net(t_b, x_b)

        return {
            "ics": u_pred_ic - self.u0,
            "res": r_pred,
        }


@register_pde("soft_allencahn", aliases=["soft_ac", "sac"])
class SoftAllenCahn(AllenCahn):
    loss_keys = ("ics", "bcs_u", "bcs_v", "res")

    def residuals(self, batch: torch.Tensor, *args) -> Dict[str, torch.Tensor]:
        x_ic = self.x_star
        t_ic = torch.full_like(x_ic, float(self.t0))
        u_pred_ic = self.u_net(t_ic, x_ic)

        t_b = batch[:, 0]
        x_b = batch[:, 1]
        r_pred = self.r_net(t_b, x_b)

        t_bc = self.t_star
        x0 = torch.full_like(t_bc, float(self.x0))
        x1 = torch.full_like(t_bc, float(self.x1))

        u_left = self.u_net(t_bc, x0)
        u_right = self.u_net(t_bc, x1)

        ux_left = self.ux_net(t_bc, x0)
        ux_right = self.ux_net(t_bc, x1)

        return {
            "ics": u_pred_ic - self.u0,
            "bcs_u": u_left - u_right,
            "bcs_v": ux_left - ux_right,
            "res": r_pred,
        }
