import argparse
import wandb
import numpy as np
import math
from typing import Tuple, List, Dict
from scipy.stats import spearmanr, pearsonr

import random
import torch
from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.distributions.categorical import Categorical

import sys
import io
import time
import functools


# Stream stdout line-buffered, utf-8
if hasattr(sys.stdout, "buffer"):
    sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', line_buffering=True)

# =====================
# Args
# =====================
parser = argparse.ArgumentParser()

# Task
parser.add_argument("--n", default=120, type=int)
parser.add_argument("--k", default=4, type=int)
parser.add_argument("--M_size", default=60, type=int)
parser.add_argument("--mode_threshold", default=30, type=int, help="Hamming distance threshold to count a mode as hit")
parser.add_argument("--reward_exponent", default=2.0, type=float)

# Seeds & device (MODIFIED: split model_seed and sampling_seed to align with set task)
parser.add_argument("--model_seed", default=0, type=int, help="Seed for model init & training randomness")
parser.add_argument("--sampling_seed", default=0, type=int, help="Seed used for sampling actions only")
parser.add_argument("--device", default='cuda', type=str)

# Train procedure
parser.add_argument("--num_iterations", default=50000, type=int)
parser.add_argument("--batch_size", default=16, type=int)
parser.add_argument("--learning_rate", default=1e-3, type=float)
parser.add_argument("--dropout", default=0.1, type=float)
parser.add_argument("--eval_every", default=50, type=int)
parser.add_argument("--n_test_steps", default=1000, type=int, help="Interval for test-set stats")
parser.add_argument("--corr_num_rounds", default=10, type=int, help="Rounds for test-set correlation rollouts")
parser.add_argument("--entropy_coeff", default=1.0, type=float, help="Temperature for policy in test correlation (kept from repo)")

# Objective: tb / db / subtb 
parser.add_argument("--objective", default='tb', type=str, choices=['tb','db','subtb'])
parser.add_argument("--z_learning_rate", default=1e-3, type=float, help="Only used for TB")
parser.add_argument("--subtb_lambda", default=1.9, type=float)

# Alpha (fixed; scheduler controls value over time like set task)
parser.add_argument("--alpha", default=0.5, type=float)
parser.add_argument("--use_alpha_scheduler", choices=[0,1], default=1, type=int)
parser.add_argument("--alpha_warm_frac", type=float, default=0.8)

# Epsilon (rand_action_prob) scheduling like set's exp_weight schedule
parser.add_argument("--rand_action_prob", default=0.001, type=float, help="epsilon for epsilon-greedy (kept original implementation)")
parser.add_argument("--use_exp_weight_decay", choices=[0,1], default=0, type=int, help="If 1, decay epsilon from 1.0 -> rand_action_prob")
parser.add_argument("--exp_weight_sched", type=str, default='linear', choices=['linear','cosine'], help="Schedule shape for epsilon")
parser.add_argument("--exp_weight_warm_frac", type=float, default=0.0, help="Warmup fraction for epsilon schedule")

# Gradient clip (MODIFIED: no boolean flag; enabled iff grad_clip_norm>0)
parser.add_argument("--grad_clip_norm", type=float, default=2, help="Global grad-norm clip threshold; <=0 disables")

# Sampling/Training model split
parser.add_argument("--steps_to_update_sampling_model", default=1, type=int, help="Copy trained->sampling every N steps (0/large = rare)")

# WandB
parser.add_argument('--wdb_project', default='AlphaGFN-Bitseq', type=str)
parser.add_argument("--wdb", choices=[0,1], default=1, type=int)

