from __future__ import annotations

from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from .rqs import RQSConfig, _params_to_spline, _rqs_forward_with_derivative


class RQSQuantile(nn.Module):
    """Batch-independent RQS quantile with shared spline parameters."""

    def __init__(
        self,
        dim: int,
        *,
        hidden: int = 128,
        depth: int = 2,
        n_bins: int = 64,
        bound: float = 10.0,
        num_layers: int = 1,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        del hidden, depth  # legacy arguments retained for compatibility

        self.dim = int(dim)
        self.eps = float(eps)
        self.cfg = RQSConfig(n_bins=int(n_bins), bound=float(bound))
        self.num_layers = int(num_layers)

        self.log_scale = nn.Parameter(torch.zeros(dim))
        self.bias = nn.Parameter(torch.zeros(dim))

        self.raw_w = nn.ParameterList()
        self.raw_h = nn.ParameterList()
        self.raw_s = nn.ParameterList()
        for _ in range(self.num_layers):
            self.raw_w.append(nn.Parameter(torch.zeros(1, dim, self.cfg.n_bins)))
            self.raw_h.append(nn.Parameter(torch.zeros(1, dim, self.cfg.n_bins)))
            self.raw_s.append(nn.Parameter(torch.zeros(1, dim, self.cfg.n_bins + 1)))

    def _expand_tau(self, tau: torch.Tensor, batch: int) -> torch.Tensor:
        if tau.dim() == 2 and tau.shape[1] == 1:
            return tau.expand(batch, self.dim)
        if tau.shape == (batch, self.dim):
            return tau
        raise ValueError("tau must be shaped (B,1) or (B,D)")

    def forward(
        self,
        u: torch.Tensor,
        tau: torch.Tensor,
        x_aux: Optional[torch.Tensor] = None,
        *,
        return_dqdt: bool = False,
        requires_grad: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor] | torch.Tensor:
        del x_aux
        batch, _ = u.shape
        tau_nd = self._expand_tau(tau, batch)
        q = self.base_quantile(u)
        eps = tau_nd * q
        if return_dqdt:
            if not requires_grad:
                return eps.detach(), q.detach()
            return eps, q
        return eps

    def base_quantile(self, u: torch.Tensor) -> torch.Tensor:
        q, _ = self._compute_quantile_and_jac(u, return_dq_du=False)
        return q

    def base_quantile_with_dqdu(self, u: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        q, dq_du = self._compute_quantile_and_jac(u, return_dq_du=True)
        if dq_du is None:
            raise RuntimeError("Expected dq/du when return_dq_du=True")
        return q, dq_du

    def _compute_quantile_and_jac(
        self,
        u: torch.Tensor,
        *,
        return_dq_du: bool,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        batch, dim = u.shape
        u_safe = u.clamp(self.eps, 1.0 - self.eps)
        z = torch.log(u_safe) - torch.log1p(-u_safe)

        if return_dq_du:
            dz_du = 1.0 / (u_safe * (1.0 - u_safe))
            dydz_total = torch.ones_like(u_safe)
        else:
            dz_du = None
            dydz_total = None

        z_curr = z

        for layer_idx in range(self.num_layers):
            xk, yk, widths, heights, slopes = _params_to_spline(
                self.raw_w[layer_idx],
                self.raw_h[layer_idx],
                self.raw_s[layer_idx],
                self.cfg,
            )

            if xk.shape[0] != batch:
                xk = xk.expand(batch, -1, -1)
                yk = yk.expand(batch, -1, -1)
                widths = widths.expand(batch, -1, -1)
                heights = heights.expand(batch, -1, -1)
                slopes = slopes.expand(batch, -1, -1)

            z_curr, dydz = _rqs_forward_with_derivative(
                z_curr,
                xk,
                yk,
                widths,
                heights,
                slopes,
                self.cfg,
                return_dydx=return_dq_du,
            )

            if return_dq_du and dydz is not None:
                dydz_total = dydz_total * dydz

        scale = F.softplus(self.log_scale) + 1e-4
        q = z_curr * scale.view(1, dim) + self.bias.view(1, dim)

        if return_dq_du and dydz_total is not None and dz_du is not None:
            dq_du = (scale.view(1, dim) * dydz_total) * dz_du
            return q, dq_du

        return q, None

    def diag_du(
        self,
        u: torch.Tensor,
        tau: torch.Tensor,
        x_aux: Optional[torch.Tensor],
        *,
        create_graph: bool = True,
    ) -> torch.Tensor:
        del x_aux
        batch, _ = u.shape
        tau_nd = self._expand_tau(tau, batch).detach()
        q, dq_du = self.base_quantile_with_dqdu(u)
        diag = tau_nd * dq_du
        if not create_graph:
            diag = diag.detach()
        return diag
