# %%
import os
from iclr_project.helpers import constants
os.environ["HF_HOME"] = constants.MODELS_DIR
import math, time, torch, torch.nn as nn, torch.nn.functional as F, numpy as np
from nnsight import NNsight
from iclr_project.helpers.typo_gen import typo_generator
from torch.utils.data import Dataset, DataLoader

from datasets import load_dataset
from collections import Counter
from tqdm import tqdm
import time
import matplotlib.pyplot as plt

from copy import deepcopy
import argparse
import json

## Parse args ##
parser = argparse.ArgumentParser()
parser.add_argument('--layer_idx', type=int, default=4, help='Layer index to extract activations from')
parser.add_argument('--hookpoint', type=str, choices=['resid', 'attn', 'mlp'], default="resid", help='Hookpoint in the transformer block')
parser.add_argument('--energy_band_width', type=int, choices=[10, 20, 50], default=20, help='Width of energy band for high-energy selection')
parser.add_argument('--ft_loss_mask', type=str, choices=["none", "typo_mask"], default="none", help='Type of loss masking during fine-tuning')

args = parser.parse_args()
layer_idx = args.layer_idx
hookpoint = args.hookpoint
energy_band_width = args.energy_band_width
ft_loss_mask = args.ft_loss_mask

print(f"Arguments: layer_idx={layer_idx}, hookpoint={hookpoint}, energy_band_width={energy_band_width}, ft_loss_mask={ft_loss_mask}")


GLOBAL_SAVE_DIR = os.path.join(constants.DATA_DIR, "processed", "toy_model")
GLOBAL_VIZ_DIR = os.path.join(constants.VIZ_DIR, "toy_model")

# %%
# ---- Config ----
latent_dim     = 4096
l1_coeff       = 2.5
steps          = 4000
batch_size     = 64
block_size     = 512
positions      = 8192       
log_every      = 100
device         = "cuda"

# %%
data_dir = os.path.join(constants.MODELS_DIR, "nanoGPT", 'data/tinystories_char')
train_mm = np.memmap(f'{data_dir}/train.bin', dtype=np.uint16, mode='r')
L = len(train_mm)

def get_batch():
    ix = torch.randint(0, L - block_size - 1, (batch_size,))
    x = torch.stack([torch.from_numpy(train_mm[i:i+block_size].astype(np.int64)) for i in ix])
    return x.to(device)


class SAE(nn.Module):
    def __init__(self, d, k):
        super().__init__()
        self.W = nn.Parameter(torch.empty(d, k))
        nn.init.kaiming_uniform_(self.W, a=math.sqrt(5))
        self.b_enc = nn.Parameter(torch.zeros(k))
        self.b_dec = nn.Parameter(torch.zeros(d))
    def forward(self, x):
        z = F.relu(x @ self.W + self.b_enc)
        xh = z @ self.W.t() + self.b_dec
        return z, xh

# %%
"""
Sample from a trained model
"""
import os
import pickle
from contextlib import nullcontext
import torch
from model import GPTConfig, GPT

init_from = 'resume' 
out_dir = os.path.join(constants.MODELS_DIR, "nanoGPT", "out-tinystories-char")
start = "Once upon a time" 
num_samples = 1 
max_new_tokens = 512 
temperature = 0.8 
top_k = 200 
seed = 666
device = 'cuda' 
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' 

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

ckpt_path = os.path.join(out_dir, 'ckpt.pt')
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)

model.eval()
model.to(device)

load_meta = True
if init_from == 'resume' and 'config' in checkpoint and 'dataset' in checkpoint['config']: # older checkpoints might not have these...
    meta_path = "metapath"
    load_meta = os.path.exists(meta_path)
if load_meta:
    print(f"Loading meta from {meta_path}...")
    with open(meta_path, 'rb') as f:
        meta = pickle.load(f)
    stoi, itos = meta['stoi'], meta['itos']
    encode = lambda s: [stoi[c] for c in s]
    decode = lambda l: ''.join([itos[i] for i in l])

trace_model = NNsight(model)