# =====================
# Utils & schedules
# =====================
def timer(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        start = time.perf_counter()
        result = func(*args, **kwargs)
        end = time.perf_counter()
        print(f"[{func.__name__}] Executed in {end-start:.6f}s", flush=True)
        return result
    return wrapper


def _dict_to_str(d: Dict) -> str:
    def _fmt(v):
        if isinstance(v, (float, np.floating)):
            return f"{float(v):.3f}"
        elif isinstance(v, (int, np.integer, str)):
            return str(v)
        else:
            return str(v)
    return ', '.join(f'{k}={_fmt(v)}' for k, v in d.items())

def process_bool_args(args):
    """
    Parse 0/1 args to int
    """
    args.wdb=bool(args.wdb)
    args.use_alpha_scheduler=bool(args.use_alpha_scheduler)
    args.use_exp_weight_decay=bool(args.use_exp_weight_decay)

    return args

def set_model_seed(seed: int, device: str):
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    random.seed(seed)

def tf(x, device):
    return torch.as_tensor(x, device=device)

class ExpWeightScheduler:
    """Anneal a value from 1.0 to a target (used here to schedule epsilon while keeping original epsilon-greedy)."""
    def __init__(self, end: float, total_steps: int, kind: str = "linear", warm_frac: float = 0.0):
        import math
        self.math = math
        self.start = 1.0
        self.end = float(end)
        self.T = max(1, int(total_steps))
        self.warm = int(max(0.0, min(1.0, warm_frac)) * self.T)
        self.kind = kind
    def __call__(self, step: int) -> float:
        t = max(0, min(step, self.T - 1))
        if t < self.warm:
            return self.start
        x = (t - self.warm) / max(1, (self.T - self.warm))
        if self.kind == "linear":
            w = self.start + (self.end - self.start) * x
        elif self.kind == "cosine":
            w = self.end + (self.start - self.end) * 0.5 * (1 + self.math.cos(self.math.pi * x))
        else:
            raise ValueError(f"Unknown exp_weight schedule kind: {self.kind}")
        return float(max(0.0, min(1.0, w)))

class AlphaScheduler:
    def __init__(self, total_steps: int, alpha0: float, warm_frac: float = 0.4,
                alpha_final: float = 0.5):
        import math
        self.math = math
        self.T = max(1, int(total_steps))
        self.a0 = float(alpha0)
        self.af = float(alpha_final)
        self.warm_frac = max(0.0, min(1.0, float(warm_frac)))
        self.T_hold = int(self.T * self.warm_frac)
        self.decay_k = 4.0
        self.poly_p = 0.5
    def __call__(self, step: int) -> float:
        t = max(0, min(step, self.T-1))
        if t < self.T_hold:
            return self.a0
        tail_t = t - self.T_hold
        tail_len = max(1, self.T - self.T_hold)
        r = self.math.exp(-self.decay_k * tail_t / tail_len)
        return self.af + (self.a0 - self.af) * r

# =====================
# Model
# =====================
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.2, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    def forward(self, x: Tensor) -> Tensor:
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

class TransformerModel(nn.Module):
    def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,
                 nlayers: int, seq_len: int, dropout: float = 0.2):
        super().__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(d_model, dropout, max_len=seq_len + 2)
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.embedding = nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.linear = nn.Linear(d_model, ntoken + seq_len + 1)
        # fixed alpha (non-trainable); value possibly scheduled per step
        self.alpha = nn.Parameter(torch.tensor(0.5), requires_grad=False)
    def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor:
        src = self.embedding(src)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        output = self.linear(output)
        return output

# =====================
# Task helpers
# =====================

