
from dataclasses import dataclass
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

# ----------------- utils -----------------
def sinkhorn(log_alpha, n_iters=3, eps=1e-6):

    log_s = log_alpha
    for _ in range(n_iters):
        log_s = log_s - torch.logsumexp(log_s, dim=-1, keepdim=True)  # rows
        log_s = log_s - torch.logsumexp(log_s, dim=-2, keepdim=True)  # cols
    P = torch.exp(log_s)
    P = torch.clamp(P, eps, 1.0)
    P = P / (P.sum(dim=-1, keepdim=True) + eps)
    P = P / (P.sum(dim=-2, keepdim=True) + eps)
    return P

def pairwise_cosine_topk(z, k):
    S = torch.einsum("bd,nd->bn", z, z)  # cosine sim as dot (assuming normalized)
    B = z.shape[0]
    S.fill_diagonal_(-1.0)  # exclude self
    k = min(k, B-1) if B > 1 else 0
    vals, idx = torch.topk(S, k=k, dim=1)
    return idx  # [B, k]

class QNet(nn.Module):
    def __init__(self, obs_dim, act_dim, feat_dim=64):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(obs_dim, 128), nn.ReLU(),
            nn.Linear(128, feat_dim), nn.ReLU(),
            nn.LayerNorm(feat_dim),
        )
        self.head = nn.Sequential(
            nn.Linear(feat_dim, 128), nn.ReLU(),
            nn.Linear(128, act_dim),
        )
        self.feat_dim = feat_dim
        self.act_dim = act_dim

    def features(self, x):
        return self.encoder(x)

    def head_from_features(self, z):
        return self.head(z)

    def forward(self, x):
        return self.head_from_features(self.features(x))

@dataclass
class SALOONConfig:
    obs_dim: int
    act_dim: int
    gamma: float = 0.99
    lr: float = 3e-4
    target_tau: float = 0.005
    eps_start: float = 1.0
    eps_end: float = 0.01
    eps_decay_steps: int = 10_000
    batch_size: int = 64
    device: str = "cpu"

    K: int = 2
    feat_dim: int = 64
    t_init: float = 2.0
    t_final: float = 0.5
    sinkhorn_iters: int = 3

    tau: float = 0.5
    k_nn: int = 5
    thr_w: float = 0.4
    p_keep: float = 0.2

    # isotonic
    T_iso: int = 1           
    eta_iso: float = 0.1
    delta_margin: float = 0.01

    # soft-closure sharpness
    mu_init: float = 1.0
    mu_final: float = 5.0

    lambda_eq: float = 0.5
    gamma_grp: float = 0.02
    lambda_iso: float = 0.1
    beta: float = 0.0
    lambda_perm: float = 0.02
    lambda_div: float = 0.01

    # warm-ups
    warmup_steps_eq: int = 5_000
    warmup_steps_iso: int = 10_000

    iso_every: int = 10
    knn_every: int = 10
    group_reg_every: int = 10
    qr_every: int = 20
    qr_thresh: float = 1e-2
    eq_subsample_frac: float = 0.5
    eq_subsample_min: int = 16
    eq_subsample_max: int = 128
    struct_every: int = 5   

    

