import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from latentplan.models.autoencoders import SymbolWiseTransformer
from latentplan.models.transformers import *
from latentplan.models.ein import EinLinear

# -----------------------------
# Low-level VQ ops (unchanged)
# -----------------------------
class VectorQuantization(Function):
    @staticmethod
    def forward(ctx, inputs, codebook):  # inputs: [..., D], codebook: [K, D]
        with torch.no_grad():
            D = codebook.size(1)
            in_sz = inputs.size()
            x = inputs.view(-1, D)
            cb_sqr = torch.sum(codebook ** 2, dim=1)                 # [K]
            x_sqr = torch.sum(x ** 2, dim=1, keepdim=True)           # [N,1]
            # dist = ||x||^2 + ||e||^2 - 2 x·e
            distances = torch.addmm(cb_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0)
            _, idx_flat = torch.min(distances, dim=1)                # [N]
            indices = idx_flat.view(*in_sz[:-1])                     # [...]
            ctx.mark_non_differentiable(indices)
            return indices

    @staticmethod
    def backward(ctx, grad_output):
        raise RuntimeError("VectorQuantization is not differentiable.")

class VectorQuantizationStraightThrough(Function):
    @staticmethod
    def forward(ctx, inputs, codebook):  # codebook: Tensor [K,D]
        idx = vq(inputs, codebook)                       # [...]
        idx_flat = idx.view(-1)
        ctx.save_for_backward(idx_flat, codebook)
        ctx.mark_non_differentiable(idx_flat)
        codes = torch.index_select(codebook, dim=0, index=idx_flat).view_as(inputs)
        return codes, idx_flat

    @staticmethod
    def backward(ctx, grad_output, grad_indices):
        grad_inputs = grad_output.clone() if ctx.needs_input_grad[0] else None
        grad_codebook = None
        if ctx.needs_input_grad[1]:
            idx_flat, codebook = ctx.saved_tensors
            D = codebook.size(1)
            go = grad_output.contiguous().view(-1, D)
            grad_codebook = torch.zeros_like(codebook)
            grad_codebook.index_add_(0, idx_flat, go)
        return grad_inputs, grad_codebook

vq = VectorQuantization.apply
vq_st = VectorQuantizationStraightThrough.apply


# --------------------------------------------------------
# EMA codebook that supports residual + masked/unmasked
# --------------------------------------------------------
class VQEmbeddingEMA(nn.Module):
    """
    EMA VQ codebook with straight-through assignment.
    - assign_from: tensor used to compute indices (usually residual of masked or full)
    - update_from: tensor used to update EMA (must be the *matching residual* stream)
    """
    def __init__(self, K, D, decay=0.99, eps=1e-5):
        super().__init__()
        emb = torch.empty(K, D)
        emb.uniform_(-1./K, 1./K)
        self.decay = decay
        self.eps = eps
        self.register_buffer("embedding", emb)          # [K,D]
        self.register_buffer("ema_count", torch.ones(K))
        self.register_buffer("ema_w", emb.clone())

    @torch.no_grad()
    def _ema_update(self, indices, x_update):           # x_update: [...,D] residual stream for this level
        # Efficient per-code sums
        K, D = self.embedding.size()
        x_flat = x_update.reshape(-1, D)
        idx = indices.reshape(-1)
        counts = torch.bincount(idx, minlength=K).float()
        self.ema_count.mul_(self.decay).add_((1 - self.decay) * counts)

        dw = torch.zeros_like(self.ema_w)               # [K,D]
        dw.index_add_(0, idx, x_flat)
        self.ema_w.mul_(self.decay).add_((1 - self.decay) * dw)

        denom = self.ema_count.add(self.eps).unsqueeze(-1)
        self.embedding.data.copy_(self.ema_w / denom)   # in-place, keeps buffer registration

    def straight_through(self, x_assign, x_update):
        """
        x_assign: residual tensor used to compute indices + ST codes
        x_update: matching residual tensor used for EMA update
        Returns: z_q (ST), z_bar (non-ST), indices
        """
        z_q, idx_flat = vq_st(x_assign.contiguous(), self.embedding)   # z_q gets grad to x_assign
        z_q = z_q.contiguous()
        if self.training:
            self._ema_update(idx_flat.view(*x_assign.shape[:-1]), x_update)
        # non-ST codes (no grad to x_assign)
        z_bar = torch.index_select(self.embedding, 0, idx_flat).view_as(x_assign).contiguous()
        return z_q, z_bar, idx_flat.view(*x_assign.shape[:-1])


