from __future__ import annotations

from typing import Any, Dict, Optional, Tuple

import numpy as np
import torch

from phijax.torch.data import UniformSampler
from phijax.torch.equations.base import IVP
from phijax.torch.equations.registry import register_pde


def get_dataset(ref_path: str, fraction=(0.0, 1.0)):
    import scipy.io
    data = scipy.io.loadmat(ref_path)

    u_ref = data["usol"]  # (Nt, Nx, Ny)
    v_ref = data["vsol"]  # (Nt, Nx, Ny)
    t_star = data["t"].reshape(-1)
    x_star = data["x"].reshape(-1)
    y_star = data["y"].reshape(-1)

    start = int(float(fraction[0]) * len(t_star))
    end = int(float(fraction[1]) * len(t_star))
    end = max(end, start + 1)
    num = end - start

    u_ref = u_ref[start:end, :, :]
    v_ref = v_ref[start:end, :, :]
    t_star = t_star[:num]

    eps = float(np.asarray(data["eps"]).reshape(-1)[0]) if "eps" in data else None
    k = float(np.asarray(data["k"]).reshape(-1)[0]) if "k" in data else None
    return u_ref, v_ref, t_star, x_star, y_star, eps, k


def _grad(y: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    (g,) = torch.autograd.grad(
        y,
        x,
        grad_outputs=torch.ones_like(y),
        create_graph=True,
        retain_graph=True,
        allow_unused=False,
    )
    return g


def _laplace_scalar(f: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    fx = _grad(f, x)
    fxx = _grad(fx, x)
    fy = _grad(f, y)
    fyy = _grad(fy, y)
    return fxx + fyy


@register_pde("ginzburg_landau", aliases=["gl", "ginzburg", "ginzburglandau"])
class GinzburgLandau(IVP):
    loss_keys = ("uics", "vics", "ru", "rv")

    def __init__(self, config: Any, model: torch.nn.Module, *, device: Optional[torch.device] = None):
        super().__init__(config, model, device=device)

        pcfg = self.config.pde_config
        fraction = getattr(pcfg, "data_fraction", (0.0, 1.0))

        u_ref, v_ref, t_star, x_star, y_star, eps_data, k_data = get_dataset(
            pcfg.ref_path, fraction=fraction
        )

        self.x_star = torch.as_tensor(x_star, dtype=torch.float32, device=self.device)
        self.y_star = torch.as_tensor(y_star, dtype=torch.float32, device=self.device)
        self.t_star = torch.as_tensor(t_star, dtype=torch.float32, device=self.device)

        self.u_ref = torch.as_tensor(u_ref, dtype=torch.float32, device=self.device)
        self.v_ref = torch.as_tensor(v_ref, dtype=torch.float32, device=self.device)

        self.eps = float(getattr(pcfg, "eps", eps_data if eps_data is not None else 1e-2))
        self.k = float(getattr(pcfg, "k", k_data if k_data is not None else 1.0))

        self.t0 = float(self.t_star[0].item())
        self.t1 = float(self.t_star[-1].item())

        self.u0 = self.u_ref[0, ...]  # (Nx, Ny)
        self.v0 = self.v_ref[0, ...]  # (Nx, Ny)

        x0, x1 = float(self.x_star[0].item()), float(self.x_star[-1].item())
        y0, y1 = float(self.y_star[0].item()), float(self.y_star[-1].item())

        self.dom = torch.tensor([[self.t0, self.t1], [x0, x1], [y0, y1]], dtype=torch.float32, device=self.device)

        bs = int(self.config.training.batch_size)
        self.sampler = UniformSampler(self.dom, batch_size=bs, device=self.device)

        self._t_scale = float(self.t_star[-1].item()) if float(self.t_star[-1].item()) != 0.0 else 1.0

    def neural_net(self, t: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        t_scaled = t / self._t_scale
        z = torch.stack([t_scaled, x, y], dim=-1)
        out = self.model(z)
        return out[..., 0], out[..., 1]

    def u_net(self, t: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        u, _ = self.neural_net(t, x, y)
        return u

    def v_net(self, t: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        _, v = self.neural_net(t, x, y)
        return v

    def r_net(self, t: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        t_req = t.clone().detach().requires_grad_(True)
        x_req = x.clone().detach().requires_grad_(True)
        y_req = y.clone().detach().requires_grad_(True)

        u, v = self.neural_net(t_req, x_req, y_req)

        u_t = _grad(u, t_req)
        v_t = _grad(v, t_req)

        u_lap = _laplace_scalar(u, x_req, y_req)
        v_lap = _laplace_scalar(v, x_req, y_req)

        r2 = u * u + v * v

        ru = u_t - self.eps * u_lap - self.k * (u - u * r2 + 1.5 * v * r2)
        rv = v_t - self.eps * v_lap - self.k * (v - v * r2 - 1.5 * u * r2)
        return ru, rv

    def residuals(self, batch: torch.Tensor, *args) -> Dict[str, torch.Tensor]:
        xx, yy = torch.meshgrid(self.x_star, self.y_star, indexing="ij")
        X = xx.reshape(-1)
        Y = yy.reshape(-1)
        T0 = torch.full_like(X, float(self.t0))

        u0_pred = self.u_net(T0, X, Y).reshape_as(self.u0)
        v0_pred = self.v_net(T0, X, Y).reshape_as(self.v0)

        t_b = batch[:, 0]
        x_b = batch[:, 1]
        y_b = batch[:, 2]
        ru, rv = self.r_net(t_b, x_b, y_b)

        return {
            "uics": u0_pred - self.u0,
            "vics": v0_pred - self.v0,
            "ru": ru,
            "rv": rv,
        }

    def compute_l2_error(self) -> Tuple[torch.Tensor, torch.Tensor]:
        t = self.t_star
        x = self.x_star
        y = self.y_star
        TT, XX, YY = torch.meshgrid(t, x, y, indexing="ij")
        u_pred, v_pred = self.neural_net(TT.reshape(-1), XX.reshape(-1), YY.reshape(-1))
        u_pred = u_pred.reshape_as(self.u_ref)
        v_pred = v_pred.reshape_as(self.v_ref)
        u_err = torch.linalg.norm(u_pred - self.u_ref) / torch.linalg.norm(self.u_ref)
        v_err = torch.linalg.norm(v_pred - self.v_ref) / torch.linalg.norm(self.v_ref)
        return u_err, v_err

    def log_errors(self) -> Dict[str, torch.Tensor]:
        u_e, v_e = self.compute_l2_error()
        return {"u_rmse_error": u_e.detach(), "v_rmse_error": v_e.detach()}


@register_pde(
    "ginzburg_landau_tm",
    aliases=["gl_tm", "tmgl", "curriculum_ginzburg", "cgl", "cginzburg", "cginzburglandau"],
)
class GinzburgLandauTM(GinzburgLandau):
    def neural_net(self, t: torch.Tensor, x: torch.Tensor, y: torch.Tensor):
        t_scaled = t / (float(self.wt_star[-1].item()) if float(self.wt_star[-1].item()) != 0.0 else 1.0)
        z = torch.stack([t_scaled, x, y], dim=-1)
        out = self.model(z)
        return out[..., 0], out[..., 1]

    def compute_l2_error(self) -> Tuple[torch.Tensor, torch.Tensor]:
        t = self.wt_star
        x = self.x_star
        y = self.y_star
        TT, XX, YY = torch.meshgrid(t, x, y, indexing="ij")
        u_pred, v_pred = self.neural_net(TT.reshape(-1), XX.reshape(-1), YY.reshape(-1))
        u_pred = u_pred.reshape_as(self.u_ref_window)
        v_pred = v_pred.reshape_as(self.v_ref_window)
        u_err = torch.linalg.norm(u_pred - self.u_ref_window) / torch.linalg.norm(self.u_ref_window)
        v_err = torch.linalg.norm(v_pred - self.v_ref_window) / torch.linalg.norm(self.v_ref_window)
        return u_err, v_err

    def set_initial_condition(self, u0, v0, u_star, v_star, window_t_star, *args):
        self.u0 = u0
        self.v0 = v0
        self.u_ref_window = u_star
        self.v_ref_window = v_star
        self.wt_star = window_t_star
