from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple

import numpy as np
import torch

from phijax.torch.data import UniformSampler
from phijax.torch.equations.base import IVP, residual_stats_logging
from phijax.torch.equations.registry import register_pde


@dataclass
class WindowState:
    idx: int
    v0: torch.Tensor
    u0: torch.Tensor
    p0: torch.Tensor
    T0: torch.Tensor
    num_steps: int
    max_steps: int


def get_dataset(ref_path: str, Re: float = 1e4, tf: float = 1.0):
    data = np.load(ref_path, allow_pickle=True).item()

    w_ref = np.asarray(data["vorticity"])
    vel = np.asarray(data["velocity"])
    u_ref = vel[..., 0]
    v_ref = vel[..., 1]

    t = np.asarray(data["t"]).reshape(-1)
    t = t - t[0]

    num_steps = int(float(tf) * t.shape[0])
    num_steps = max(num_steps, 1)

    u_ref = u_ref[:num_steps]
    v_ref = v_ref[:num_steps]
    w_ref = w_ref[:num_steps]
    t = t[:num_steps]

    nu = 1.0 / float(Re)
    coords = np.asarray(data["coords"])
    return u_ref, v_ref, w_ref, t, coords, nu


def _grad(y: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    (g,) = torch.autograd.grad(
        y,
        x,
        grad_outputs=torch.ones_like(y),
        create_graph=True,
        retain_graph=True,
        allow_unused=False,
    )
    return g


def _laplace_scalar(f: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    fx = _grad(f, x)
    fxx = _grad(fx, x)
    fy = _grad(f, y)
    fyy = _grad(fy, y)
    return fxx + fyy


@register_pde("kolmogorov", aliases=["kf"])
class KolmogorovFlow(IVP):
    loss_keys = ("uics", "vics", "ru", "rv", "rc")

    def __init__(self, config: Any, model: torch.nn.Module, *, device: Optional[torch.device] = None):
        super().__init__(config, model, device=device)

        pcfg = self.config.pde_config
        u_ref, v_ref, w_ref, t_star, coords, nu = get_dataset(pcfg.ref_path)

        self.u_ref = torch.as_tensor(u_ref, dtype=torch.float32, device=self.device)
        self.v_ref = torch.as_tensor(v_ref, dtype=torch.float32, device=self.device)
        self.w_ref = torch.as_tensor(w_ref, dtype=torch.float32, device=self.device)

        self.t_star = torch.as_tensor(t_star, dtype=torch.float32, device=self.device)
        self.coords = torch.as_tensor(coords, dtype=torch.float32, device=self.device)
        self.nu = float(nu)

        self.u0 = self.u_ref[0, ...]
        self.v0 = self.v_ref[0, ...]
        self.w0 = self.w_ref[0, ...]

        self.force_fn = lambda x, y: 2.0 * torch.sin(4.0 * torch.pi * y)

        self._t_scale = float(self.t_star[-1].item()) if float(self.t_star[-1].item()) != 0.0 else 1.0

    def get_ref(self):
        return self.u_ref, self.v_ref, self.w_ref, self.t_star, self.coords

    def net(self, t: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        t_scaled = t / self._t_scale
        z = torch.stack([t_scaled, x, y], dim=-1)
        out = self.model(z)
        return out[..., 0], out[..., 1], out[..., 2]

    def u_net(self, t: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        u, _, _ = self.net(t, x, y)
        return u

    def v_net(self, t: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        _, v, _ = self.net(t, x, y)
        return v

    def p_net(self, t: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        _, _, p = self.net(t, x, y)
        return p

    def w_net(self, t: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        t_req = t.clone().detach().requires_grad_(True)
        x_req = x.clone().detach().requires_grad_(True)
        y_req = y.clone().detach().requires_grad_(True)
        u = self.u_net(t_req, x_req, y_req)
        v = self.v_net(t_req, x_req, y_req)
        u_y = _grad(u, y_req)
        v_x = _grad(v, x_req)
        return v_x - u_y

    def r_net(self, t: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        t_req = t.clone().detach().requires_grad_(True)
        x_req = x.clone().detach().requires_grad_(True)
        y_req = y.clone().detach().requires_grad_(True)

        u, v, p = self.net(t_req, x_req, y_req)

        u_t = _grad(u, t_req)
        u_x = _grad(u, x_req)
        u_y = _grad(u, y_req)

        v_t = _grad(v, t_req)
        v_x = _grad(v, x_req)
        v_y = _grad(v, y_req)

        p_x = _grad(p, x_req)
        p_y = _grad(p, y_req)

        u_lap = _laplace_scalar(u, x_req, y_req)
        v_lap = _laplace_scalar(v, x_req, y_req)

        force = self.force_fn(x_req, y_req)

        ru = u_t + u * u_x + v * u_y + p_x - self.nu * u_lap - force
        rv = v_t + u * v_x + v * v_y + p_y - self.nu * v_lap
        rc = u_x + v_y
        return ru, rv, rc

    def residuals(self, batch: Dict[str, torch.Tensor], *args) -> Dict[str, torch.Tensor]:
        ics_batch = batch["ics"]
        res_batch = batch["res"]

        coords_batch, u_batch, v_batch, w_batch = ics_batch

        t0 = torch.zeros((coords_batch.shape[0],), dtype=coords_batch.dtype, device=coords_batch.device)
        u_ic_pred = self.u_net(t0, coords_batch[:, 0], coords_batch[:, 1])
        v_ic_pred = self.v_net(t0, coords_batch[:, 0], coords_batch[:, 1])

        t_b = res_batch[:, 0]
        x_b = res_batch[:, 1]
        y_b = res_batch[:, 2]
        ru_pred, rv_pred, rc_pred = self.r_net(t_b, x_b, y_b)

        return {
            "uics": u_ic_pred - u_batch,
            "vics": v_ic_pred - v_batch,
            "ru": ru_pred,
            "rv": rv_pred,
            "rc": rc_pred,
        }

    def compute_metrics(
        self,
        t: torch.Tensor,
        coords: torch.Tensor,
        u_ref: torch.Tensor,
        v_ref: torch.Tensor,
        w_ref: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        TT = t[:, None].expand(t.shape[0], coords.shape[0]).reshape(-1)
        XX = coords[:, 0][None, :].expand(t.shape[0], coords.shape[0]).reshape(-1)
        YY = coords[:, 1][None, :].expand(t.shape[0], coords.shape[0]).reshape(-1)

        u_pred, v_pred, _ = self.net(TT, XX, YY)
        u_pred = u_pred.reshape_as(u_ref)
        v_pred = v_pred.reshape_as(v_ref)

        w_pred = self.w_net(TT, XX, YY).reshape_as(w_ref)

        u_err = torch.linalg.norm(u_ref - u_pred) / torch.linalg.norm(u_ref)
        v_err = torch.linalg.norm(v_ref - v_pred) / torch.linalg.norm(v_ref)
        w_err = torch.linalg.norm(w_ref - w_pred) / torch.linalg.norm(w_ref)
        return u_err, v_err, w_err

    def log_errors(
        self,
        t: torch.Tensor,
        coords: torch.Tensor,
        u_ref: torch.Tensor,
        v_ref: torch.Tensor,
        w_ref: torch.Tensor,
    ) -> Dict[str, torch.Tensor]:
        u_e, v_e, w_e = self.compute_metrics(t, coords, u_ref, v_ref, w_ref)
        return {"u_error": u_e.detach(), "v_error": v_e.detach(), "w_error": w_e.detach()}

    def _log_stats(self, batch: Dict[str, torch.Tensor], *args):
        res = self.residuals(batch)
        for key in ("ru", "rv", "rc"):
            stats = residual_stats_logging(res[key])
            for stat_name, stat_value in stats.items():
                self.log_dict[f"stats/{key}_{stat_name}"] = stat_value