class VQEmbedding(nn.Module):
    """Non-EMA codebook; kept for completeness (not used if ma_update=True)."""
    def __init__(self, K, D):
        super().__init__()
        self.embedding = nn.Embedding(K, D)
        self.embedding.weight.data.uniform_(-1./K, 1./K)

    def straight_through(self, x_assign):
        z_q, idx_flat = vq_st(x_assign.contiguous(), self.embedding.weight.detach())
        z_q = z_q.contiguous()
        z_bar = torch.index_select(self.embedding.weight, 0, idx_flat).view_as(x_assign).contiguous()
        return z_q, z_bar, idx_flat.view(*x_assign.shape[:-1])


class VQStepWiseTransformer(nn.Module):
    def __init__(self, config, feature_dim):
        super().__init__()
        self.K = config.K
        self.latent_size = config.trajectory_embd
        self.embedding_dim = config.n_embd
        self.trajectory_length = config.block_size
        self.block_size = config.block_size
        self.observation_dim = feature_dim
        self.action_dim = config.action_dim
        self.transition_dim = config.transition_dim
        self.latent_step = config.latent_step
        self.state_conditional = config.state_conditional

        self.masking = getattr(config, "masking", "none")
        self.bottleneck = getattr(config, "bottleneck", "pooling")
        self.n_levels = getattr(config, "n_levels", 2)
        self.ma_update = getattr(config, "ma_update", True)
        self.assign_from = getattr(config, "assign_from", "masked")   # "masked" or "full"
        self.update_from = getattr(config, "update_from", "full")     # "masked" or "full"
        self.commitment_beta = getattr(config, "commitment_beta", 0.25)

        # NEW: optionally append terminals into encoder input
        self.use_term_channel = getattr(config, "use_term_channel", True)
        self.encoder_input_dim = self.transition_dim

        self.encoder = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
        self.decoder = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
        self.pos_emb = nn.Parameter(torch.zeros(1, self.trajectory_length, self.embedding_dim))

        # project per-step inputs -> model width
        self.embed = nn.Linear(self.encoder_input_dim, self.embedding_dim)
        self.predict = nn.Linear(self.embedding_dim, self.transition_dim)
        self.cast_embed = nn.Linear(self.embedding_dim, self.latent_size)
        self.latent_mixing = nn.Linear(self.latent_size + self.observation_dim, self.embedding_dim)

        if self.bottleneck == "pooling":
            self.latent_pooling = nn.MaxPool1d(self.latent_step, stride=self.latent_step)
        elif self.bottleneck == "attention":
            self.latent_pooling = AsymBlock(config, self.trajectory_length // self.latent_step)
            self.expand = AsymBlock(config, self.trajectory_length)
        else:
            raise ValueError(f"Unknown bottleneck type {self.bottleneck}")

        self.ln_f = nn.LayerNorm(self.embedding_dim)
        self.drop = nn.Dropout(config.embd_pdrop)

        # residual codebooks
        # if self.ma_update:
        #     self.codebooks = nn.ModuleList([
        #         VQEmbeddingEMA(self.K, self.latent_size,
        #                        decay=getattr(config, "ema_decay", 0.99),
        #                        eps=getattr(config, "ema_eps", 1e-5))
        #         for _ in range(self.n_levels)
        #     ])
        # else:
        #     self.codebooks = nn.ModuleList([VQEmbedding(self.K, self.latent_size)
        #                                     for _ in range(self.n_levels)])
        # --- shared residual codebook across depths (paper setup) ---
        if self.ma_update:
            self.codebook = VQEmbeddingEMA(
                self.K, self.latent_size,
                decay=getattr(config, "ema_decay", 0.99),
                eps=getattr(config, "ema_eps", 1e-5)
            )
        else:
            self.codebook = VQEmbedding(self.K, self.latent_size)

        # optional: keep old behavior by cutting gradients between levels
        self.stop_grad_between_levels = getattr(config, "stop_grad_between_levels", False)

    def codebook_weight(self):
        """Return the tensor [K, latent_size] for tying with the AR prior."""
        if isinstance(self.codebook, VQEmbeddingEMA):
            return self.codebook.embedding
        else:
            return self.codebook.embedding.weight

    def _encode_core(self, enc_inputs):
        """enc_inputs shape: [B, T, encoder_input_dim]"""
        x = enc_inputs.to(dtype=torch.float32)
        B, T, _ = x.size()
        assert T <= self.block_size, "Block size exhausted."

        tok = self.embed(x)                              # [B,T,emb]
        pos = self.pos_emb[:, :T, :]
        h = self.drop(tok + pos)
        h = self.encoder(h)
        if self.bottleneck == "pooling":
            h = self.latent_pooling(h.transpose(1, 2)).transpose(1, 2)  # [B,T',emb]
        else:
            h = self.latent_pooling(h)                                   # [B,T',emb]
        z = self.cast_embed(h)                                           # [B,T',latent_size]
        return z

    def encode(self, joined_inputs, terminals=None):
        """
        Convenience wrapper when you want to call encode directly.
        If use_term_channel==True, terminals must be provided, and we will concat them.
        """
        if self.use_term_channel:
            assert terminals is not None, "terminals must be provided when use_term_channel=True"
            enc_inputs = torch.cat([joined_inputs, terminals], dim=2)
        else:
            enc_inputs = joined_inputs
        return self._encode_core(enc_inputs)

    def decode(self, latents, state):
        B, Tprime, _ = latents.shape
        state_flat = state.view(B, 1, -1).repeat(1, Tprime, 1)
        if not self.state_conditional:
            state_flat = torch.zeros_like(state_flat)
        x = torch.cat([state_flat, latents], dim=-1)
        x = self.latent_mixing(x)
        if self.bottleneck == "pooling":
            x = torch.repeat_interleave(x, self.latent_step, dim=1)
        else:
            x = self.expand(x)
        x = x + self.pos_emb[:, :x.shape[1]]
        x = self.decoder(x)
        x = self.ln_f(x)
        out = self.predict(x)                              # [B,T,transition_dim]
        out[:, :, -1] = torch.sigmoid(out[:, :, -1])       # terminals
        out[:, :, 1:self.observation_dim+1] += state.view(B, 1, -1)
        return out

    def forward(self, joined_inputs, state, terminals=None):
        # mask value (channel 0)
        feat_mask = torch.ones_like(joined_inputs)
        feat_mask[:, :, 0] = 0.0
        x_masked = joined_inputs * feat_mask

        # Build encoder inputs with/without terminal channel
        if self.use_term_channel:
            assert terminals is not None, "terminals must be provided when use_term_channel=True"
            enc_full   = torch.cat([joined_inputs, terminals], dim=2)
            enc_masked = torch.cat([x_masked,       terminals], dim=2)
        else:
            enc_full   = joined_inputs
            enc_masked = x_masked

        feat_full   = self._encode_core(enc_full)      # [B,T',D_lat]
        feat_masked = self._encode_core(enc_masked)    # [B,T',D_lat]

        # choose assignment/update streams
        assign_in = feat_masked if self.assign_from == "masked" else feat_full
        update_in = feat_full  if self.update_from  == "full"   else feat_masked

        # residual VQ
        res_assign = assign_in
        res_update = update_in
        quantised, z_bars = [], []
        for cb in self.codebooks:
            z_q, z_bar, _ = cb.straight_through(res_assign, res_update)
            quantised.append(z_q)
            z_bars.append(z_bar)
            res_assign = (res_assign - z_bar).detach()
            res_update = (res_update - z_bar).detach()

        latents_st = torch.stack(quantised, dim=0).sum(0)  # [B,T',D_lat]
        joined_pred = self.decode(latents_st, state)
        return joined_pred, quantised, feat_full, feat_masked



class VQContinuousVAE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.model = VQStepWiseTransformer(config, config.observation_dim)
        self.trajectory_embd = config.trajectory_embd
        self.vocab_size = config.vocab_size
        self.block_size = config.block_size
        self.observation_dim = config.observation_dim
        self.masking = getattr(config, "masking", "none")
        self.action_dim = config.action_dim
        self.trajectory_length = config.block_size
        self.transition_dim = config.transition_dim
        self.action_weight = config.action_weight
        self.reward_weight = config.reward_weight
        self.value_weight = config.value_weight
        self.position_weight = config.position_weight
        self.first_action_weight = config.first_action_weight
        self.sum_reward_weight = config.sum_reward_weight
        self.last_value_weight = config.last_value_weight
        self.latent_step = config.latent_step
        self.padding_vector = torch.zeros(self.transition_dim)
        self.apply(self._init_weights)

    def get_block_size(self): return self.block_size
    def set_padding_vector(self, padding): self.padding_vector = padding

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_(); module.weight.data.fill_(1.0)

    def configure_optimizers(self, train_config):
        decay, no_decay = set(), set()
        whitelist = (torch.nn.Linear, EinLinear)
        blacklist = (torch.nn.LayerNorm, torch.nn.Embedding)
        for mn, m in self.named_modules():
            for pn, p in m.named_parameters():
                fpn = f"{mn}.{pn}" if mn else pn
                if pn.endswith("bias"): no_decay.add(fpn)
                elif pn.endswith("weight") and isinstance(m, whitelist): decay.add(fpn)
                elif pn.endswith("weight") and isinstance(m, blacklist): no_decay.add(fpn)
        if isinstance(self.model, (SymbolWiseTransformer, VQStepWiseTransformer)):
            no_decay.add('model.pos_emb')
            if self.model.bottleneck == "attention":
                no_decay.update({'model.latent_pooling.query', 'model.expand.query',
                                 'model.latent_pooling.attention.in_proj_weight',
                                 'model.expand.attention.in_proj_weight'})
        param_dict = {pn: p for pn, p in self.named_parameters()}
        assert len(decay & no_decay) == 0
        assert len(param_dict.keys() - (decay | no_decay)) == 0
        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(decay)], "weight_decay": train_config.weight_decay},
            {"params": [param_dict[pn] for pn in sorted(no_decay)], "weight_decay": 0.0},
        ]
        return torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)

    @torch.no_grad()
    def encode(self, joined_inputs, terminals=None):
        """
        Returns indices for each residual level: [L, B, T'].
        """
        B, T, D = joined_inputs.size()
        # mask channel 0 for encoder invariance
        mask = torch.ones_like(joined_inputs); mask[:, :, 0] = 0.0
        xin = joined_inputs * mask

        # terminal padding for features (keeps context clean beyond episode end)
        if terminals is not None:
            padded = torch.as_tensor(self.padding_vector, dtype=torch.float32, device=xin.device).repeat(B, T, 1)
            tmask = (1 - terminals).repeat(1, 1, D)
            xin = xin * tmask + (1 - tmask) * padded

        # Build encoder input (concat terminals if requested by model)
        if self.model.use_term_channel:
            assert terminals is not None, "terminals required when use_term_channel=True"
            enc_in = torch.cat([xin, terminals], dim=2)
        else:
            enc_in = xin

        # run per-level NN assignments on masked (or full) stream based on config.assign_from
        feat_full   = self.model._encode_core(torch.cat([joined_inputs, terminals], dim=2) if self.model.use_term_channel else joined_inputs) if (self.model.assign_from == "full") else None
        feat_masked = self.model._encode_core(enc_in) if (self.model.assign_from == "masked") else None
        assign_in = feat_masked if self.model.assign_from == "masked" else feat_full

        idxs = []
        res = assign_in
        for cb in self.model.codebooks:
            emb = cb.embedding if isinstance(cb, VQEmbeddingEMA) else cb.embedding.weight
            idx = vq(res, emb)  # [B,T']
            idxs.append(idx)
            z_bar = F.embedding(idx.view(-1), emb).view_as(res)
            res = res - z_bar
        return torch.stack(idxs, dim=0)

    def decode(self, latent, state):
        return self.model.decode(latent, state)

    def decode_from_indices(self, indices, state):
        if indices.dim() == 2:
            indices = indices.unsqueeze(0)  # [1,B,T']
        L, B, Tprime = indices.shape
        Dlat = self.trajectory_embd
        latent = None
        for l in range(L):
            cb = self.model.codebooks[l]
            emb = cb.embedding if isinstance(cb, VQEmbeddingEMA) else cb.embedding.weight
            z = F.embedding(indices[l].reshape(-1), emb).view(B, Tprime, Dlat)
            latent = z if latent is None else (latent + z)
        if self.model.bottleneck == "attention":
            pad_T = self.trajectory_length // self.latent_step
            latent = torch.cat([latent, torch.zeros(B, pad_T, latent.size(-1), device=latent.device)], dim=1)
        return self.model.decode(latent, state[:, None, :])

    def forward(self, joined_inputs, targets=None, mask=None, terminals=None, returnx=False):
        joined_inputs = joined_inputs.to(dtype=torch.float32)
        x = joined_inputs
        B, T, D = x.size()

        # terminal padding into features
        padded = torch.as_tensor(self.padding_vector, dtype=torch.float32, device=x.device).repeat(B, T, 1)
        if terminals is not None:
            tmask = (1 - terminals).repeat(1, 1, D)
            x = x * tmask + (1 - tmask) * padded

        if mask is None:
            mask = torch.ones_like(x)
        mask[:, :, 0] = 0.0  # don't train on value channel by default

        # choose conditioning state (here: from the penultimate step)
        state = x[:, -2, 1:(self.observation_dim + 1)]

        # >>>>>>>>>>>>>  pass terminals through  <<<<<<<<<<<<<<
        reconstructed, quantised, feat_full, feat_masked = self.model(x, state, terminals=terminals)

        pred_traj  = reconstructed[:, :, :-1]          # [B,T,D]
        pred_terms = reconstructed[:, :, -1, None]

        # ----- losses -----
        #current_state_loss = next_state_loss = value_loss = first_action_loss = torch.tensor(0.0, device=x.device)
        #reconstruction_loss = loss_vq = loss_commit = None

        if targets is not None:
            weights = torch.cat([
                torch.ones(1, device=x.device) * self.value_weight,
                torch.ones(self.observation_dim, device=x.device),
                torch.ones(self.action_dim, device=x.device) * self.action_weight,
                torch.ones(self.action_dim, device=x.device) * self.action_weight,
                torch.ones(self.action_dim, device=x.device) * self.action_weight,
            ])

            mse_tail = F.mse_loss(pred_traj[:, -2:, :], joined_inputs[:, -2:, :], reduction='none') * weights[None, None, :]
            mse_ctx  = F.mse_loss(pred_traj[:, :-2, :], joined_inputs[:, :-2, :], reduction='none') * weights[None, None, :]

            ce = 0.0
            if terminals is not None:
                ce = F.binary_cross_entropy(pred_terms, torch.clamp(terminals.float(), 0.0, 1.0))

            first_action_loss = F.mse_loss(
                joined_inputs[:, -2, (1+self.observation_dim):(1+self.observation_dim+self.action_dim)],
                pred_traj[:,  -2, (1+self.observation_dim):(1+self.observation_dim+self.action_dim)]
            )
            value_loss = F.mse_loss(joined_inputs[:, :, 0].mean(dim=1), pred_traj[:, :, 0].mean(dim=1))
            current_state_loss = F.mse_loss(
                joined_inputs[:, -2, 1:(self.observation_dim+1)], pred_traj[:, -2, 1:(self.observation_dim+1)]
            )
            next_state_loss = F.mse_loss(
                joined_inputs[:, -1, 1:(self.observation_dim+1)], pred_traj[:, -1, 1:(self.observation_dim+1)]
            )

            if terminals is None:
                term_mask = torch.ones(B, T, 1, device=x.device, dtype=x.dtype)
            else:
                term_mask = 1 - terminals
            term_mask = term_mask.repeat(1, 1, D)

            reconstruction_loss = (mse_tail * mask[:, -2:, :] * term_mask[:, -2:, :]).mean() \
                                  + 0.1 * (mse_ctx * mask[:, :-2, :] * term_mask[:, :-2, :]).mean() \
                                  + ce \
                                  + value_loss + next_state_loss + first_action_loss

            # masked vs unmasked feature consistency
            loss_vq = F.mse_loss(feat_full, feat_masked)

            # residual commitment proxy with z_q
            loss_commit = 0.0
            res = feat_masked if self.model.assign_from == "masked" else feat_full
            for z_q in quantised:
                loss_commit += F.mse_loss(res, z_q.detach())
                res = res - z_q.detach()
            loss_commit = loss_commit / len(quantised) * self.model.commitment_beta

        return reconstructed, reconstruction_loss, loss_vq, loss_commit, current_state_loss, next_state_loss, value_loss, first_action_loss


