"""
spatial_temporal_decoder.py

temporal decoders for graph-level latent z_g -> low-dim dynamics h_t -> node signals Y_hat

default:
- conditioning = "ic+context"

optional:
- ODETemporalDecoder requires torchdyn
"""

from __future__ import annotations

from typing import Optional, Tuple, Literal, Dict, Any

import torch
import torch.nn as nn
import torch.nn.functional as F

try:
    from torchdyn.core import NeuralODE as TorchDynNeuralODE
    _TORCHDYN_AVAILABLE = True
except Exception:
    _TORCHDYN_AVAILABLE = False


# readouts

class LinearReadout(nn.Module):
    def __init__(self, N: int, r: int):
        super().__init__()
        self.L = nn.Linear(r, N)

    def forward(self, h_seq: torch.Tensor) -> torch.Tensor:
        x_t = self.L(h_seq)        # [B,T,N]
        return x_t.transpose(1, 2) # [B,N,T]


class SpectralReadout(nn.Module):
    def __init__(self, U_k: torch.Tensor, r: int):
        super().__init__()
        if U_k.ndim != 2:
            raise ValueError("U_k must be [N,k]")
        N, k = U_k.shape
        self.register_buffer("U_k", U_k.clone())
        self.R = nn.Linear(r, k, bias=False)
        self.bias = nn.Parameter(torch.zeros(N))

    def forward(self, h_seq: torch.Tensor) -> torch.Tensor:
        coeff = self.R(h_seq)                               # [B,T,k]
        X = torch.einsum("nk, btk -> btn", self.U_k, coeff)  # [B,T,N]
        X = X + self.bias.view(1, 1, -1)
        return X.transpose(1, 2)                            # [B,N,T]


# RNN backend
class FiLMGRUCell(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, d_c: int):
        super().__init__()
        self.input_size = int(input_size)
        self.hidden_size = int(hidden_size)
        self.d_c = int(d_c)
        self.base = nn.GRUCell(input_size=self.input_size, hidden_size=self.hidden_size)
        self.x_adapter: Optional[nn.Linear] = None

    def forward(self, x_t: torch.Tensor, h_t: torch.Tensor, c: Optional[torch.Tensor]) -> torch.Tensor:
        if c is not None and c.numel() > 0:
            x_aug = torch.cat([x_t, c], dim=-1)
            if self.x_adapter is None:
                self.x_adapter = nn.Linear(x_aug.size(-1), self.input_size, bias=False).to(x_t.device)
            x_t = self.x_adapter(x_aug)
        return self.base(x_t, h_t)