# %%
with torch.no_grad():
    toks = get_batch()
    with trace_model.trace(toks) as tr:
        if hookpoint == 'resid':
            act_node = trace_model.transformer.h[layer_idx].output.save()
        elif hookpoint == 'attn':
            act_node = trace_model.transformer.h[layer_idx].attn.output.save()
        elif hookpoint == 'mlp':
            act_node = trace_model.transformer.h[layer_idx].mlp.output.save()
    sample_act = act_node.value if hasattr(act_node, 'value') else act_node
    d_model = sample_act.shape[-1]
print(f'd_model={d_model}')

sae = SAE(d_model, latent_dim).to(device)
opt = torch.optim.AdamW(sae.parameters(), lr=1e-3)

# %%
ev_fn = lambda x, xr: 1 - F.mse_loss(xr, x) / (x.pow(2).mean() + 1e-8)
state = {k:0.0 for k in ['loss','recon','l1','L0','EV']}
ema = lambda k,v: state.__setitem__(k, 0.98*state[k] + 0.02*v)

t0 = time.time()
for step in range(1, steps+1):
    with torch.no_grad():
        toks = get_batch()
        with trace_model.trace(toks) as tr:
            if hookpoint == 'resid':
                act_node = trace_model.transformer.h[layer_idx].output.save()
            elif hookpoint == 'attn':
                act_node = trace_model.transformer.h[layer_idx].attn.output.save()
            elif hookpoint == 'mlp':
                act_node = trace_model.transformer.h[layer_idx].mlp.output.save()
        acts_btC = act_node.value if hasattr(act_node, 'value') else act_node  # (B,T,C)
        xs = acts_btC.reshape(-1, d_model).to(torch.float32)
        if xs.size(0) > positions:
            idx = torch.randint(0, xs.size(0), (positions,), device=xs.device)
            xs = xs[idx]

    z, xs_hat = sae(xs)
    recon = F.mse_loss(xs_hat, xs)
    l1 = z.abs().mean()
    loss = recon + l1_coeff * l1

    opt.zero_grad(set_to_none=True)
    loss.backward()
    opt.step()

    with torch.no_grad():
        EV = ev_fn(xs, xs_hat).item()
        active_counts = (z > 0).sum(dim=1).float() 
        L0_mean = active_counts.mean().item()
        ema('loss', loss.item()); ema('recon', recon.item()); ema('l1', l1.item()); ema('L0', L0_mean); ema('EV', EV)

    if step % log_every == 0 or step == 1:
        dt = (time.time() - t0)/step
        print(f"{step:5d} | loss {state['loss']:.3e} | recon {state['recon']:.3e} | l1 {state['l1']:.3e} | L0(avg_active) {state['L0']:.2f} | EV {state['EV']:.4f} | {dt:.3f}s/step")

# %%
# Save SAE
sae_path = os.path.join(out_dir, f'sae-l{layer_idx}-d{latent_dim}-l1{l1_coeff}.pt')
torch.save(sae.state_dict(), sae_path)
print(f"Saved SAE to {sae_path}")

# %%


# %%
# Load SAE
d_model = 512
sae = SAE(d_model, 4096).to(device)
sae_path = os.path.join(out_dir, f'sae-l{layer_idx}-d{latent_dim}-l1{l1_coeff}.pt')
sae.load_state_dict(torch.load(sae_path, map_location=device))
print(f"Loaded SAE from {sae_path}")
sae.eval()

# %%
def calibrate_energy_constants(sae: torch.nn.Module,
                               acts: torch.Tensor,
                               eps: float = 1e-6):
    """
    Parameters
    ----------
    sae  : trained SAE  (d → k)
    acts : 2-d tensor of training activations, shape (N, d)
    eps  : small number to avoid log(0)

    Returns
    -------
    sigma2    : scalar noise variance  (float)
    log_odds  : tensor of length k,  log((1-p_i)/p_i)
    """
    with torch.no_grad():
        z, x_hat = sae(acts)                       # encode–decode once
        sigma2   = (acts - x_hat).pow(2).mean()    # ⟨‖g‖²⟩
        p        = (z > 0).float().mean(dim=0)     # firing rates
        p        = p.clamp(eps, 1.0 - eps)         # keep in (0,1)
        log_odds = torch.log((1.0 - p) / p)

    return sigma2.item(), log_odds                # simple Python types