class TransformerPrior(nn.Module):
    """
    Two-level residual prior:
      - trunk sees ONLY the coarse history (level-0 tokens) + a state vector
      - head0 predicts next coarse token
      - head1 predicts next fine token conditioned on the (teacher-forced) coarse target at the same position
    Forward supports an optional loss_mask over time positions.
    """
    def __init__(self, config):
        super().__init__()
        self.block_size = config.block_size        # context length in latent time
        self.n_embd     = config.n_embd
        self.n_layer    = config.n_layer
        self.embd_pdrop = config.embd_pdrop
        self.observation_dim = config.observation_dim

        # vocab sizes for level-0 (coarse) and level-1 (fine)
        self.K0 = getattr(config, "K0", config.K)  # fallback: same K for both levels
        self.K1 = getattr(config, "K1", config.K)

        # embeddings for the coarse token stream (input to trunk)
        self.tok_emb  = nn.Embedding(self.K0, self.n_embd)
        self.pos_emb  = nn.Parameter(torch.zeros(1, self.block_size, self.n_embd))
        self.state_emb= nn.Linear(self.observation_dim, self.n_embd)
        self.drop     = nn.Dropout(self.embd_pdrop)

        # transformer trunk (must be causal; reuse your Block(config))
        self.blocks = nn.Sequential(*[Block(config) for _ in range(self.n_layer)])
        self.ln_f   = nn.LayerNorm(self.n_embd)

        # heads
        self.head0 = nn.Linear(self.n_embd, self.K0, bias=False)      # coarse logits
        # embeddings to condition fine head on the coarse target at the SAME position (teacher forcing)
        self.coarse_target_emb = nn.Embedding(self.K0, self.n_embd // 2)
        self.fine_adapter = nn.Sequential(
            nn.Linear(self.n_embd + self.n_embd // 2, self.n_embd),
            nn.GELU(),
            nn.Linear(self.n_embd, self.n_embd)
        )
        self.head1 = nn.Linear(self.n_embd, self.K1, bias=False)      # fine logits

        self.apply(self._init_weights)

    def configure_optimizers(self, train_config):
        """
        This long function is unfortunately doing something very simple and is being very defensive:
        We are separating out all parameters of the model into two buckets: those that will experience
        weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
        We are then returning the PyTorch optimizer object.
        """

        # separate out all parameters to those that will and won't experience regularizing weight decay
        decay = set()
        no_decay = set()
        whitelist_weight_modules = (torch.nn.Linear, EinLinear)
        blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
        for mn, m in self.named_modules():
            for pn, p in m.named_parameters():
                fpn = '%s.%s' % (mn, pn) if mn else pn  # full param name

                if pn.endswith('bias'):
                    # all biases will not be decayed
                    no_decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                    # weights of whitelist modules will be weight decayed
                    decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                    # weights of blacklist modules will NOT be weight decayed
                    no_decay.add(fpn)

        # special case the position embedding parameter in the root GPT module as not decayed
        no_decay.add('pos_emb')

        # validate that we considered every parameter
        param_dict = {pn: p for pn, p in self.named_parameters()}
        inter_params = decay & no_decay
        union_params = decay | no_decay
        assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),)
        assert len(
            param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
                                                    % (str(param_dict.keys() - union_params),)

        # create the pytorch optimizer object
        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]
        optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
        return optimizer

    def _init_weights(self, m):
        if isinstance(m, (nn.Linear, nn.Embedding)):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.zeros_(m.bias); nn.init.ones_(m.weight)

    def _masked_ce(self, logits, targets, loss_mask=None, label_smoothing=0.0):
        """
        logits : [B, T, V]
        targets: [B, T] (Long)
        loss_mask: [B, T] (bool/float) or None
        """
        if loss_mask is None:
            if label_smoothing > 0:
                logp = F.log_softmax(logits, dim=-1)
                V = logits.size(-1)
                with torch.no_grad():
                    true = torch.zeros_like(logits).scatter_(-1, targets.unsqueeze(-1), 1.0)
                    smooth = label_smoothing / V
                    target_dist = (1.0 - label_smoothing) * true + smooth
                return -(target_dist * logp).sum(-1).mean()
            return F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))

        # masked
        logp = F.log_softmax(logits, dim=-1)                           # [B,T,V]
        nll  = -logp.gather(-1, targets.unsqueeze(-1)).squeeze(-1)     # [B,T]
        m = loss_mask.to(logits.dtype)
        denom = m.sum().clamp_min(1.0)
        return (nll * m).sum() / denom

    def forward(self, coarse_in, state, targets0=None, targets1=None,
                loss_mask=None, label_smoothing=0.0, return_logits=False):
        """
        coarse_in : [B, Tctx]   (level-0 tokens, context)
        state     : [B, obs_dim]
        targets0  : [B, Tctx] or None (next coarse tokens)
        targets1  : [B, Tctx] or None (next fine tokens)
        loss_mask : [B, Tctx] positions to include in loss (None = all)
        Returns: loss (and logits if return_logits=True)
        """
        B, Tctx = coarse_in.shape
        assert Tctx <= self.block_size, "Context too long for block_size."

        # trunk
        tok = self.tok_emb(coarse_in)                   # [B,T,D]
        pos = self.pos_emb[:, :Tctx, :]
        st  = self.state_emb(state).unsqueeze(1)        # [B,1,D]
        x = self.drop(tok + pos + st)
        x = self.blocks(x)
        x = self.ln_f(x)                                # [B,T,D]

        # coarse logits
        logits0 = self.head0(x)                         # [B,T,K0]

        loss = None
        logits1 = None
        if (targets0 is not None) and (targets1 is not None):
            # fine head is conditioned on the *ground-truth* coarse target at each position
            cte   = self.coarse_target_emb(targets0)    # [B,T,De/2]
            h_f   = self.fine_adapter(torch.cat([x, cte], dim=-1))
            logits1 = self.head1(h_f)                   # [B,T,K1]

            loss0 = self._masked_ce(logits0, targets0, loss_mask, label_smoothing)
            loss1 = self._masked_ce(logits1, targets1, loss_mask, label_smoothing)
            loss  = 0.5 * (loss0 + loss1)
        elif targets0 is not None:
            loss  = self._masked_ce(logits0, targets0, loss_mask, label_smoothing)

        if return_logits:
            return loss, (logits0, logits1)
        return loss

    def _last_hidden(self, coarse_ctx: torch.Tensor, state: torch.Tensor):
        """
        coarse_ctx: [B, Tctx] (Long)
        state     : [B, obs_dim] (Float)
        returns   : h_last [B, D] (Float)
        """
        B, T = coarse_ctx.shape
        tok = self.tok_emb(coarse_ctx)  # [B,T,D]
        pos = self.pos_emb[:, :T, :]
        st = self.state_emb(state).unsqueeze(1)  # [B,1,D]
        x = self.drop(tok + pos + st)
        x = self.blocks(x)
        x = self.ln_f(x)
        return x[:, -1, :]  # [B,D]

    @torch.no_grad()
    def sample_next_pairs(self,
                          coarse_ctx: torch.Tensor,
                          state: torch.Tensor,
                          num_samples: int = 10,
                          replacement: bool = True,
                          temperature: float = 1.0,
                          topk_coarse: int | None = None,
                          fine_shortlist: dict | list | None = None,
                          fine_policy: str = "argmax",  # "argmax" or "sample"
                          return_batch_pack: bool = True):
        """
        Sample S (=num_samples) (coarse, fine) pairs for the NEXT step.

        Args
          coarse_ctx      : Long [B, Tctx]
          state           : Float [B, obs_dim]
          num_samples     : int S
          replacement     : if False, coarse samples are unique (no replacement)
          temperature     : sampling temperature for both heads
          topk_coarse     : restrict coarse support to top-k (None = full)
          fine_shortlist  : optional map coarse_id -> LongTensor[<=K1] of allowed fine ids
          fine_policy     : "argmax" (fast/stable) or "sample" (stochastic fine)
          return_batch_pack: if True, also returns a packed form ([2,B*S,1], states)

        Returns dict with:
          coarse        : Long [B, S]
          fine          : Long [B, S]
          p_coarse      : Float [B, S]
          p_fine        : Float [B, S]
          p_joint       : Float [B, S]
          indices_time  : Long [2, B, S]             # ready for decode_from_indices (as time)
          indices_batch : Long [2, B*S, 1] (optional)
          state_batch   : Float [B*S, obs_dim] (optional)
        """
        self.eval()
        device = coarse_ctx.device
        B = coarse_ctx.size(0)
        S = int(num_samples)

        # 1) Trunk once → last hidden
        h = self._last_hidden(coarse_ctx, state)  # [B,D]

        # 2) Coarse probabilities (masked top-k if requested), softmax (not log)
        logits0 = self.head0(h) / max(1e-8, float(temperature))  # [B,K0]
        if topk_coarse is not None:
            k = min(topk_coarse, logits0.size(-1))
            topv, topi = torch.topk(logits0, k, dim=-1)  # [B,k]
            masked = torch.full_like(logits0, -float('inf'))
            masked.scatter_(dim=-1, index=topi, src=topv)
            logits0 = masked
        p0 = torch.softmax(logits0, dim=-1)  # [B,K0]

        # 3) Sample coarse ids with/without replacement
        if (not replacement) and topk_coarse is not None:
            S_eff = min(S, k)  # cannot sample more unique items than support size
        else:
            S_eff = S
        coarse_ids = torch.multinomial(p0, num_samples=S_eff, replacement=replacement)  # [B,S_eff]
        p_coarse = p0.gather(-1, coarse_ids)  # [B,S_eff]

        # If S_eff < S (no replacement with topk smaller than S), pad by re-sampling with replacement
        if S_eff < S:
            extra = torch.multinomial(p0, num_samples=(S - S_eff), replacement=True)  # [B,S-S_eff]
            coarse_ids = torch.cat([coarse_ids, extra], dim=-1)  # [B,S]
            p_extra = p0.gather(-1, extra)
            p_coarse = torch.cat([p_coarse, p_extra], dim=-1)  # [B,S]

        # 4) Fine conditional for each (B,S) coarse sample
        #    Vectorize by flattening samples into batch dimension
        c_flat = coarse_ids.reshape(-1)  # [B*S]
        e0 = self.coarse_target_emb(c_flat)  # [B*S, De/2]
        h_rep = h.repeat_interleave(S, dim=0)  # [B*S, D]
        logits1 = self.fine_adapter(torch.cat([h_rep, e0], dim=-1))
        logits1 = self.head1(logits1) / max(1e-8, float(temperature))  # [B*S, K1]

        # Optional fine shortlist masking (per row)
        if fine_shortlist is not None:
            masked = torch.full_like(logits1, -float('inf'))
            # loop over rows; S is usually small (<= ~32), so this is fine
            for r in range(c_flat.size(0)):
                allowed = fine_shortlist[int(c_flat[r].item())]  # LongTensor of fine ids
                masked[r, allowed] = logits1[r, allowed]
            logits1 = masked

        p1 = torch.softmax(logits1, dim=-1)  # [B*S, K1]

        # Choose fine ids
        if fine_policy == "sample":
            fine_flat = torch.multinomial(p1, num_samples=1, replacement=True).squeeze(-1)  # [B*S]
        else:  # "argmax"
            fine_flat = torch.argmax(p1, dim=-1)  # [B*S]

        p_fine_flat = p1.gather(-1, fine_flat.unsqueeze(-1)).squeeze(-1)  # [B*S]

        # Reshape back to [B,S]
        fine_ids = fine_flat.view(B, S)  # [B,S]
        p_fine = p_fine_flat.view(B, S)  # [B,S]
        p_joint = (p_coarse * p_fine)  # [B,S]

        # 5) Pack outputs for your VQ-VAE decode_from_indices
        indices_time = torch.stack([coarse_ids, fine_ids], dim=0)  # [2,B,S]

        out = {
            "coarse": coarse_ids,  # [B,S]
            "fine": fine_ids,  # [B,S]
            "p_coarse": p_coarse,  # [B,S]
            "p_fine": p_fine,  # [B,S]
            "p_joint": p_joint,  # [B,S]
            "indices_time": indices_time  # [2,B,S]
        }

        if return_batch_pack:
            # Flatten candidates into batch to decode S one-step proposals in parallel
            indices_batch = torch.stack(
                [coarse_ids.reshape(-1), fine_ids.reshape(-1)], dim=0
            ).unsqueeze(-1)  # [2, B*S, 1]
            state_batch = state.repeat_interleave(S, dim=0)  # [B*S, obs]
            out["indices_batch"] = indices_batch
            out["state_batch"] = state_batch

        return out