# ----------------- SALOON -----------------
class SALOON:
    def __init__(self, cfg: SALOONConfig):
        self.cfg = cfg
        self.q = QNet(cfg.obs_dim, cfg.act_dim, feat_dim=cfg.feat_dim).to(cfg.device)
        self.q_tgt = QNet(cfg.obs_dim, cfg.act_dim, feat_dim=cfg.feat_dim).to(cfg.device)
        self.q_tgt.load_state_dict(self.q.state_dict())

        # Transforms W_k in feature space
        self.W = nn.ParameterList()
        I = torch.eye(cfg.feat_dim, device=cfg.device)
        for k in range(cfg.K):
            if k == 0:
                Wk = I.clone()
            else:
                A = torch.randn(cfg.feat_dim, cfg.feat_dim, device=cfg.device)
                Qm, _ = torch.linalg.qr(A, mode='reduced')
                Wk = Qm
            self.W.append(nn.Parameter(Wk))

        # Action relabel logits L_k (near identity)
        self.L = nn.ParameterList()
        eye_logits = torch.full((cfg.act_dim, cfg.act_dim), -2.5, device=cfg.device)
        eye_logits = eye_logits + torch.eye(cfg.act_dim, device=cfg.device) * 5.0
        for _ in range(cfg.K):
            self.L.append(nn.Parameter(eye_logits.clone()))

        self.opt = Adam(
            list(self.q.parameters()) + list(self.W.parameters()) + list(self.L.parameters()),
            lr=cfg.lr
        )
        self._step = 0
        self._nbr_cache = None   

    # -------- policy --------
    def act(self, obs_np, eval=False):
        self._step += 1
        eps = self.epsilon()
        if eval:
            eps = 0.0
        if np.random.rand() < eps:
            return np.random.randint(0, self.cfg.act_dim)
        with torch.no_grad():
            q = self.q(torch.tensor(obs_np, dtype=torch.float32, device=self.cfg.device).unsqueeze(0))
            return int(q.argmax(-1).item())

    def epsilon(self):
        t = min(self._step, self.cfg.eps_decay_steps)
        ratio = 1.0 - t / self.cfg.eps_decay_steps
        return self.cfg.eps_end + (self.cfg.eps_start - self.cfg.eps_end) * max(0.0, ratio)

    # -------- schedules --------
    def _temp(self):
        s = min(1.0, self._step / max(1, self.cfg.eps_decay_steps))
        return self.cfg.t_final + 0.5*(self.cfg.t_init - self.cfg.t_final)*(1+math.cos(math.pi*s))

    def _mu(self):
        s = min(1.0, self._step / max(1, self.cfg.eps_decay_steps))
        return self.cfg.mu_init + (self.cfg.mu_final - self.cfg.mu_init)*s

    def _lambda_iso(self):
        s = min(1.0, self._step / max(1, self.cfg.eps_decay_steps))
        return self.cfg.lambda_iso * (0.5 + 0.5*s)

    def _lambda_eq(self):
        s = min(1.0, max(0, self._step - self.cfg.warmup_steps_eq) / max(1, self.cfg.eps_decay_steps))
        return self.cfg.lambda_eq * s

    def _lambda_iso_runtime(self):
        s = min(1.0, max(0, self._step - self.cfg.warmup_steps_iso) / max(1, self.cfg.eps_decay_steps))
        return self._lambda_iso() * s

    def _lambda_perm_runtime(self):
        s = min(1.0, max(0, self._step - self.cfg.warmup_steps_eq) / max(1, self.cfg.eps_decay_steps))
        return self.cfg.lambda_perm * s


    # -------- losses --------
    def _td_target(self, rew, next_obs, done):

        with torch.no_grad():
            q_next_online = self.q(next_obs)                 
            a_star = q_next_online.argmax(dim=-1, keepdim=True)
            q_next_tgt = self.q_tgt(next_obs).gather(1, a_star).squeeze(1)
            y = rew + (1.0 - done) * self.cfg.gamma * q_next_tgt
        return y  # [B]

    def _equivariance(self, z, Qs, mask=None):
        """z: [B,d] features; Qs: [B,A] from current head(z)."""
        t = self._temp()
        eq = 0.0
        perm_entropy = 0.0
        for k in range(self.cfg.K):
            Wk = self.W[k]                        # [d,d]
            z_t = z @ Wk.T                        # [b,d]
            Qt  = self.q.head_from_features(z_t)  # [b,A]

            Pk = sinkhorn(self.L[k] / t, n_iters=self.cfg.sinkhorn_iters)  # [A,A]
            Qt_perm = Qt @ Pk.T                   

            eq += F.mse_loss(Qs, Qt_perm)

            eps = 1e-8
            row_ent = -(Pk * (Pk+eps).log()).sum(dim=-1).mean()
            col_ent = -(Pk * (Pk+eps).log()).sum(dim=-2).mean()
            perm_entropy = perm_entropy + row_ent + col_ent

        eq = eq / max(1, self.cfg.K)
        perm_entropy = perm_entropy / max(1, self.cfg.K)
        return eq, perm_entropy
        

    def _group_regs(self):
        """R_id, R_inv, R_ortho, R_clo (soft closure), R_div"""
        K, d = self.cfg.K, self.cfg.feat_dim
        I = torch.eye(d, device=self.cfg.device)

        R_id = torch.norm(self.W[0] - I, p='fro')**2

        R_ortho = 0.0
        for k in range(K):
            Wk = self.W[k]
            R_ortho = R_ortho + torch.norm(Wk.T @ Wk - I, p='fro')**2
        R_ortho = R_ortho / max(1, K)

        mu = self._mu()
        R_clo = 0.0
        for i in range(K):
            for j in range(K):
                comp = self.W[i] @ self.W[j]
                dists = []
                for m in range(K):
                    d2 = torch.norm(comp - self.W[m], p='fro')**2
                    dists.append(d2)
                D = torch.stack(dists)  # [K]
                w = torch.softmax(-mu * D.detach(), dim=0)
                R_clo = R_clo + (w * D).sum()
        R_clo = R_clo / max(1, K*K)

        R_div = 0.0
        cnt = 0
        for i in range(K):
            for j in range(i+1, K):
                R_div = R_div + torch.exp(-torch.norm(self.W[i] - self.W[j], p='fro')**2)
                cnt += 1
        if cnt > 0:
            R_div = R_div / cnt

        return R_id, R_ortho, R_clo, R_div
       

    def _build_pref_edges(self, z, Delta, y, Qnext_var):

        B = z.shape[0]
        if B <= 1:
            return []

        with torch.no_grad():
            Delta = Delta.detach()
            Qnext_var = Qnext_var.detach()
            w = torch.sigmoid(Delta / max(1e-6, self.cfg.tau)) * torch.exp(-Qnext_var)
            keep = (Delta > 0.0) & (w > self.cfg.thr_w)
            idx_all = torch.arange(B, device=z.device)
            idx_keep = idx_all[keep]
            if idx_keep.numel() == 0:
                return []

            # top-p% by w
            Kp = max(1, int(self.cfg.p_keep * idx_keep.numel()))
            w_keep = w[idx_keep]
            top_idx = torch.topk(w_keep, k=Kp).indices
            seeds = idx_keep[top_idx]  # [Kp]

            # normalize features
            zn = F.normalize(z, dim=-1)
            if (self._nbr_cache is None) or (self._nbr_cache.shape[0] != B):
                nbrs = pairwise_cosine_topk(zn, k=self.cfg.k_nn)  # [B, k]
                self._nbr_cache = nbrs
            else:
                nbrs = self._nbr_cache

            edges = []
            for u in seeds.tolist():
                cand = nbrs[u]
                for v in cand.tolist():
                    if Delta[u] > Delta[v]:
                        edges.append((u, v))
            edges = list(dict.fromkeys(edges))
        return edges
        

    def _isotonic_inner(self, V, edges):

        if len(edges) == 0:
            return V.detach()
        V_hat = V.detach().clone()
        if self.cfg.T_iso <= 0:
            return V_hat
        u = torch.tensor([e[0] for e in edges], device=V.device, dtype=torch.long)
        v = torch.tensor([e[1] for e in edges], device=V.device, dtype=torch.long)
        for _ in range(self.cfg.T_iso):
            viol = self.cfg.delta_margin + V_hat[v] - V_hat[u]   # [E]
            m = (viol > 0).float()
            if m.sum() == 0:
                break
            grad = torch.zeros_like(V_hat)
            scale = 1.0 / max(1, len(edges))
            grad.index_add_(0, u, -m * scale)
            grad.index_add_(0, v,  m * scale)
            V_hat = V_hat - self.cfg.eta_iso * grad
        return V_hat.detach()
        

    def update(self, batch, env=None):
        obs, act, rew, next_obs, done = (batch["obs"].to(self.cfg.device),
                                         batch["act"].to(self.cfg.device),
                                         batch["rew"].to(self.cfg.device),
                                         batch["next_obs"].to(self.cfg.device),
                                         batch["done"].to(self.cfg.device))
        
        # TD targets
        y = self._td_target(rew, next_obs, done)          # [B]
        Qnext = self.q_tgt(next_obs)                      # [B,A]
        Qnext_var = Qnext.var(dim=-1, unbiased=False)     # [B]

        # Current preds
        z = self.q.features(obs)                          # [B,d]
        Qs = self.q.head_from_features(z)                 # [B,A]
        V = Qs.max(-1).values                             # [B]
        q_taken = Qs.gather(1, act.view(-1,1)).squeeze(1) # [B]
        td_loss = F.mse_loss(q_taken, y)

       
        update_struct = (self._step % self.cfg.struct_every == 0)
        for Wk in self.W: Wk.requires_grad_(update_struct)
        for Lk in self.L: Lk.requires_grad_(update_struct)

        if update_struct and (self._step >= self.cfg.warmup_steps_eq):
            B = obs.shape[0]
            m = max(self.cfg.eq_subsample_min, int(B * self.cfg.eq_subsample_frac))
            m = min(m, self.cfg.eq_subsample_max, B)
            idx_sorted = torch.argsort(Qnext_var)[:m]
            eq_loss, perm_entropy = self._equivariance(z.index_select(0, idx_sorted),
                                                       Qs.index_select(0, idx_sorted))
        else:
            eq_loss = z.new_tensor(0.0); perm_entropy = z.new_tensor(0.0)

        if self._step % self.cfg.group_reg_every == 0:
            R_id, R_ortho, R_clo, R_div = self._group_regs()
        else:
            R_id = z.new_tensor(0.0)
            R_ortho = z.new_tensor(0.0)
            R_clo = z.new_tensor(0.0)
            R_div = z.new_tensor(0.0)
        group_reg = R_id + R_ortho + R_clo

        if self._step % self.cfg.iso_every == 0:
            if self._step % self.cfg.knn_every == 0:
                self._nbr_cache = None  
            Delta_taken = y - q_taken.detach()
            edges = self._build_pref_edges(z, Delta_taken, y, Qnext_var)
            V_hat = self._isotonic_inner(V, edges)
            iso_loss = F.mse_loss(V, V_hat)
            rank_loss = z.new_tensor(0.0)
        else:
            iso_loss = z.new_tensor(0.0)
            rank_loss = z.new_tensor(0.0)

        # === Total
        total_loss = td_loss \
             + self._lambda_eq() * eq_loss \
             + self.cfg.gamma_grp * group_reg \
             + self._lambda_perm_runtime() * perm_entropy \
             + self.cfg.lambda_div * R_div \
             + self._lambda_iso_runtime() * iso_loss \
             + self.cfg.beta * rank_loss

        self.opt.zero_grad(set_to_none=True)
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.q.parameters(), 10.0)
        self.opt.step()

        if self._step % self.cfg.qr_every == 0:
            with torch.no_grad():
                I = torch.eye(self.cfg.feat_dim, device=self.cfg.device)
                for Wk in self.W:
                    dev = torch.norm(Wk.T @ Wk - I, p='fro') / self.cfg.feat_dim
                    if dev >= self.cfg.qr_thresh:
                        q, _ = torch.linalg.qr(Wk.data, mode='reduced')
                        Wk.data.copy_(q)

        # ---- target update ----
        with torch.no_grad():
            for p, p_tgt in zip(self.q.parameters(), self.q_tgt.parameters()):
                p_tgt.data.mul_(1 - self.cfg.target_tau).add_(self.cfg.target_tau * p.data)

        self._step += 1

        return {
            "td": float(td_loss.item()),
            "eq": float(eq_loss.item()),
            "iso": float(iso_loss.item()),
            "rank": float(rank_loss.item()),
            "grp": float(group_reg.item()),
            "perm": float(perm_entropy.item()),
            "div": float(R_div.item()),
            "total": float(total_loss.item()),
        }

       