class RNNTemporalDecoder(nn.Module):
    """
    discrete-time decoder with conditioning in {"ic", "ic+context", "ic+modulation"}.

    rtns
      h_seq: [B,T,r]
      Y_hat: [B,N,T]
    """

    def __init__(
        self,
        N: int,
        r: int,
        d_g: int,
        *,
        hidden_size: Optional[int] = None,
        conditioning: Literal["ic", "ic+context", "ic+modulation"] = "ic+context",
        d_c: int = 16,
        readout: Literal["linear", "spectral"] = "linear",
        U_k: Optional[torch.Tensor] = None,
        dropout_p: float = 0.0,
    ):
        super().__init__()
        self.N, self.r, self.d_g = int(N), int(r), int(d_g)
        self.conditioning = conditioning
        self.d_c = int(d_c) if conditioning != "ic" else 0

        self.ic = nn.Linear(self.d_g, self.r)
        self.ctx = nn.Linear(self.d_g, self.d_c) if self.d_c > 0 else None

        self.input_dim = self.d_c if conditioning == "ic+context" else 1
        H = int(hidden_size) if hidden_size is not None else self.r

        if conditioning == "ic+modulation":
            self.cell = FiLMGRUCell(self.input_dim, H, d_c=self.d_c)
        else:
            self.cell = nn.GRUCell(self.input_dim, H)

        self.h2h = nn.Linear(H, self.r)
        self.drop = nn.Dropout(float(dropout_p))

        if readout == "linear":
            self.readout = LinearReadout(self.N, self.r)
        elif readout == "spectral":
            if U_k is None:
                raise ValueError("U_k must be provided for spectral readout")
            self.readout = SpectralReadout(U_k, self.r)
        else:
            raise ValueError("readout must be {'linear','spectral'}")

    def init_from_z(self, z_g: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        h0 = self.ic(z_g)
        c = self.ctx(z_g) if self.ctx is not None else None
        return h0, c

    def unroll(self, h0: torch.Tensor, c: Optional[torch.Tensor], T: int) -> torch.Tensor:
        B = h0.size(0)
        h_hid = torch.tanh(h0)
        dummy = h0.new_zeros(B, 1)

        states = []
        for _ in range(int(T)):
            x_t = c if self.conditioning == "ic+context" else dummy
            if self.conditioning == "ic+modulation":
                h_hid = self.cell(x_t, h_hid, c)
            else:
                h_hid = self.cell(x_t, h_hid)
            h_lat = self.h2h(self.drop(h_hid))
            states.append(h_lat)

        return torch.stack(states, dim=1)  # [B,T,r]

    def forward(self, z_g: torch.Tensor, T: int) -> Tuple[torch.Tensor, torch.Tensor]:
        h0, c = self.init_from_z(z_g)
        h_seq = self.unroll(h0, c, T)
        Y_hat = self.readout(h_seq)
        return h_seq, Y_hat


# ODE backend
class _ContextualVectorField(nn.Module):
    def __init__(self, r: int, d_c: int, hidden: int = 64, layers: int = 2, modulation: bool = False):
        super().__init__()
        self.r, self.d_c = int(r), int(d_c)
        self.modulation = bool(modulation)

        in_dim = self.r + self.d_c
        dims = [in_dim] + [int(hidden)] * (int(layers) - 1) + [self.r]
        mods = []
        for i in range(len(dims) - 2):
            mods += [nn.Linear(dims[i], dims[i + 1]), nn.Tanh()]
        mods += [nn.Linear(dims[-2], dims[-1])]
        self.net = nn.Sequential(*mods)

        self.film = nn.Linear(self.d_c, 2 * self.r) if (self.modulation and self.d_c > 0) else None
        self._ctx: Optional[torch.Tensor] = None

    def set_context(self, c: Optional[torch.Tensor]) -> None:
        self._ctx = c

    def forward(self, t: torch.Tensor, h: torch.Tensor, args: Optional[Dict[str, Any]] = None) -> torch.Tensor:
        if self._ctx is not None and self._ctx.numel() > 0:
            c = self._ctx if self._ctx.shape[0] == h.shape[0] else self._ctx.expand(h.shape[0], -1)
            inp = torch.cat([h, c], dim=-1)
        else:
            c = None
            inp = h

        v = self.net(inp)
        if self.modulation and c is not None and self.film is not None:
            gamma, beta = self.film(c).chunk(2, dim=-1)
            v = gamma * v + beta
        return v


class ODETemporalDecoder(nn.Module):
    """
    continuous-time decoder using torchdyn.core.NeuralODE.
    conditioning in {"ic", "ic+context", "ic+modulation"}.
    """

    def __init__(
        self,
        N: int,
        r: int,
        d_g: int,
        *,
        conditioning: Literal["ic", "ic+context", "ic+modulation"] = "ic+context",
        d_c: int = 16,
        vf_hidden: int = 64,
        vf_layers: int = 2,
        readout: Literal["linear", "spectral"] = "linear",
        U_k: Optional[torch.Tensor] = None,
        solver: str = "rk4",
        atol: Optional[float] = None,
        rtol: Optional[float] = None,
        dropout_p: float = 0.0,
    ):
        super().__init__()
        if not _TORCHDYN_AVAILABLE:
            raise ImportError("torchdyn is not installed. Install with `pip install torchdyn` to use ODETemporalDecoder.")

        self.N, self.r, self.d_g = int(N), int(r), int(d_g)
        self.conditioning = conditioning
        self.d_c = int(d_c) if conditioning != "ic" else 0

        self.ic = nn.Linear(self.d_g, self.r)
        self.ctx = nn.Linear(self.d_g, self.d_c) if self.d_c > 0 else None

        self.vf = _ContextualVectorField(
            r=self.r,
            d_c=self.d_c,
            hidden=int(vf_hidden),
            layers=int(vf_layers),
            modulation=(conditioning == "ic+modulation"),
        )

        ode_kwargs = dict(solver=solver, sensitivity="adjoint")
        if atol is not None:
            ode_kwargs["atol"] = float(atol)
        if rtol is not None:
            ode_kwargs["rtol"] = float(rtol)

        self.ode = TorchDynNeuralODE(self.vf, **ode_kwargs)
        self.drop = nn.Dropout(float(dropout_p))

        if readout == "linear":
            self.readout = LinearReadout(self.N, self.r)
        elif readout == "spectral":
            if U_k is None:
                raise ValueError("U_k must be provided for spectral readout")
            self.readout = SpectralReadout(U_k, self.r)
        else:
            raise ValueError("readout must be {'linear','spectral'}")

    def init_from_z(self, z_g: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        h0 = self.ic(z_g)
        c = self.ctx(z_g) if self.ctx is not None else None
        return h0, c

    def unroll(self, h0: torch.Tensor, c: Optional[torch.Tensor], T: int) -> torch.Tensor:
        device = h0.device
        t_span = torch.linspace(0.0, 1.0, int(T), device=device)
        self.vf.set_context(c)

        out = self.ode(h0, t_span)
        traj = out[1] if isinstance(out, (tuple, list)) and len(out) == 2 else out

        B = h0.shape[0]
        if traj.shape == (int(T), B, self.r):
            h_seq = traj.transpose(0, 1).contiguous()
        elif traj.shape == (B, int(T), self.r):
            h_seq = traj
        else:
            dims = list(traj.shape)
            if self.r in dims and int(T) in dims and B in dims:
                bdim = dims.index(B)
                tdim = dims.index(int(T))
                rdim = dims.index(self.r)
                h_seq = traj.permute(bdim, tdim, rdim).contiguous()
            else:
                raise RuntimeError(f"Cannot parse ODE output shape {tuple(traj.shape)} for B={B}, T={T}, r={self.r}")

        return self.drop(h_seq)

    def forward(self, z_g: torch.Tensor, T: int) -> Tuple[torch.Tensor, torch.Tensor]:
        h0, c = self.init_from_z(z_g)
        h_seq = self.unroll(h0, c, T)
        Y_hat = self.readout(h_seq)
        return h_seq, Y_hat
