from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple

import numpy as np
import torch

from phijax.torch.equations.base import IVP, residual_stats_logging
from phijax.torch.equations.registry import register_pde


@dataclass
class WindowState:
    idx: int
    v0: torch.Tensor
    u0: torch.Tensor
    p0: torch.Tensor
    T0: torch.Tensor
    num_steps: int
    max_steps: int


def get_dataset(ref_path: str):
    data = np.load(ref_path, allow_pickle=True).item()
    start_idx = 5

    velocity = np.asarray(data["velocity"])[start_idx:]
    pressure = np.asarray(data["pressure"])[start_idx:]
    temperature = np.asarray(data["temperature"])[start_idx:]

    t = np.asarray(data["t"])[start_idx:].reshape(-1)
    t = t - t[0]
    coords = np.asarray(data["coords"])

    alpha1 = float(np.asarray(data["alpha1"]).item())
    alpha2 = float(np.asarray(data["alpha2"]).item())
    alpha3 = float(np.asarray(data["alpha3"]).item())
    alpha4 = float(np.asarray(data["alpha4"]).item())

    Ra = float(np.asarray(data["Ra"]).item())
    Pr = float(np.asarray(data["Pr"]).item())
    Ge = float(np.asarray(data["Ge"]).item())

    return velocity, pressure, temperature, t, coords, alpha1, alpha2, alpha3, alpha4, Ra, Pr, Ge


def _grad(y: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    (g,) = torch.autograd.grad(
        y,
        x,
        grad_outputs=torch.ones_like(y),
        create_graph=True,
        retain_graph=True,
        allow_unused=False,
    )
    return g


def _laplace_scalar(f: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    fx = _grad(f, x)
    fxx = _grad(fx, x)
    fy = _grad(f, y)
    fyy = _grad(fy, y)
    return fxx + fyy


@register_pde("rayleightaylor", aliases=["rt"])
class RayleighTaylor(IVP):
    loss_keys = ("uics", "vics", "pics", "Tics", "ubcs", "vbcs", "Tbcs", "resu", "resv", "resc", "rese")

    def __init__(self, config: Any, model: torch.nn.Module, *, device: Optional[torch.device] = None):
        super().__init__(config, model, device=device)

        pcfg = self.config.pde_config
        (
            velocity,
            pressure,
            temperature,
            t_star,
            coords,
            alpha1,
            alpha2,
            alpha3,
            alpha4,
            Ra,
            Pr,
            Ge,
        ) = get_dataset(pcfg.ref_path)

        self.velocity = torch.as_tensor(velocity, dtype=torch.float32, device=self.device)
        self.pressure = torch.as_tensor(pressure, dtype=torch.float32, device=self.device)
        self.temperature = torch.as_tensor(temperature, dtype=torch.float32, device=self.device)
        self.t_star = torch.as_tensor(t_star, dtype=torch.float32, device=self.device)
        self.coords = torch.as_tensor(coords, dtype=torch.float32, device=self.device)

        self.u0 = self.velocity[0, :, 0]
        self.v0 = self.velocity[0, :, 1]
        self.T0 = self.temperature[0]
        self.p0 = self.pressure[0]

        self.alpha1 = float(alpha1)
        self.alpha2 = float(alpha2)
        self.alpha3 = float(alpha3)
        self.alpha4 = float(alpha4)
        self.Ra = float(Ra)
        self.Pr = float(Pr)
        self.Ge = float(Ge)

        vmag0 = torch.sqrt(self.u0.square() + self.v0.square())
        self.v_scale = float(vmag0.max().detach().cpu().item()) + 1e-3

        self._t_scale = float(self.t_star[-1].item()) if float(self.t_star[-1].item()) != 0.0 else 1.0

    def get_ref(self):
        return (
            self.velocity,
            self.pressure,
            self.temperature,
            self.t_star,
            self.coords,
            self.alpha1,
            self.alpha2,
            self.alpha3,
            self.alpha4,
        )

    def net(
        self, t: torch.Tensor, x: torch.Tensor, y: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        t_scaled = t / self._t_scale
        z = torch.stack([t_scaled, x, y], dim=-1)
        out = self.model(z)
        return out[..., 0], out[..., 1], out[..., 2], out[..., 3]

    def u_net(self, t: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        u, _, _, _ = self.net(t, x, y)
        return u

    def v_net(self, t: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        _, v, _, _ = self.net(t, x, y)
        return v

    def p_net(self, t: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        _, _, p, _ = self.net(t, x, y)
        return p

    def temp_net(self, t: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        _, _, _, temp = self.net(t, x, y)
        return temp

    def r_net(self, t: torch.Tensor, x: torch.Tensor, y: torch.Tensor):
        t_req = t.clone().detach().requires_grad_(True)
        x_req = x.clone().detach().requires_grad_(True)
        y_req = y.clone().detach().requires_grad_(True)

        u, v, p, temp = self.net(t_req, x_req, y_req)

        u_t = _grad(u, t_req)
        u_x = _grad(u, x_req)
        u_y = _grad(u, y_req)

        v_t = _grad(v, t_req)
        v_x = _grad(v, x_req)
        v_y = _grad(v, y_req)

        p_x = _grad(p, x_req)
        p_y = _grad(p, y_req)

        temp_t = _grad(temp, t_req)
        temp_x = _grad(temp, x_req)
        temp_y = _grad(temp, y_req)

        u_lap = _laplace_scalar(u, x_req, y_req)
        v_lap = _laplace_scalar(v, x_req, y_req)
        temp_lap = _laplace_scalar(temp, x_req, y_req)

        ru = u_t + u * u_x + v * u_y + p_x - self.alpha1 * u_lap
        rv = v_t + u * v_x + v * v_y + p_y - self.alpha1 * v_lap - self.alpha2 * temp
        rc = u_x + v_y
        re = temp_t + u * temp_x + v * temp_y - self.alpha4 * temp_lap

        return ru, rv, rc, re

    def residuals(self, batch: Dict[str, Any], *args) -> Dict[str, torch.Tensor]:
        ics_batch = batch["ics"]
        bcs_batch = batch["bcs"]
        res_batch = batch["res"]

        coords_batch, u_batch, v_batch, p_batch, temp_batch = ics_batch

        t0 = torch.zeros((coords_batch.shape[0],), dtype=coords_batch.dtype, device=coords_batch.device)
        u_ic_pred, v_ic_pred, p_ic_pred, temp_ic_pred = self.net(t0, coords_batch[:, 0], coords_batch[:, 1])

        u_bc_pred, v_bc_pred, _, temp_bc_pred = self.net(bcs_batch[:, 0], bcs_batch[:, 1], bcs_batch[:, 2])

        ru_pred, rv_pred, rc_pred, re_pred = self.r_net(res_batch[:, 0], res_batch[:, 1], res_batch[:, 2])

        return {
            "uics": u_ic_pred - u_batch,
            "vics": v_ic_pred - v_batch,
            "pics": p_ic_pred - p_batch,
            "Tics": temp_ic_pred - temp_batch,
            "ubcs": u_bc_pred,
            "vbcs": v_bc_pred,
            "Tbcs": temp_bc_pred,
            "resu": ru_pred,
            "resv": rv_pred,
            "resc": rc_pred,
            "rese": re_pred,
        }

    def _log_stats(self, batch: Dict[str, Any], *args):
        res = self.residuals(batch)
        for key in ("resu", "resv", "resc"):
            stats = residual_stats_logging(res[key])
            for stat_name, stat_value in stats.items():
                self.log_dict[f"stats/{key}_{stat_name}"] = stat_value