def energy_score(sae           : torch.nn.Module,
                 acts          : torch.Tensor,
                 sigma2        : float,
                 log_odds      : torch.Tensor):
    """
    Parameters
    ----------
    sae, acts  : as above
    sigma2     : variance returned by `calibrate_energy_constants`
    log_odds   : idem

    Returns
    -------
    E : 1-d tensor of per-sample energies  (N,)
        E = ‖g‖² / (2σ²)  +  Σ_i 1[z_i>0]·log_odds_i
    """
    with torch.no_grad():
        z, x_hat = sae(acts)
        g        = acts - x_hat
        rec_term = g.pow(2).sum(dim=1) / (2.0 * sigma2)
        sparsity = (z > 0).float() @ log_odds      # indicator · log-odds
        E        = rec_term + sparsity

    return E, rec_term, sparsity

dataset = load_dataset("roneneldan/TinyStories", cache_dir=os.path.join(constants.DATA_DIR, "hf_datasets"))

selected_dataset_train = dataset["train"].select(range(2000))
selected_dataset_val = dataset["validation"].select(range(1000))

train_string = "".join([example["text"] for example in selected_dataset_train])
train_normal_encoded = np.array(encode(train_string))

val_string = "".join([example["text"] for example in selected_dataset_val])
val_normal_encoded = np.array(encode(val_string))