def construct_M(n: int, b: int, H: List[str], M_size: int, seed: int = 0) -> List[str]:
    rng = np.random.default_rng(seed)
    M = []
    for _ in range(M_size):
        M.append("".join([rng.choice(H) for _ in range(n // b)]))
        assert len(M[-1]) == n
    return M

def distance(s1: str, s2: str) -> int:
    assert len(s1) == len(s2), f"{len(s1)} is not equal to {len(s2)}"
    return sum(1 for i in range(len(s1)) if s1[i] != s2[i])

def M_distance(s: str, M: List[str]) -> int:
    return min(distance(s, ms) for ms in M)

def construct_test_set(M: List[str], seed: int = 0) -> List[str]:
    """Original repo test-set constructor (ANALYZED):
    For each mode string s in M (|s| = n), the function:
    1) Includes s itself (Hamming distance 0), then
    2) For each cnt in 1..n-1, samples ONE random subset of indices of size `cnt`,
       flips those bits in s, and appends the resulting string.

    ➤ Implications:
       - The test set is *not* exhaustive over all \(inom{n}{cnt}\) perturbations. It contains exactly (n+1)
         items per s (one for each Hamming distance), so total size ~ |M|·(n+1).
       - Because for each cnt we choose *one* random subset, the test set spreads examples across all Hamming
         distances but with a single random representative per distance.
       - This randomness is independent of training and fixed by `seed`, providing a stable eval pack.

    We keep this function unchanged in behavior; analysis added in comments only.
    """
    rng = np.random.default_rng(seed)
    test_set = []
    for s in M:
        test_set.append(s)
        for cnt in range(1, len(s)):
            new_s = list(s)
            subset = rng.choice(len(s), size=cnt, replace=False)
            for i in subset:
                new_s[i] = '0' if s[i] == '1' else '1'
            cand = "".join(new_s)
            assert len(cand) == len(s) and distance(cand, s) == cnt
            test_set.append(cand)
    return test_set

def log_reward(s: str, M: List[str]) -> float:
    return -float(M_distance(s, M))

def reward(s: str, M: List[str]) -> float:
    return np.exp(log_reward(s, M))

def token_seq_to_str(seq: torch.Tensor, k: int) -> str:
    return "".join([format(int(v), f'0{k}b') for v in seq])

def batch_rewards(batch, M, k):
    """(Kept name & original semantics) Return raw reward(s) without exponent.
    NOTE: Do not apply reward_exponent here; callers will apply it (as in the original repo)."""
    batch_np = batch.detach().cpu().numpy()
    rewards = [reward(token_seq_to_str(batch_np[i], k), M) for i in range(batch_np.shape[0])]
    return torch.tensor(rewards)

def batch_log_rewards(batch, M, k):
    """(Kept name & original semantics) Return raw log-reward(s) without exponent."""
    batch_np = batch.detach().cpu().numpy()
    log_rewards = [log_reward(token_seq_to_str(batch_np[i], k), M) for i in range(batch_np.shape[0])]
    return torch.tensor(log_rewards)

# ----- logits processing & sampling -----

def process_logits(all_logits: torch.Tensor, pos_mask: torch.Tensor, args):
    """(Kept function name/signature.)
    MODIFIED: None in logic; identical to original except we use masked_fill for clarity.
    Returns (pos_logits, word_logits, sum_logits) like original.
    """
    pos_logits = all_logits[0, :, -(args.n // args.k + 1):]  # [batch_size, n/k + 1]
    pos_logits = pos_logits.masked_fill(pos_mask, -float("inf"))
    word_logits = all_logits[:, :, : 2 ** args.k]            # [n/k + 1, batch_size, 2^k]
    sum_logits = torch.moveaxis(word_logits, 1, 0) + pos_logits[:, :, None]  # [B, n/k+1, 2^k]
    sum_logits = sum_logits.reshape(pos_logits.shape[0], (args.n // args.k + 1) * (2 ** args.k))
    return pos_logits, word_logits, sum_logits

# NOTE: Keep original epsilon-greedy implementation.
# MODIFIED: now uses args.sampling_gen (torch.Generator) for reproducible RNG and scheduled args.rand_action_prob.
#           Function name and overall behavior remain unchanged.

def sample_forward(sum_logits, sum_uniform, batch, args):
    # There is a bug in pytorch that allows to sample objects that has 0 probability (happens very rarely but still happens).
    # This loop basically resamples until everything is correct.
    while True:
        actions = torch.multinomial(sum_logits.clone().softmax(dim=-1), 1, generator=args.sampling_gen).squeeze(-1).to(args.device)
        uniform_actions = torch.multinomial(sum_uniform.softmax(dim=-1),1, generator=args.sampling_gen).squeeze(-1).to(args.device)
        # use scheduled epsilon stored in args.rand_action_prob
        uniform_mask = torch.rand(args.batch_size, generator=args.sampling_gen, device=args.device) < args.rand_action_prob
        actions[uniform_mask] = uniform_actions[uniform_mask]
        positions = actions // (2 ** args.k)
        if (batch[range(args.batch_size), positions] == 2 ** args.k).sum() == args.batch_size:
            break
    assert positions.min() >= 1
    assert positions.max() <= args.n // args.k
    words = actions % (2 ** args.k)
    return actions, positions, words

# =====================
# Train / Eval / Test helpers (aligned with "set" code)
# =====================
@torch.no_grad()
def eval_batch_metrics(tr_model: nn.Module, args, batch_size: int, device: torch.device,
                       M=None, reward_exponent=None,
                       acc_unique_strs=None, acc_unique_Rs=None, mode_hits=None):
    """
    Evaluation with the *training* sampling routine:
    - Keep the original signature and return (metrics_dict, eval_tokens, eval_strings).
    - Reuse `sample_forward` for action sampling to match training behavior.
      (Temporarily disable epsilon: rand_action_prob=0.0 during eval.)
    - Compute forward/backward metrics and (optionally) reward/top-k/modes like before.
    """
    tr_model.eval()
    n, k = args.n, args.k
    T = n // k
    B = batch_size

    # [BOS, empty, ..., empty]
    batch = torch.full((B, T + 1), 2 ** k, device=device, dtype=torch.long)
    batch[:, 0] = 2 ** k + 1

    fwd_ent_list, fwd_taken_p_list = [], []
    back_ent_list, back_taken_p_list = [], []

    # Temporarily turn off epsilon for pure on-policy evaluation.
    _old_eps = args.rand_action_prob
    try:
        args.rand_action_prob = 0.0  # disable epsilon during eval

        for t in range(T):
            pos_mask = (batch != 2 ** k)

            # IMPORTANT: clone() avoids in-place versioning issues with Embedding indices.
            logits = tr_model(batch.T.clone())
            _, _, sum_logits = process_logits(logits, pos_mask, args)
            # Uniform branch logits (same trick as training)
            _, _, sum_uniform = process_logits(0.0 * logits, pos_mask, args)

            # Reuse the training sampler (now epsilon=0)
            actions, positions, words = sample_forward(sum_logits, sum_uniform, batch, args)

            # Apply step
            batch[torch.arange(B), positions] = words

            # Build a valid-action mask: only still-empty positions (exclude BOS) × all words
            valid_pos = (batch == 2 ** k)                 # [B, T+1], True => empty slot
            valid_pos[:, 0] = False                       # BOS is not selectable
            valid_action_mask = valid_pos.unsqueeze(-1).expand(-1, T + 1, 2 ** k).reshape(B, -1)

            # NaN-safe forward distribution: mask invalid logits by a large negative number (NOT -inf)
            masked_logits = sum_logits.masked_fill(~valid_action_mask, -1e30)
            logp = masked_logits - torch.logsumexp(masked_logits, dim=-1, keepdim=True)

            # Entropy over the VALID support only (no 0 * -inf)
            ent_each = -(logp.exp() * logp).sum(dim=-1)   # [B]
            ent_each[~torch.isfinite(ent_each)] = 0.0     # extra guard (shouldn't trigger now)
            fwd_ent_list.append(float(ent_each.mean().item()))

            # Probability of the taken action under the VALID forward distribution
            # Numerically-stable chosen-action probability: p(a) = exp(logit_a - logsumexp)
            taken_logp = sum_logits[torch.arange(B), actions] - torch.logsumexp(sum_logits, dim=-1)
            fwd_taken_p_list.append(float(taken_logp.exp().mean().item()))


            # Backward (parents-uniform) metrics
            back_ent_list.append(float(math.log(t + 1 + 1e-12)))
            back_taken_p_list.append(float(1.0 / (t + 1)))
    finally:
        args.rand_action_prob = _old_eps  # restore epsilon

    # Tokens (without BOS) and corresponding strings
    eval_tokens = batch[:, 1:]  # [B, T]
    eval_strings = [token_seq_to_str(seq, k) for seq in eval_tokens]

    # Step-wise metrics
    metrics = {
        'forward_policy_entropy_eval': float(np.mean(fwd_ent_list)),
        'forward_avg_action_prob_eval': float(np.mean(fwd_taken_p_list)),
        'backward_policy_entropy_eval': float(np.mean(back_ent_list)),
        'backward_avg_action_prob_eval': float(np.mean(back_taken_p_list)),
    }

    # Optional: reward on the current eval batch
    if (M is not None) and (reward_exponent is not None):
        avg_current_reward = float(((batch_rewards(eval_tokens, M, k).to(device)) ** reward_exponent).mean().item())
        metrics['avg_current_reward_eval'] = avg_current_reward

    # Optional: accumulate unique samples & compute top-k/modes like before
    if (M is not None) and (reward_exponent is not None) and \
       (acc_unique_strs is not None) and (acc_unique_Rs is not None):
        # Update unique set and their rewards
        for s in eval_strings:
            if s not in acc_unique_strs:
                acc_unique_strs.append(s)
                acc_unique_Rs.append((reward(s, M) ** reward_exponent))

        # Update modes
        if mode_hits is not None:
            for i, m in enumerate(M):
                if not mode_hits[i]:
                    if any(distance(m, s) <= args.mode_threshold for s in eval_strings):
                        mode_hits[i] = True

        # Compute top-k and summary stats from the accumulated uniques
        if len(acc_unique_Rs) > 0:
            arr = np.array(sorted(acc_unique_Rs, reverse=True))
            top100  = arr[:100]  if arr.size >= 100  else arr
            top1000 = arr[:1000] if arr.size >= 1000 else arr
            metrics.update({
                'mean_top_100_R':  float(np.mean(top100)),
                'mean_top_1000_R': float(np.mean(top1000)),
                'mean_R_all_unique': float(np.mean(arr)),
                'unique_samples_eval': int(len(acc_unique_strs)),
                'modes': int(sum(mode_hits)) if mode_hits is not None else 0,
            })
        else:
            metrics.update({
                'mean_top_100_R':  float('nan'),
                'mean_top_1000_R': float('nan'),
                'mean_R_all_unique': float('nan'),
                'unique_samples_eval': 0,
                'modes': int(sum(mode_hits)) if mode_hits is not None else 0,
            })

    return metrics, eval_tokens, eval_strings


@timer
@torch.no_grad()
def compute_test_metrics(model: nn.Module, M: List[str], test_set: List[str], args, rounds: int = 10, batch_size: int = 180) -> Dict[str, float]:
    """Compute Spearman/Pearson correlations on the test_set, plus forward/backward metrics.
    PERF: (1) replace per-step uniform-position sampling with a single random permutation per sample;
          (2) cache test_set tokenization on device to avoid Python string work inside loops.
    """
    device = torch.device(args.device)
    model.eval()
    assert len(test_set) % batch_size == 0, "test_set size must be divisible by batch_size for this evaluator"

    B = batch_size
    T = args.n // args.k
    V = 2 ** args.k
    N = len(test_set)

    # ---- (PERF) Cache test_set as [N, T] LongTensor of word tokens on device ----
    # We only build it once and reuse across calls.
    if not hasattr(args, "_test_tokens") or args._test_tokens is None \
       or args._test_tokens.shape != (N, T):
        # NOTE: simple Python loop is OK since we do it ONCE per run.
        toks = torch.empty((N, T), dtype=torch.long)
        for i, s in enumerate(test_set):
            toks[i] = torch.tensor([int(s[t*args.k:(t+1)*args.k], 2) for t in range(T)], dtype=torch.long)
        args._test_tokens = toks.to(device)  # cache on device

    test_tokens = args._test_tokens  # [N, T] on device

    # ---- accumulators ----
    p_forward_sums = torch.zeros(N, rounds, device=device)
    fwd_ent_accum: List[float] = []
    fwd_taken_prob_accum: List[float] = []
    back_ent_accum: List[float] = []
    back_taken_prob_accum: List[float] = []

    for r in range(rounds):
        for b_idx in range(N // B):
            # [BOS, empty, ..., empty]
            batch = torch.full((B, T + 1), V, device=device, dtype=torch.long)
            batch[:, 0] = V + 1

            # Indexes of samples in this mini-batch
            idx0 = b_idx * B
            idx1 = (b_idx + 1) * B
            gt_words = test_tokens[idx0:idx1]  # [B, T], ground-truth words for each position 0..T-1

            # ---- (PERF) Pre-sample a UNIFORM permutation of positions per sample ----
            # This is equivalent to: at each step choose uniformly among remaining empty slots.
            # We keep positions in "graph indexing" 1..T (since 0 is BOS).
            perm = torch.argsort(torch.rand(B, T, device=device), dim=1)        # [B, T], values in 0..T-1
            pos_seq = perm + 1                                                  # [B, T], values in 1..T
            # Pre-pick the ground-truth words following the same positional order
            words_seq = gt_words.gather(1, perm)                                # [B, T]

            for t in range(T):
                positions = pos_seq[:, t]                # [B], 1..T
                words     = words_seq[:, t]              # [B], 0..V-1

                # Compute logits for current partial states
                pos_mask = (batch != V)
                all_logits = model(batch.T.clone())
                _, _, sum_logits = process_logits(all_logits, pos_mask, args)   # [B, (T+1)*V]

                # Flattened action index (pos, word)
                actions = positions * V + words                                  # [B]

                # ===== NaN-safe forward metrics (masked logits over VALID actions only) =====
                valid_pos = (batch == V)      # [B, T+1]
                valid_pos[:, 0] = False       # BOS not selectable
                valid_action_mask = valid_pos.unsqueeze(-1).expand(-1, T + 1, V).reshape(B, -1)

                masked_logits = sum_logits.masked_fill(~valid_action_mask, -1e30)
                logp = masked_logits - torch.logsumexp(masked_logits, dim=-1, keepdim=True)

                ent_each = -(logp.exp() * logp).sum(dim=-1)      # [B]
                ent_each[~torch.isfinite(ent_each)] = 0.0
                fwd_ent_accum.append(float(ent_each.mean().item()))
                # Numerically-stable chosen-action probability: p(a) = exp(logit_a - logsumexp)
                taken_logp = sum_logits[torch.arange(B), actions] - torch.logsumexp(sum_logits, dim=-1)
                fwd_taken_prob_accum.append(float(taken_logp.exp().mean().item()))


                # ===========================================================================

                # PF(tau) contribution at temperature `entropy_coeff`
                norm_logits = sum_logits / args.entropy_coeff
                log_pf = norm_logits[torch.arange(B, device=device), actions] - torch.logsumexp(norm_logits, dim=-1)
                p_forward_sums[idx0:idx1, r] += log_pf

                # Backward (parents-uniform) metrics
                back_ent_accum.append(float(math.log(t + 1 + 1e-12)))
                back_taken_prob_accum.append(float(1.0 / (t + 1)))

                # Step: write the chosen words into the chosen positions
                batch[torch.arange(B, device=device), positions] = words

    # Aggregate correlations across rounds via logsumexp
    p_forward_sum = torch.logsumexp(p_forward_sums, dim=-1).detach().cpu().numpy()
    log_rewards = np.array([log_reward(s, M) for s in test_set]) * args.reward_exponent

    sp = spearmanr(log_rewards, p_forward_sum).statistic
    pr = pearsonr(log_rewards, p_forward_sum)[0]
    return {
        'spearman_corr_test': float(sp),
        'pearson_corr_test': float(pr),
        'forward_policy_entropy_test': float(np.mean(fwd_ent_accum)) if fwd_ent_accum else float('nan'),
        'forward_avg_action_prob_test': float(np.mean(fwd_taken_prob_accum)) if fwd_taken_prob_accum else float('nan'),
        'backward_policy_entropy_test': float(np.mean(back_ent_accum)) if back_ent_accum else float('nan'),
        'backward_avg_action_prob_test': float(np.mean(back_taken_prob_accum)) if back_taken_prob_accum else float('nan'),
    }


# =====================
# Train losses (keep original function names; note the MODIFIED comments)
# =====================
def TB_train_step(model, log_Z, optimizer, Z_optimizer, M, args, alpha):
    """MODIFIED: 
    - Sampling now uses args._sampling_model (separate sampling model) with the ORIGINAL epsilon-greedy via sample_forward().
    - The epsilon value is scheduled by updating args.rand_action_prob in the main loop (implementation unchanged).
    - Added optional global grad-norm clipping (if args.grad_clip_norm>0) and store per-iter stats in args._last_clip_*.
    - Loss computation still uses `model` (training model), consistent with set task alignment.
    Function name and external behavior preserved.
    """
    # This code is pretty simple because all trajectories in our graph have the same length.
    model.train()
    optimizer.zero_grad()
    Z_optimizer.zero_grad()

    n, k = args.n, args.k
    T = n // k
    batch = torch.tensor([[2 ** k + 1] + ([2 ** k] * (T)) for _ in range(args.batch_size)], device=args.device)
    p_forward_sum = torch.zeros(args.batch_size, device=args.device)
    p_backward_sum = torch.zeros(args.batch_size, device=args.device)

    for i in range(T):
        pos_mask = batch != 2 ** k
        # sample with sampling_model + original epsilon-greedy
        all_logits_samp = args._sampling_model(batch.T.clone())
        _, _, sum_logits_samp = process_logits(all_logits_samp, pos_mask, args)
        _, _, sum_uniform = process_logits(0.0 * all_logits_samp.clone(), pos_mask, args)
        actions, positions, words = sample_forward(sum_logits_samp, sum_uniform, batch, args)

        # forward prob under training model
        all_logits = model(batch.T.clone())
        _, _, sum_logits = process_logits(all_logits, pos_mask, args)
        p_forward_sum += sum_logits[range(args.batch_size), actions] - torch.logsumexp(sum_logits, dim=-1)
        p_backward_sum += torch.log(torch.tensor(1 / (i + 1), device=args.device)) - torch.log(torch.tensor(alpha/(1-alpha), device=args.device))

        batch[range(args.batch_size), positions] = words

    log_rewards = args.reward_exponent * batch_log_rewards(batch[:, 1:], M, k).to(args.device).detach()
    loss = (log_Z.sum() + p_forward_sum - p_backward_sum - log_rewards).pow(2).mean()

    loss.backward()
    # MODIFIED: global grad clip inside train step to keep function interface unchanged
    if args.grad_clip_norm > 0:
        params_to_clip = []
        for g in optimizer.param_groups:
            for p in g['params']:
                if p.grad is not None:
                    params_to_clip.append(p)
        total_norm = torch.nn.utils.clip_grad_norm_(params_to_clip, max_norm=args.grad_clip_norm)
        total_norm = float(total_norm.detach().cpu().item())
        tau = float(args.grad_clip_norm)
        args._last_clip_scale = float(tau / (total_norm + 1e-12)) if total_norm > tau else 1.0
        args._last_clip_trigger = 1.0 if total_norm > tau else 0.0
    else:
        args._last_clip_scale = 1.0
        args._last_clip_trigger = 0.0

    optimizer.step()
    Z_optimizer.step()

    assert batch[:, 1:].max() < 2 ** k
    return loss.cpu().item(), batch[:, 1:].cpu()
def DB_train_step(model, optimizer, M, args, alpha):
    """MODIFIED: 
    - Sampling now uses args._sampling_model (separate sampling model) with the ORIGINAL epsilon-greedy via sample_forward().
    - Epsilon value scheduled by main loop (args.rand_action_prob).
    - Added grad clipping inside (if args.grad_clip_norm>0), stats stored in args._last_clip_*.
    Function name preserved.
    """
    model.train()
    optimizer.zero_grad()

    n, k = args.n, args.k
    T = n // k
    batch = torch.tensor([[2 ** k + 1] + ([2 ** k] * T) for _ in range(args.batch_size)], device=args.device)
    loss_val = torch.tensor(0.0, device=args.device)
    pred_logits = None
    pred_f = None

    for i in range(T):
        pos_mask = batch != 2 ** k
        # sample with sampling_model + epsilon-greedy
        all_logits_samp = args._sampling_model(batch.T.clone())
        _, _, sum_logits_samp = process_logits(all_logits_samp, pos_mask, args)
        _, _, sum_uniform = process_logits(0.0 * all_logits_samp.clone(), pos_mask, args)
        actions, positions, words = sample_forward(sum_logits_samp, sum_uniform, batch, args)

        # training model forward
        all_logits = model(batch.T.clone())
        _, _, sum_logits = process_logits(all_logits, pos_mask, args)
        log_f = all_logits[0, :, 2 ** k]
        if pred_logits is not None:
            loss_val += (pred_f + pred_logits - log_f - torch.log(torch.tensor(1 / i, device=args.device)) + torch.log(torch.tensor(alpha/(1-alpha), device=args.device))).pow(2).mean()

        lp = torch.log_softmax(sum_logits, dim=-1)
        pred_logits = lp[torch.arange(args.batch_size), actions]
        pred_f = log_f

        batch[torch.arange(args.batch_size), positions] = words

    log_rewards = args.reward_exponent * batch_log_rewards(batch[:, 1:], M, k).to(args.device).detach()
    loss_val += (pred_f + pred_logits - log_rewards - torch.log(torch.tensor(1 / T, device=args.device)) + torch.log(torch.tensor(alpha/(1-alpha), device=args.device))).pow(2).mean()
    loss = loss_val / T

    loss.backward()
    if args.grad_clip_norm > 0:
        params_to_clip = []
        for g in optimizer.param_groups:
            for p in g['params']:
                if p.grad is not None:
                    params_to_clip.append(p)
        total_norm = torch.nn.utils.clip_grad_norm_(params_to_clip, max_norm=args.grad_clip_norm)
        total_norm = float(total_norm.detach().cpu().item())
        tau = float(args.grad_clip_norm)
        args._last_clip_scale = float(tau / (total_norm + 1e-12)) if total_norm > tau else 1.0
        args._last_clip_trigger = 1.0 if total_norm > tau else 0.0
    else:
        args._last_clip_scale = 1.0
        args._last_clip_trigger = 0.0

    optimizer.step()

    assert batch[:, 1:].max() < 2 ** k
    return loss.cpu().item(), batch[:, 1:].cpu()
def SubTB_train_step(model, optimizer, M, args, alpha):
    """MODIFIED:
    - Sampling uses args._sampling_model + original epsilon-greedy via sample_forward().
    - Epsilon scheduled in main via args.rand_action_prob.
    - Added grad clipping inside (if args.grad_clip_norm>0), stats stored in args._last_clip_*.
    Function name preserved.
    """
    model.train()
    optimizer.zero_grad()

    n, k = args.n, args.k
    T = n // k
    B = args.batch_size
    batch = torch.tensor([[2 ** k + 1] + ([2 ** k] * T) for _ in range(B)], device=args.device)

    log_pfs = torch.zeros(T + 1, B, device=args.device)
    log_pbs = torch.zeros(T + 1, B, device=args.device)
    log_fs = torch.zeros(T + 1, B, device=args.device)

    for i in range(T):
        pos_mask = batch != 2 ** k
        all_logits_samp = args._sampling_model(batch.T.clone())
        _, _, sum_logits_samp = process_logits(all_logits_samp, pos_mask, args)
        _, _, sum_uniform = process_logits(0.0 * all_logits_samp.clone(), pos_mask, args)
        actions, positions, words = sample_forward(sum_logits_samp, sum_uniform, batch, args)

        all_logits = model(batch.T.clone())
        _, _, sum_logits = process_logits(all_logits, pos_mask, args)
        log_fs[i] = all_logits[0, :, 2 ** k]
        lp = torch.log_softmax(sum_logits, dim=-1)
        log_pfs[i] = lp[torch.arange(B), actions]
        if i > 0:
            log_pbs[i] = torch.log(torch.tensor(1 / i, device=args.device))

        batch[torch.arange(B), positions] = words

    log_fs[-1] = args.reward_exponent * batch_log_rewards(batch[:, 1:], M, k).to(args.device).detach()
    log_pbs[-1] = torch.log(torch.tensor(1 / T, device=args.device))

    loss = torch.zeros((), device=args.device)
    total_lambda = torch.zeros((), device=args.device)
    for i in range(T + 1):
        for j in range(i + 1, T + 1):
            lmbd = args.subtb_lambda ** (j - i)
            acc = log_fs[i, :] + log_pfs[i:j, :].sum(dim=0) - log_fs[j, :] - log_pbs[i+1:j+1, :].sum(dim=0) + (j - i - 1) * torch.log(torch.tensor(alpha/(1-alpha), device=args.device))
            loss = loss + lmbd * (acc.pow(2).mean())
            total_lambda = total_lambda + lmbd
    loss = loss / (total_lambda + 1e-12)

    loss.backward()
    if args.grad_clip_norm > 0:
        params_to_clip = []
        for g in optimizer.param_groups:
            for p in g['params']:
                if p.grad is not None:
                    params_to_clip.append(p)
        total_norm = torch.nn.utils.clip_grad_norm_(params_to_clip, max_norm=args.grad_clip_norm)
        total_norm = float(total_norm.detach().cpu().item())
        tau = float(args.grad_clip_norm)
        args._last_clip_scale = float(tau / (total_norm + 1e-12)) if total_norm > tau else 1.0
        args._last_clip_trigger = 1.0 if total_norm > tau else 0.0
    else:
        args._last_clip_scale = 1.0
        args._last_clip_trigger = 0.0

    optimizer.step()

    assert batch[:, 1:].max() < 2 ** k
    return loss.cpu().item(), batch[:, 1:].cpu()


# =====================
# Main
# =====================

def main(args):
    device = torch.device(args.device)
    args.rand_action_prob_final=args.rand_action_prob  # for logging
    args.alpha_init=args.alpha  # for logging

    # Seeds
    set_model_seed(args.model_seed, args.device)
    args.sampling_gen = torch.Generator(device=device).manual_seed(args.sampling_seed)  # for reproducible sampling

    # Build M
    assert args.n % args.k == 0
    H = ["00000000", "11111111", "11110000", "00001111", "00111100"]
    assert args.n % len(H[0]) == 0
    M = construct_M(args.n, len(H[0]), H, args.M_size, seed=args.model_seed)
    test_set = construct_test_set(M, seed=args.model_seed)

    # Models: training vs sampling
    training_model = TransformerModel(ntoken=2 ** args.k + 2, d_model=64, d_hid=64, nhead=8, nlayers=3,
                                      seq_len=args.n // args.k, dropout=args.dropout).to(device)
    training_model.alpha.data = torch.tensor(args.alpha, device=device)
    sampling_model = TransformerModel(ntoken=2 ** args.k + 2, d_model=64, d_hid=64, nhead=8, nlayers=3,
                                      seq_len=args.n // args.k, dropout=args.dropout).to(device)
    sampling_model.load_state_dict(training_model.state_dict())
    args._sampling_model = sampling_model  # MODIFIED: pass sampling model via args to keep train-step function signatures

    # Optimizers
    params = [p for p in training_model.parameters() if p.requires_grad]
    optimizer = torch.optim.Adam(params, lr=args.learning_rate, weight_decay=1e-5)
    if args.objective == 'tb':
        log_Z = nn.Parameter(torch.zeros(64, device=device), requires_grad=True)
        Z_optimizer = torch.optim.Adam([log_Z], lr=args.z_learning_rate, weight_decay=1e-5)
    else:
        log_Z = None
        Z_optimizer = None

    # Schedulers
    if args.use_exp_weight_decay:
        eps_sched = ExpWeightScheduler(end=args.rand_action_prob, total_steps=args.num_iterations,
                                       kind=args.exp_weight_sched, warm_frac=args.exp_weight_warm_frac)
    if args.use_alpha_scheduler:
        alpha_sched = AlphaScheduler(total_steps=args.num_iterations, alpha0=args.alpha,
                                     warm_frac=args.alpha_warm_frac, alpha_final=0.5)

    # WandB
    wdb_name = f"k({args.k})_ms({args.model_seed})_ss({args.sampling_seed})_m({args.objective})_a({args.alpha})"
    if args.wdb:
        wandb.init(project=args.wdb_project, name=wdb_name)

    # Eval accumulators for top-k etc.
    all_eval_unique_strs: List[str] = []
    all_eval_unique_Rs: List[float] = []
    eval_mode_hits = [False] * len(M)

    # Grad clip accumulators
    clip_sum_trigger = 0.0
    clip_sum_scale = 0.0
    clip_count = 0

    last_eval_wall = time.time()

    for it in range(args.num_iterations):
        # Update alpha / epsilon schedules
        if args.use_alpha_scheduler:
            training_model.alpha.data = torch.tensor(alpha_sched(it), device=device)
        if args.use_exp_weight_decay:
            args.rand_action_prob = eps_sched(it)

        # Train step (compute loss)
        alpha_now = float(training_model.alpha.item())
        if args.objective == 'tb':
            loss, train_batch_tokens = TB_train_step(training_model, log_Z, optimizer, Z_optimizer, M, args, alpha_now)
        elif args.objective == 'db':
            loss, train_batch_tokens = DB_train_step(training_model, optimizer, M, args, alpha_now)
        elif args.objective == 'subtb':
            loss, train_batch_tokens = SubTB_train_step(training_model, optimizer, M, args, alpha_now)
        else:
            raise ValueError(f"Unknown objective {args.objective}")

        # accumulate clip stats (set by train-step functions)
        if hasattr(args, '_last_clip_trigger'):
            clip_sum_trigger += float(args._last_clip_trigger)
            clip_sum_scale += float(args._last_clip_scale)
            clip_count += 1

        # Periodically update sampling model
        if args.steps_to_update_sampling_model > 0 and it > 0 and (it % args.steps_to_update_sampling_model == 0):
            sampling_model.load_state_dict(training_model.state_dict())
            args._sampling_model = sampling_model

        # ----- EVAL BLOCK -----
        # On-policy generate samples with the training model
        eval_stats, eval_tokens, eval_strings = eval_batch_metrics(
                training_model, args, args.batch_size, device,
                M=M, reward_exponent=args.reward_exponent,
                acc_unique_strs=all_eval_unique_strs,
                acc_unique_Rs=all_eval_unique_Rs,
                mode_hits=eval_mode_hits,
            )
        # test-set metrics (compute original correlation, and ALSO extra test metrics side-by-side)
        if (it % args.n_test_steps == 0) or (it == args.num_iterations - 1):
            test_stats = compute_test_metrics(training_model, M, test_set, args, rounds=args.corr_num_rounds, batch_size=180)
            # print(f"test set reward correlation: {test_stats['spearman_correlation']}", flush=True)
        else:
            test_stats = {} 

        if (it % args.eval_every == 0) or (it == args.num_iterations - 1): # controls logging frequency
            now = time.time()
            eval_interval_sec = now - last_eval_wall  # seconds since previous evaluation
            last_eval_wall = now

            if args.wdb:
                wandb.log({"eval_interval_sec": eval_interval_sec}, step=it)
            # Generate eval batch with trained policy (no epsilon)


            # clip interval averages & reset
            if clip_count > 0:
                clip_trigger_rate_interval = clip_sum_trigger / clip_count
                clip_scale_avg_interval = clip_sum_scale / clip_count
            else:
                clip_trigger_rate_interval = 0.0
                clip_scale_avg_interval = 1.0
            clip_sum_trigger = 0.0
            clip_sum_scale = 0.0
            clip_count = 0

            if args.wdb:
                log_dict = {
                    'step': it,
                    'loss': float(loss if isinstance(loss, float) else float(loss)),
                    'rand_action_prob': float(args.rand_action_prob),
                    'alpha': float(alpha_now),
                    'clip_trigger_rate_interval': clip_trigger_rate_interval,
                    'clip_scale_avg_interval': clip_scale_avg_interval,
                    'log_Z': float(log_Z.mean().item()) if log_Z is not None else 0.0,
                }
                log_dict.update(eval_stats)
                log_dict.update(test_stats)
                print(f"task={wdb_name}, eval_interval={eval_interval_sec:.3f}s, "+_dict_to_str(log_dict), flush=True)
                if it == args.num_iterations - 1:
                    # log args
                    setting_dict = vars(args).copy()
                    # clean noisy keys
                    for k in ('alpha', 'rand_action_prob', 'wdb', 'device', 'wdb_project', 'sampling_gen', '_sampling_model', '_last_clip_trigger', '_last_clip_scale'):
                        setting_dict.pop(k, None)
                    log_dict.update(setting_dict)
                wandb.log(log_dict)

    # Final log of settings
    if args.wdb:
        wandb.finish()

    print(f'{wdb_name} finish training', flush=True)


if __name__ == '__main__':
    args = parser.parse_args()
    args=process_bool_args(args)
    main(args)