# Calibrate energy constants on clean training data
# --- Prepare clean batches ---
clean_encoded = train_normal_encoded[: (train_normal_encoded.size // block_size) * block_size]
clean_batches = clean_encoded.reshape(-1, block_size)
num_clean_seqs = clean_batches.shape[0]

# --- Calibrate sigma2 and log-odds on clean activations (streaming over sequences) ---
eps = 1e-6
k = sae.b_enc.shape[0]
active_counts = torch.zeros(k, device=device)
sum_sse = 0.0
elem_count = 0
n_positions = 0

with torch.no_grad():
    for i in tqdm(range(num_clean_seqs), desc='Calibrating (clean)'):
        toks_batch = torch.from_numpy(clean_batches[i:i+1].astype(np.int64)).to(device)  # (1, T)
        with trace_model.trace(toks_batch) as tr:
            if hookpoint == 'resid':
                act_node = trace_model.transformer.h[layer_idx].output.save()
            elif hookpoint == 'attn':
                act_node = trace_model.transformer.h[layer_idx].attn.output.save()
            elif hookpoint == 'mlp':
                act_node = trace_model.transformer.h[layer_idx].mlp.output.save()
        acts_btC = act_node.value if hasattr(act_node, 'value') else act_node  # (1, T, C)
        acts_flat = acts_btC.reshape(-1, d_model).to(torch.float32)           # (T, C)

        z, x_hat = sae(acts_flat)
        g = acts_flat - x_hat

        sum_sse += g.pow(2).sum().item()
        elem_count += g.numel()
        active_counts += (z > 0).sum(dim=0).float()
        n_positions += z.size(0)

sigma2 = sum_sse / max(1, elem_count)
p = (active_counts / max(1, n_positions)).clamp(eps, 1.0 - eps)
log_odds = torch.log((1.0 - p) / p)

# --- Compute per-sequence mean energy for clean data (to compare with noisy) ---
seq_mean_energies_clean = []
with torch.no_grad():
    for i in tqdm(range(num_clean_seqs), desc='Energies (clean)'):
        toks_batch = torch.from_numpy(clean_batches[i:i+1].astype(np.int64)).to(device)
        with trace_model.trace(toks_batch) as tr:
            if hookpoint == 'resid':
                act_node = trace_model.transformer.h[layer_idx].output.save()
            elif hookpoint == 'attn':
                act_node = trace_model.transformer.h[layer_idx].attn.output.save()
            elif hookpoint == 'mlp':
                act_node = trace_model.transformer.h[layer_idx].mlp.output.save()
        acts_btC = act_node.value if hasattr(act_node, 'value') else act_node
        acts_flat = acts_btC.reshape(-1, d_model).to(torch.float32)
        E_seq, _, _ = energy_score(sae, acts_flat, sigma2, log_odds)
        seq_mean_energies_clean.append(E_seq.mean().item())

seq_mean_energies_clean = np.array(seq_mean_energies_clean)

# %%

for noise_level in [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100]:
    print(f'*=*=*=* Running Noise Level: {noise_level} *=*=*=*')

    # Generate noisy text (+ optional per-character mask)
    if ft_loss_mask == "typo_mask":
        train_string_noisy, train_mask_chars = typo_generator(train_string, percent_of_words=noise_level, seed=46, return_mask=True)
        val_string_noisy, val_mask_chars = typo_generator(val_string, percent_of_words=noise_level, seed=46, return_mask=True)
    else:
        train_string_noisy = typo_generator(train_string, percent_of_words=noise_level, seed=46)
        val_string_noisy = typo_generator(val_string, percent_of_words=noise_level, seed=46)
        train_mask_chars = None
        val_mask_chars = None

    train_noisy_encoded = np.array(encode(train_string_noisy))
    val_noisy_encoded = np.array(encode(val_string_noisy))

    # Convert masks to numpy arrays aligned with encoded chars (char-level tokenizer)
    train_noisy_mask = np.array(train_mask_chars, dtype=np.uint8) if train_mask_chars is not None else None
    val_noisy_mask = np.array(val_mask_chars, dtype=np.uint8) if val_mask_chars is not None else None
        
    if energy_band_width == 10:
        percentiles = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
    elif energy_band_width == 20:
        percentiles = [20, 40, 60, 80, 100]
    elif energy_band_width == 50:
        percentiles = [50, 100]

    # --- Compute energies for ALL noisy sequences ---
    # Rebuild toks tensor (ensure it's in scope)
    train_noisy_encoded = train_noisy_encoded[:(train_noisy_encoded.size // block_size) * block_size]
    noisy_batches = train_noisy_encoded.reshape(-1, block_size)
    num_seqs = noisy_batches.shape[0]

    # Align mask batches if available
    if train_noisy_mask is not None and train_noisy_mask.size >= train_noisy_encoded.size:
        train_noisy_mask = train_noisy_mask[:train_noisy_encoded.size]
        mask_batches = train_noisy_mask.reshape(-1, block_size)
    else:
        mask_batches = None

    seq_mean_energies = []
    keep_indices = []  # will decide later

    E_per_seq = []  # store per-sequence energy vectors if needed for analysis

    with torch.no_grad():
        for i in tqdm(range(num_seqs), desc='Computing energies (noisy)'):
            toks_batch = torch.from_numpy(noisy_batches[i:i+1].astype(np.int64)).to(device)  # (1, T)
            with trace_model.trace(toks_batch) as tr:
                if hookpoint == 'resid':
                    act_node = trace_model.transformer.h[layer_idx].output.save()
                elif hookpoint == 'attn':
                    act_node = trace_model.transformer.h[layer_idx].attn.output.save()
                elif hookpoint == 'mlp':
                    act_node = trace_model.transformer.h[layer_idx].mlp.output.save()
            acts_btC = act_node.value if hasattr(act_node, 'value') else act_node  # (1, T, C)
            acts_flat = acts_btC.reshape(-1, d_model).to(torch.float32)
            E_seq, rec_seq, sparsity_seq = energy_score(sae, acts_flat, sigma2, log_odds)
            mean_E = E_seq.mean().item()
            seq_mean_energies.append(mean_E)
            E_per_seq.append(E_seq.cpu())

    seq_mean_energies = np.array(seq_mean_energies)

    # --- Plot overlapping histogram: clean vs noisy sequence mean energies ---
    # Create structured output directories
    viz_root = os.path.join(GLOBAL_VIZ_DIR, f"{hookpoint}_{layer_idx}", f"noise_{noise_level:.2f}", f"band_{energy_band_width}", f"mask_{ft_loss_mask}")
    energies_dir = os.path.join(viz_root, "energies")
    curves_dir = os.path.join(viz_root, "curves")
    os.makedirs(energies_dir, exist_ok=True)
    os.makedirs(curves_dir, exist_ok=True)

    plt.figure(figsize=(10, 6))
    clean_E = seq_mean_energies_clean
    noisy_E = seq_mean_energies
    common_min = float(min(clean_E.min(), noisy_E.min())) if (len(clean_E) and len(noisy_E)) else 0.0
    common_max = float(max(clean_E.max(), noisy_E.max())) if (len(clean_E) and len(noisy_E)) else 1.0
    bins = np.linspace(common_min, common_max, 60)

    plt.hist(clean_E, bins=bins, density=True, alpha=0.5, label='Clean', color='tab:blue', edgecolor='white', linewidth=0.5)
    plt.hist(noisy_E, bins=bins, density=True, alpha=0.5, label=f'Noisy ({noise_level:.2f})', color='tab:orange', edgecolor='white', linewidth=0.5)
    if len(clean_E):
        plt.axvline(clean_E.mean(), color='tab:blue', linestyle='--', linewidth=1)
    if len(noisy_E):
        plt.axvline(noisy_E.mean(), color='tab:orange', linestyle='--', linewidth=1)

    plt.title(f'Sequence Mean Energy — Hook={hookpoint}, Layer={layer_idx}, Noise={noise_level:.2f}')
    plt.xlabel('Mean Energy per Sequence')
    plt.ylabel('Density')
    plt.legend(frameon=False)
    plt.grid(alpha=0.25)
    plt.tight_layout()
    hist_path = os.path.join(energies_dir, f'energy_hist_clean_vs_noisy.png')
    plt.savefig(hist_path, dpi=200)
    plt.show()
    plt.close()

    percentile_loss_dict = {}
    for percentile in percentiles:
        percentile_loss_dict[percentile]= {
            'train_loss': [],
            'val_noisy_loss': [],
            'val_clean_loss': []
        }

        print(f'============ Trying Percentile: {percentile} ============')

        energy_threshold_max = np.percentile(seq_mean_energies, percentile)
        energy_threshold_min = np.percentile(seq_mean_energies, percentile - energy_band_width)
        print(f"{percentile}th percentile energy threshold: {energy_threshold_max:.4f}")

        high_energy_mask = (seq_mean_energies > energy_threshold_min) & (seq_mean_energies <= energy_threshold_max)
        high_energy_indices = np.where(high_energy_mask)[0]
        print(f"Selected {high_energy_indices.size} / {num_seqs} sequences (>{energy_threshold_max:.4f})")

        high_energy_data = noisy_batches[high_energy_indices]  # numpy array (N_high, block_size)
        high_energy_mask_data = mask_batches[high_energy_indices] if mask_batches is not None else None
        print('High energy data shape:', high_energy_data.shape)

        # --- High Energy Fine-Tuning Setup ---
        # Build Dataset and DataLoader using only high_energy_data (already computed above)

        assert 'high_energy_data' in globals(), 'high_energy_data not found. Run previous cells first.'

        class HighEnergyDataset(Dataset):
            def __init__(self, arr: np.ndarray, mask: np.ndarray = None):
                # arr shape: (N, block_size)
                self.data = torch.from_numpy(arr.astype(np.int64))
                self.mask = torch.from_numpy(mask.astype(np.uint8)) if mask is not None else None
            def __len__(self):
                return self.data.size(0)
            def __getitem__(self, idx):
                seq = self.data[idx]  # (block_size,)
                x, y = seq[:-1], seq[1:]
                if self.mask is None:
                    return x, y
                m = self.mask[idx]            # (block_size,)
                y_mask = m[1:].bool()         # mask aligned to targets (True=typo)
                return x, y, y_mask

        he_dataset = HighEnergyDataset(high_energy_data, high_energy_mask_data)
        ft_batch_size = min(32, len(he_dataset))  # keep small & safe
        he_loader = DataLoader(he_dataset, batch_size=ft_batch_size, shuffle=True, drop_last=True)
        print(f'High-energy dataset size: {len(he_dataset)} sequences; loader batch_size={ft_batch_size}')

        # Reduce context by 1 for targets shift
        ft_ctx = high_energy_data.shape[1] - 1
        print(f'Fine-tune context length: {ft_ctx}')


        # --- Fine-Tuning Loop (High-Energy Only + Validation) ---
        # Uses the model's built-in causal LM loss by passing targets to forward
        ckpt_path = os.path.join(out_dir, 'ckpt.pt')
        checkpoint = torch.load(ckpt_path, map_location=device)
        gptconf = GPTConfig(**checkpoint['model_args'])
        model_ft = GPT(gptconf)
        state_dict = checkpoint['model']
        unwanted_prefix = '_orig_mod.'
        for k,v in list(state_dict.items()):
            if k.startswith(unwanted_prefix):
                state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
        model_ft.load_state_dict(state_dict)

        model_ft.to(device)

        model_ft.train()

        # Build validation loaders from val_noisy_encoded and val_normal_encoded
        assert 'val_noisy_encoded' in globals() and 'val_normal_encoded' in globals(), 'Validation arrays not found.'

        def build_val_loader(encoded_array: np.ndarray, mask_array: np.ndarray = None, batch_size: int = 32):
            arr = encoded_array[: (encoded_array.size // block_size) * block_size]
            if arr.size == 0:
                return None
            arr = arr.reshape(-1, block_size)
            if mask_array is not None and mask_array.size >= encoded_array.size:
                m_arr = mask_array[: (encoded_array.size // block_size) * block_size].reshape(-1, block_size)
            else:
                m_arr = None
            ds = HighEnergyDataset(arr, m_arr)
            bs = min(batch_size, len(ds))
            if len(ds) < 2:  # need at least 2 for drop_last; handle edge
                return DataLoader(ds, batch_size=bs, shuffle=False, drop_last=False)
            return DataLoader(ds, batch_size=bs, shuffle=False, drop_last=True)

        val_noisy_loader = build_val_loader(val_noisy_encoded, val_noisy_mask)
        val_clean_loader = build_val_loader(val_normal_encoded, None)

        # Optimizer (reuse helper). Conservative LR for small high-energy subset.
        ft_learning_rate = 1e-5
        optimizer = model_ft.configure_optimizers(weight_decay=0.1, learning_rate=ft_learning_rate, betas=(0.9, 0.95), device_type=device_type)

        max_steps = 2000  # adjust as desired
        log_every = 50
        scaler = torch.amp.GradScaler(device_type, enabled=(device_type=='cuda' and ptdtype!=torch.float32))

        seen_steps = 0
        running_loss = 0.0
        start_time = time.time()

        @torch.no_grad()
        def eval_loader(loader):
            if loader is None: return float('nan')
            model_ft.eval()
            losses = []
            for batch in loader:
                if isinstance(batch, (list, tuple)) and len(batch) == 3:
                    xb, yb, ymask = batch
                else:
                    xb, yb = batch
                    ymask = None
                xb, yb = xb.to(device), yb.to(device)
                if ymask is not None:
                    ymask = ymask.to(device)
                with torch.autocast(device_type=device_type, dtype=ptdtype, enabled=(device_type=='cuda')):
                    logits, _ = model_ft(xb, yb)  # get logits aligned with targets
                    if ft_loss_mask == 'typo_mask' and ymask is not None:
                        vocab = logits.size(-1)
                        per_tok = F.cross_entropy(logits.view(-1, vocab), yb.view(-1), reduction='none')
                        include = (~ymask).float().view(-1)
                        denom = include.sum().clamp_min(1.0)
                        loss_val = (per_tok * include).sum() / denom
                    else:
                        _, loss_val = model_ft(xb, yb)
                        loss_val = loss_val if torch.is_tensor(loss_val) else torch.tensor(loss_val, device=device)
                losses.append(loss_val.item())
            model_ft.train()
            return float(np.mean(losses)) if losses else float('nan')

        for epoch in range(max_steps):  # epoch is conceptual here
            for batch in he_loader:
                if isinstance(batch, (list, tuple)) and len(batch) == 3:
                    x, y, y_mask = batch
                else:
                    x, y = batch
                    y_mask = None
                x = x.to(device)
                y = y.to(device)
                if y_mask is not None:
                    y_mask = y_mask.to(device)
                optimizer.zero_grad(set_to_none=True)
                with torch.autocast(device_type=device_type, dtype=ptdtype, enabled=(device_type=='cuda')):
                    logits, base_loss = model_ft(x, y)
                    if ft_loss_mask == 'typo_mask' and y_mask is not None:
                        vocab = logits.size(-1)
                        per_tok = F.cross_entropy(logits.view(-1, vocab), y.view(-1), reduction='none')
                        include = (~y_mask).float().view(-1)
                        denom = include.sum().clamp_min(1.0)
                        loss = (per_tok * include).sum() / denom
                    else:
                        loss = base_loss
                scaler.scale(loss).backward()
                torch.nn.utils.clip_grad_norm_(model_ft.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()

                running_loss += loss.item()
                seen_steps += 1

                if seen_steps % log_every == 0:
                    avg_train_loss = running_loss / log_every
                    # Validation
                    val_noisy_loss = eval_loader(val_noisy_loader)
                    val_clean_loss = eval_loader(val_clean_loader)
                    tps = (seen_steps * ft_batch_size * (ft_ctx)) / (time.time() - start_time)
                    print(f'step {seen_steps:5d} | train {avg_train_loss:.4f} | val_noisy {val_noisy_loss:.4f} | val_clean {val_clean_loss:.4f} | toks/sec {tps:,.0f}')
                    running_loss = 0.0

                    percentile_loss_dict[percentile]['train_loss'].append(avg_train_loss)
                    percentile_loss_dict[percentile]['val_noisy_loss'].append(val_noisy_loss)
                    percentile_loss_dict[percentile]['val_clean_loss'].append(val_clean_loss)

                if seen_steps >= max_steps:
                    break
            if seen_steps >= max_steps:
                break

    print('Fine-tuning complete.')
    model_ft.eval()

    # === Save losses dict (pickle) ===
    save_root = os.path.join(GLOBAL_SAVE_DIR, f"{hookpoint}_{layer_idx}", f"noise_{noise_level:.2f}", f"band_{energy_band_width}", f"mask_{ft_loss_mask}")
    os.makedirs(save_root, exist_ok=True)

    # Directly pickle the full percentile_loss_dict (no extra metadata)
    with open(os.path.join(save_root, 'losses.pkl'), 'wb') as f:
        pickle.dump(percentile_loss_dict, f)
    
    print(f"Saved losses dict to {os.path.join(save_root, 'losses.pkl')}")

    # === Plotting loss curves separately ===
    reds = plt.cm.Reds(np.linspace(0.9, 0.3, len(percentiles)))

    # 1) Train
    plt.figure(figsize=(12, 8))
    for i, percentile in enumerate(percentiles):
        losses = percentile_loss_dict[percentile]
        steps_ax = np.arange(1, len(losses['train_loss']) + 1) * log_every
        if len(steps_ax):
            plt.plot(steps_ax, losses['train_loss'], label=f'Percentile {percentile}', color=reds[i])
    plt.xlabel('Optimization Steps')
    plt.ylabel('Cross-Entropy Loss')
    plt.title(f'Training Cross-Entropy vs Steps — Hook={hookpoint}, Layer={layer_idx}, Noise={noise_level:.2f}')
    plt.legend(frameon=False, ncol=2)
    plt.grid(alpha=0.25)
    plt.tight_layout()
    train_curve_path = os.path.join(curves_dir, 'loss_curve_train.png')
    plt.savefig(train_curve_path, dpi=200)
    plt.show(); plt.close()

    # 2) Validation (Noisy)
    plt.figure(figsize=(12, 8))
    for i, percentile in enumerate(percentiles):
        losses = percentile_loss_dict[percentile]
        steps_ax = np.arange(1, len(losses['val_noisy_loss']) + 1) * log_every
        if len(steps_ax):
            plt.plot(steps_ax, losses['val_noisy_loss'], label=f'Percentile {percentile}', color=reds[i])
    plt.xlabel('Optimization Steps')
    plt.ylabel('Cross-Entropy Loss')
    plt.title(f'Validation Cross-Entropy vs Steps (Noisy) — Hook={hookpoint}, Layer={layer_idx}, Noise={noise_level:.2f}')
    plt.legend(frameon=False, ncol=2)
    plt.grid(alpha=0.25)
    plt.tight_layout()
    noisy_curve_path = os.path.join(curves_dir, 'loss_curve_val_noisy.png')
    plt.savefig(noisy_curve_path, dpi=200)
    plt.show(); plt.close()

    # 3) Validation (Clean)
    plt.figure(figsize=(12, 8))
    for i, percentile in enumerate(percentiles):
        losses = percentile_loss_dict[percentile]
        steps_ax = np.arange(1, len(losses['val_clean_loss']) + 1) * log_every
        if len(steps_ax):
            plt.plot(steps_ax, losses['val_clean_loss'], label=f'Percentile {percentile}', color=reds[i])
    plt.xlabel('Optimization Steps')
    plt.ylabel('Cross-Entropy Loss')
    plt.title(f'Validation Cross-Entropy vs Steps (Clean) — Hook={hookpoint}, Layer={layer_idx}, Noise={noise_level:.2f}')
    plt.legend(frameon=False, ncol=2)
    plt.grid(alpha=0.25)
    plt.tight_layout()
    clean_curve_path = os.path.join(curves_dir, 'loss_curve_val_clean.png')
    plt.savefig(clean_curve_path, dpi=200)
    plt.show(); plt.close()

    # === Final loss vs percentile for each metric ===
    def _final_losses_for(metric):
        vals = []
        for p in percentiles:
            lst = percentile_loss_dict[p][metric]
            vals.append(lst[-1] if len(lst) else float('nan'))
        return vals

    # Train final
    plt.figure(figsize=(10, 6))
    final_train_losses = _final_losses_for('train_loss')
    plt.plot(percentiles, final_train_losses, marker='o')
    plt.xlabel('Upper Percentile of High-Energy Band')
    plt.ylabel('Final Cross-Entropy Loss')
    plt.title(f'Final Training Cross-Entropy vs Percentile — Hook={hookpoint}, Layer={layer_idx}, Noise={noise_level:.2f}')
    plt.grid(alpha=0.25)
    plt.tight_layout()
    plt.savefig(os.path.join(curves_dir, 'final_loss_vs_percentile_train.png'), dpi=200)
    plt.show(); plt.close()

    # Val noisy final
    plt.figure(figsize=(10, 6))
    final_noisy_losses = _final_losses_for('val_noisy_loss')
    plt.plot(percentiles, final_noisy_losses, marker='o')
    plt.xlabel('Upper Percentile of High-Energy Band')
    plt.ylabel('Final Cross-Entropy Loss (Noisy)')
    plt.title(f'Final Validation Cross-Entropy vs Percentile (Noisy) — Hook={hookpoint}, Layer={layer_idx}, Noise={noise_level:.2f}')
    plt.grid(alpha=0.25)
    plt.tight_layout()
    plt.savefig(os.path.join(curves_dir, 'final_loss_vs_percentile_val_noisy.png'), dpi=200)
    plt.show(); plt.close()

    # Val clean final
    plt.figure(figsize=(10, 6))
    final_clean_losses = _final_losses_for('val_clean_loss')
    plt.plot(percentiles, final_clean_losses, marker='o')
    plt.xlabel('Upper Percentile of High-Energy Band')
    plt.ylabel('Final Cross-Entropy Loss (Clean)')
    plt.title(f'Final Validation Cross-Entropy vs Percentile (Clean) — Hook={hookpoint}, Layer={layer_idx}, Noise={noise_level:.2f}')
    plt.grid(alpha=0.25)
    plt.tight_layout()
    plt.savefig(os.path.join(curves_dir, 'final_loss_vs_percentile_val_clean.png'), dpi=200)
    plt.show(); plt.close()
