"""
Joint training: Tokenizer + Market JEPA (end-to-end).

This script trains the tokenizer and JEPA together, with the JEPA loss
backpropagating through the tokenizer. After training, use cache_tokens.py
to freeze the tokenizer outputs.

Usage:
    python -m finjepa.scripts.train_joint --config finjepa/configs/joint.yaml
    python -m finjepa.scripts.train_joint --config finjepa/configs/joint.yaml --wandb
"""

import argparse
import gc
import itertools
import math
import random
import time
from pathlib import Path

import numpy as np
import torch
import yaml
from torch.utils.data import DataLoader, IterableDataset, get_worker_info

try:
    import wandb
    WANDB_AVAILABLE = True
except ImportError:
    WANDB_AVAILABLE = False

from equitiesjepa.models import MarketJEPA, EquityTokenizer, update_ema, forward_target, compute_loss
from equitiesjepa.data import CrossSectionDataset, MarketMaskCollator


def collate_clips(batch):
    """Collate function for temporal clips (module-level for multiprocessing)."""
    X = torch.stack([b["X"] for b in batch])
    mask = torch.stack([b["mask"] for b in batch])
    strides = [b["stride"] for b in batch]
    stride = strides[0]
    if not all(s == stride for s in strides):
        raise RuntimeError(f"Mixed strides: {set(strides)}")
    return {"X": X, "mask": mask, "stride": stride}


# Schedulers (from V-JEPA 2: vjepa2/src/utils/schedulers.py)

class WSDSchedule:
    """
    Warmup-Stable-Decay schedule from V-JEPA 2.
    
    - warmup_steps: linear warmup from start_lr to ref_lr
    - T_max (stable phase): constant LR at ref_lr  
    - anneal_steps: linear decay from ref_lr to final_lr
    
    This is more robust than cosine when extending training,
    as you can keep constant LR until probe plateaus, then cooldown.
    """
    def __init__(self, optimizer, warmup_steps, anneal_steps, T_max, start_lr, ref_lr, final_lr=0.0):
        self.optimizer = optimizer
        self.start_lr = start_lr
        self.ref_lr = ref_lr
        self.final_lr = final_lr
        self.anneal_steps = anneal_steps
        self.warmup_steps = warmup_steps
        self.T_max = T_max - warmup_steps - anneal_steps  # Constant phase length
        self._step = 0

    def step(self):
        self._step += 1
        if self._step < self.warmup_steps:
            # Warmup: linear increase
            progress = float(self._step) / float(max(1, self.warmup_steps))
            new_lr = self.start_lr + progress * (self.ref_lr - self.start_lr)
        elif self._step < self.T_max + self.warmup_steps:
            # Stable: constant LR
            new_lr = self.ref_lr
        else:
            # Decay: linear decrease (clamped to prevent going below final_lr)
            _step = self._step - (self.T_max + self.warmup_steps)
            progress = min(1.0, float(_step) / float(max(1, self.anneal_steps)))
            new_lr = self.ref_lr + progress * (self.final_lr - self.ref_lr)
            new_lr = max(self.final_lr, new_lr)  # Clamp

        for group in self.optimizer.param_groups:
            group["lr"] = new_lr
            if "lr_scale" in group:
                group["lr"] *= group["lr_scale"]
        return new_lr

    def load_state(self, step):
        self._step = step


class WarmupCosineSchedule:
    """Warmup + cosine decay (original schedule, kept for compatibility)."""
    def __init__(self, optimizer, warmup_steps, start_lr, ref_lr, T_max, final_lr=0.0):
        self.optimizer = optimizer
        self.start_lr = start_lr
        self.ref_lr = ref_lr
        self.final_lr = final_lr
        self.warmup_steps = warmup_steps
        self.T_max = max(1, T_max - warmup_steps)
        self._step = 0

    def step(self):
        self._step += 1
        if self._step < self.warmup_steps:
            progress = self._step / max(1, self.warmup_steps)
            new_lr = self.start_lr + progress * (self.ref_lr - self.start_lr)
        else:
            progress = (self._step - self.warmup_steps) / self.T_max
            progress = min(1.0, progress)
            new_lr = self.final_lr + (self.ref_lr - self.final_lr) * 0.5 * (1 + math.cos(math.pi * progress))
            new_lr = max(self.final_lr, new_lr)

        for group in self.optimizer.param_groups:
            group["lr"] = new_lr * group.get("lr_scale", 1.0)
        return new_lr

    def load_state(self, step):
        self._step = step


class EMASchedule:
    """
    EMA momentum schedule from V-JEPA 2.
    
    V-JEPA 2 uses linear interpolation from start to end over ALL training:
        momentum = ema[0] + step * (ema[1] - ema[0]) / total_steps
    
    I-JEPA: 0.996 → 1.0 over training
    V-JEPA: 0.998 → 1.0 over training
    """
    def __init__(self, start_momentum, end_momentum, total_steps):
        self.start = start_momentum
        self.end = end_momentum
        self.total_steps = total_steps
        self._step = 0

    def step(self):
        self._step += 1
        # Linear interpolation over full training (V-JEPA 2 style)
        progress = min(1.0, self._step / max(1, self.total_steps))
        return self.start + progress * (self.end - self.start)

    def load_state(self, step):
        self._step = step


class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


class PerStrideLossTracker:
    """Track loss separately per stride for diagnostics."""
    def __init__(self):
        self.meters = {}
    
    def update(self, stride, loss, n=1):
        if stride not in self.meters:
            self.meters[stride] = AverageMeter()
        self.meters[stride].update(loss, n)
    
    def get_losses(self):
        """Return dict of {stride: avg_loss}."""
        return {s: m.avg for s, m in sorted(self.meters.items())}
    
    def reset(self):
        self.meters = {}


# Temporal Clip Dataset (loads raw cross-sections, tokenizes on-the-fly)

class TemporalClipDataset(IterableDataset):
    """
    Samples temporal clips of cross-sections for joint training.
    
    Returns L consecutive days (with stride) for tokenization.
    Stride-homogeneous batching: samples same stride for batch_size samples.
    """
    
    def __init__(
        self,
        cs_dataset: CrossSectionDataset,
        clip_length: int = 21,
        strides: list = None,
        stride_weights: list = None,
        samples_per_epoch: int = 10000,
        batch_size: int = 64,
        seed: int = 42,
    ):
        self.cs_dataset = cs_dataset
        self.clip_length = clip_length
        self.strides = strides or [1]
        self.stride_weights = stride_weights or [1.0]
        self.samples_per_epoch = samples_per_epoch
        self.batch_size = batch_size
        self.seed = seed
        
        # Normalize weights
        total = sum(self.stride_weights)
        self.stride_weights = [w / total for w in self.stride_weights]
        
        # Build valid endpoints per stride
        self.valid_endpoints = {}
        T = len(cs_dataset)
        for s in self.strides:
            min_ep = (clip_length - 1) * s
            self.valid_endpoints[s] = list(range(min_ep, T))
            if len(self.valid_endpoints[s]) == 0:
                raise ValueError(f"No valid endpoints for stride {s}")
    
    def __iter__(self):
        worker_info = get_worker_info()
        if worker_info is not None:
            worker_seed = self.seed + worker_info.id
            samples_per_worker = self.samples_per_epoch // worker_info.num_workers
        else:
            worker_seed = self.seed
            samples_per_worker = self.samples_per_epoch
        
        rng = np.random.default_rng(worker_seed)
        
        samples_yielded = 0
        while samples_yielded < samples_per_worker:
            # Sample ONE stride for this batch (stride-homogeneous)
            stride = int(rng.choice(self.strides, p=self.stride_weights))
            
            for _ in range(self.batch_size):
                if samples_yielded >= samples_per_worker:
                    break
                
                endpoints = self.valid_endpoints[stride]
                endpoint = int(rng.choice(endpoints))
                
                # Compute day indices for clip
                day_indices = [
                    endpoint - (self.clip_length - 1 - i) * stride
                    for i in range(self.clip_length)
                ]
                
                # Load cross-sections
                clip_X = []
                clip_mask = []
                for day_idx in day_indices:
                    item = self.cs_dataset[day_idx]
                    clip_X.append(item["X"])
                    clip_mask.append(item["mask"])
                
                yield {
                    "X": torch.stack(clip_X),       # [L, N, F]
                    "mask": torch.stack(clip_mask), # [L, N]
                    "stride": stride,
                }
                samples_yielded += 1
    
    def __len__(self):
        return self.samples_per_epoch


# Target Encoder Wrapper (tokenizer + JEPA encoder)

class JointTargetEncoder(torch.nn.Module):
    """Wraps tokenizer + temporal encoder for EMA target."""
    
    def __init__(self, tokenizer, temporal_encoder):
        super().__init__()
        self.tokenizer = tokenizer
        self.temporal_encoder = temporal_encoder
    
    def forward(self, X, cs_mask, stride=1):
        """
        Args:
            X: [B, L, N, F] cross-section features
            cs_mask: [B, L, N] cross-section validity mask
            stride: temporal stride
        Returns:
            h: [B, L*K, d] latent representations (without CLS, for loss computation)
        """
        B, L, N, F = X.shape
        
        # Tokenize each day
        X_flat = X.reshape(B * L, N, F)
        mask_flat = cs_mask.reshape(B * L, N)
        tokens_flat = self.tokenizer(X_flat, mask=mask_flat)  # [B*L, K, d]
        K = tokens_flat.shape[1]
        tokens = tokens_flat.reshape(B, L, K, -1)  # [B, L, K, d]
        
        # Encode temporally
        h = self.temporal_encoder(tokens, mask=None, stride=stride)
        
        # Strip CLS token if present (returns only token representations)
        h = self.temporal_encoder.get_token_output(h)  # [B, L*K, d]
        
        return h


def forward_target_joint(target_encoder, X, cs_mask, stride, normalize=True):
    """Forward through joint target encoder with stop-grad."""
    with torch.no_grad():
        h = target_encoder(X, cs_mask, stride=stride)
        if normalize:
            h = torch.nn.functional.layer_norm(h, (h.shape[-1],))
    return h


# Training

def create_param_groups(tokenizer, jepa_model, weight_decay):
    """Create param groups with proper weight decay."""
    decay = []
    no_decay = []
    
    # Tokenizer params
    for name, param in tokenizer.named_parameters():
        if "bias" in name or "ln" in name:
            no_decay.append(param)
        else:
            decay.append(param)
    
    # JEPA params
    for name, param in jepa_model.named_parameters():
        if not param.requires_grad:
            continue
        if "bias" in name or "norm" in name or "embed" in name or "token" in name:
            no_decay.append(param)
        else:
            decay.append(param)
    
    return [
        {"params": decay, "weight_decay": weight_decay},
        {"params": no_decay, "weight_decay": 0.0},
    ]


def train_epoch(
    tokenizer,
    model,
    target_encoder,
    dataloader,
    optimizer,
    scaler,
    lr_scheduler,
    ema_scheduler,
    mask_collator,
    device,
    dtype,
    loss_exp,
    log_freq,
    epoch,
    iters_per_epoch,
    clip_length,
    num_tokens,
    grad_clip=1.0,
    use_amp=True,
    use_wandb=False,
    global_step_start=0,
):
    tokenizer.train()
    model.train()
    # Target encoder stays in eval mode
    
    loss_meter = AverageMeter()
    time_meter = AverageMeter()
    stride_tracker = PerStrideLossTracker()  # Per-stride loss tracking
    global_step = global_step_start
    
    # Collapse diagnostics
    target_norms = []
    target_stds = []     # Cross-batch std of pooled features (key collapse metric)
    token_stds = []      # Per-dim std across all tokens

    # V-JEPA 2 style: iterate exactly iters_per_epoch times
    # This ensures scheduler step counts match actual steps
    for itr, batch in enumerate(itertools.islice(dataloader, iters_per_epoch)):
        start_time = time.time()
        
        X = batch["X"].to(device, non_blocking=True)      # [B, L, N, F]
        cs_mask = batch["mask"].to(device, non_blocking=True)  # [B, L, N]
        stride = batch["stride"]
        
        B, L, N, F = X.shape
        
        # Generate JEPA mask
        jepa_mask = mask_collator(B).to(device)  # [B, L, K]
        mask_ratio = jepa_mask.float().mean().item()
        
        optimizer.zero_grad(set_to_none=True)
        
        with torch.amp.autocast(device.type, dtype=dtype, enabled=use_amp):
            # Tokenize (both for student and target)
            X_flat = X.reshape(B * L, N, F)
            cs_mask_flat = cs_mask.reshape(B * L, N)
            tokens_flat = tokenizer(X_flat, mask=cs_mask_flat)
            tokens = tokens_flat.reshape(B, L, -1, tokens_flat.shape[-1])  # [B, L, K, d]
            
            # Target: stop-grad through target encoder
            h = forward_target_joint(target_encoder, X, cs_mask, stride=stride, normalize=True)
            
            if itr % 10 == 0:
                with torch.no_grad():
                    # Pool to final timestep representation (how we'll use it downstream)
                    h_reshaped = h.float().reshape(B, clip_length, num_tokens, -1)  # [B, L, K, d]
                    z_final = h_reshaped[:, -1].mean(dim=1)  # [B, d] - final timestep pooled
                    
                    # Collapse check: per-dimension std across batch
                    # If collapsed, all samples look the same → std ≈ 0
                    batch_std = z_final.std(dim=0).mean().item()  # Should be > 0.1
                    
                    # Also track raw feature diversity (before pooling)
                    h_flat = h.float().reshape(-1, h.shape[-1])  # [B*L*K, d]
                    token_std = h_flat.std(dim=0).mean().item()  # Per-dim std across all tokens
                    
                    target_norms.append(z_final.norm(dim=-1).mean().item())
                    target_stds.append(batch_std)  # This is the key collapse metric
                    token_stds.append(token_std)
            
            # Student prediction
            pred = model(tokens, mask=jepa_mask, stride=stride)
            
            # Loss
            loss = compute_loss(pred, h, jepa_mask, loss_exp=loss_exp)
        
        # Backward
        if scaler is not None:
            scaler.scale(loss).backward()
            if grad_clip > 0:
                scaler.unscale_(optimizer)
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    list(tokenizer.parameters()) + list(model.parameters()), grad_clip
                )
            else:
                grad_norm = 0.0
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            if grad_clip > 0:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    list(tokenizer.parameters()) + list(model.parameters()), grad_clip
                )
            else:
                grad_norm = 0.0
            optimizer.step()
        
        # EMA update (both tokenizer and encoder)
        momentum = ema_scheduler.step()
        with torch.no_grad():
            for p_src, p_tgt in zip(tokenizer.parameters(), target_encoder.tokenizer.parameters()):
                p_tgt.data.mul_(momentum).add_(p_src.data, alpha=1 - momentum)
            for p_src, p_tgt in zip(model.encoder.parameters(), target_encoder.temporal_encoder.parameters()):
                p_tgt.data.mul_(momentum).add_(p_src.data, alpha=1 - momentum)
        
        # LR update
        new_lr = lr_scheduler.step()
        
        loss_meter.update(loss.item())
        stride_tracker.update(stride, loss.item())  # Track per-stride
        time_meter.update(time.time() - start_time)
        global_step += 1
        
        # Logging
        if itr % log_freq == 0:
            mem_gb = torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0
            stride_losses = stride_tracker.get_losses()
            stride_str = " ".join([f"s{s}={l:.4f}" for s, l in stride_losses.items()])
            print(
                f"[E{epoch+1} I{itr:04d}] "
                f"loss={loss_meter.avg:.4f} ({stride_str}) "
                f"lr={new_lr:.1e} "
                f"ema={momentum:.4f} "
                f"t={time_meter.avg*1000:.0f}ms "
                f"mem={mem_gb:.1f}GB"
            )
            
            # wandb logging
            if use_wandb and WANDB_AVAILABLE:
                log_dict = {
                    "train/loss": loss.item(),
                    "train/loss_avg": loss_meter.avg,
                    "train/lr": new_lr,
                    "train/ema_momentum": momentum,
                    "train/grad_norm": grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm,
                    "train/mask_ratio": mask_ratio,
                    "train/stride": stride,
                    "train/step_time_ms": time_meter.val * 1000,
                    "train/memory_gb": mem_gb,
                    "epoch": epoch + 1,
                    "global_step": global_step,
                }
                # Per-stride losses
                for s, l in stride_losses.items():
                    log_dict[f"train/loss_s{s}"] = l
                
                # Collapse diagnostics
                if target_norms:
                    log_dict["collapse/target_norm"] = sum(target_norms) / len(target_norms)
                    log_dict["collapse/batch_std"] = sum(target_stds) / len(target_stds)  # Key metric
                    log_dict["collapse/token_std"] = sum(token_stds) / len(token_stds)
                
                wandb.log(log_dict)
    
    # Collapse summary for epoch
    collapse_stats = {}
    if target_norms:
        collapse_stats["target_norm_avg"] = sum(target_norms) / len(target_norms)
        collapse_stats["batch_std_avg"] = sum(target_stds) / len(target_stds)
        collapse_stats["token_std_avg"] = sum(token_stds) / len(token_stds)
    
    # Return per-stride losses for epoch summary
    return loss_meter.avg, global_step, stride_tracker.get_losses(), collapse_stats


def load_config(path):
    with open(path) as f:
        return yaml.safe_load(f)


def get_stride_schedule(cfg_stride, epoch):
    cumulative = 0
    for phase_name in ["phase1", "phase2", "phase3"]:
        phase = cfg_stride.get(phase_name)
        if phase is None:
            continue
        cumulative += phase["epochs"]
        if epoch < cumulative:
            return phase["strides"], phase["weights"]
    
    for phase_name in ["phase3", "phase2", "phase1"]:
        phase = cfg_stride.get(phase_name)
        if phase is not None:
            return phase["strides"], phase["weights"]
    
    return [1], [1.0]


def main(args):
    cfg = load_config(args.config)
    
    # wandb setup
    use_wandb = args.wandb and WANDB_AVAILABLE
    if args.wandb and not WANDB_AVAILABLE:
        print("Warning: wandb not installed, logging disabled. Install with: pip install wandb")
    
    if use_wandb:
        wandb.init(
            project=args.wandb_project,
            name=args.wandb_name or f"finjepa-{cfg['meta'].get('seed', 42)}",
            config=cfg,
            tags=["stage2", "joint-training"],
        )
        print(f"wandb run: {wandb.run.url}")
    
    # Device & dtype
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dtype_str = cfg["meta"].get("dtype", "bfloat16")
    dtype = torch.bfloat16 if dtype_str == "bfloat16" else torch.float32
    
    use_amp = device.type == "cuda"
    use_scaler = device.type == "cuda" and dtype == torch.float16
    
    if torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
    
    # Seed
    seed = cfg["meta"]["seed"]
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    
    # Output
    output_dir = Path(cfg["meta"]["output_dir"])
    output_dir.mkdir(parents=True, exist_ok=True)
    
    print(f"Config: {args.config}")
    print(f"Device: {device}, dtype: {dtype}")
    print(f"Output: {output_dir}")
    
    # Build tokenizer
    tcfg = cfg["tokenizer"]
    tokenizer = EquityTokenizer(
        dim_input=tcfg["dim_input"],
        dim_hidden=tcfg["dim_hidden"],
        num_tokens=tcfg["num_tokens"],
        num_inds=tcfg["num_inds"],
        num_heads=tcfg["num_heads"],
        num_isab_layers=tcfg["num_isab_layers"],
        ln=tcfg.get("ln", True),
        dropout=tcfg.get("dropout", 0.1),
    ).to(device)
    
    # Build JEPA model
    mcfg = cfg["model"]
    model = MarketJEPA(
        dim=mcfg["dim"],
        encoder_depth=mcfg["encoder_depth"],
        num_heads=mcfg["num_heads"],
        mlp_ratio=mcfg["mlp_ratio"],
        predictor_dim=mcfg["predictor_dim"],
        predictor_depth=mcfg["predictor_depth"],
        max_delta=mcfg["max_delta"],
        num_slots=tcfg["num_tokens"],  # K from tokenizer
        use_fourier_pos=mcfg["use_fourier_pos"],
        drop_rate=mcfg["drop_rate"],
        attn_drop_rate=mcfg["attn_drop_rate"],
        drop_path_rate=mcfg["drop_path_rate"],
        use_cls=mcfg["use_cls"],
        init_std=mcfg["init_std"],
    ).to(device)
    
    # Create joint target encoder (EMA copy of tokenizer + temporal encoder)
    import copy
    target_tokenizer = copy.deepcopy(tokenizer)
    target_temporal_encoder = copy.deepcopy(model.encoder)
    target_encoder = JointTargetEncoder(target_tokenizer, target_temporal_encoder).to(device)
    target_encoder.eval()
    for p in target_encoder.parameters():
        p.requires_grad_(False)
    
    # Count params
    tok_params = sum(p.numel() for p in tokenizer.parameters())
    jepa_params = sum(p.numel() for p in model.parameters())
    print(f"Tokenizer params: {tok_params:,} ({tok_params/1e6:.1f}M)")
    print(f"JEPA params: {jepa_params:,} ({jepa_params/1e6:.1f}M)")
    print(f"Total: {(tok_params + jepa_params):,} ({(tok_params + jepa_params)/1e6:.1f}M)")
    
    # Data
    dcfg = cfg["data"]
    cs_dataset = CrossSectionDataset(
        zarr_path=dcfg["zarr_path"],
        split="train",
        max_assets=dcfg["max_assets"],
    )
    print(f"Cross-section dataset: {len(cs_dataset)} days, {cs_dataset.num_features} features")
    
    # Mask collator
    L = dcfg["clip_length"]
    K = tcfg["num_tokens"]
    mask_collator = MarketMaskCollator(
        clip_length=L,
        num_tokens=K,
        min_visible_ratio=cfg["mask"]["min_visible_ratio"],
        causal_ratio=cfg["mask"]["causal_ratio"],
    )
    
    # Optimizer
    ocfg = cfg["optimization"]
    param_groups = create_param_groups(tokenizer, model, ocfg["weight_decay"])
    optimizer = torch.optim.AdamW(
        param_groups,
        lr=ocfg["lr"],
        betas=tuple(ocfg["betas"]),
        eps=ocfg["eps"],
    )
    
    scaler = torch.amp.GradScaler("cuda", enabled=use_scaler) if use_scaler else None
    
    # Schedulers
    num_epochs = ocfg["epochs"]
    samples_per_epoch = dcfg["samples_per_epoch"]
    batch_size = dcfg["batch_size"]
    num_workers = dcfg.get("num_workers", 0)
    
    if num_workers > 0:
        chunk = batch_size * num_workers
        if samples_per_epoch % chunk != 0:
            suggested = ((samples_per_epoch // chunk) + 1) * chunk
            print(f"WARNING: samples_per_epoch={samples_per_epoch} is not divisible by "
                  f"(batch_size * num_workers)={chunk}. This will cause step count mismatch!")
            print(f"         Suggested value: samples_per_epoch={suggested}")
        # Compute actual iters_per_epoch accounting for multi-worker behavior
        samples_per_worker = samples_per_epoch // num_workers
        batches_per_worker = samples_per_worker // batch_size
        iters_per_epoch = batches_per_worker * num_workers
    else:
        iters_per_epoch = samples_per_epoch // batch_size
    
    total_steps = iters_per_epoch * num_epochs
    print(f"Steps: {iters_per_epoch}/epoch × {num_epochs} epochs = {total_steps} total")
    
    # LR schedule: WSD (V-JEPA 2 style) or Cosine
    lr_schedule_type = ocfg.get("lr_schedule", "cosine")
    if lr_schedule_type == "wsd":
        # V-JEPA 2 style: warmup → constant → cooldown
        warmup_steps = ocfg["warmup_epochs"] * iters_per_epoch
        anneal_steps = ocfg.get("anneal_epochs", 10) * iters_per_epoch
        lr_scheduler = WSDSchedule(
            optimizer,
            warmup_steps=warmup_steps,
            anneal_steps=anneal_steps,
            T_max=total_steps,
            start_lr=ocfg["start_lr"],
            ref_lr=ocfg["lr"],
            final_lr=ocfg["final_lr"],
        )
        print(f"LR schedule: WSD (warmup={ocfg['warmup_epochs']}ep, stable={num_epochs - ocfg['warmup_epochs'] - ocfg.get('anneal_epochs', 10)}ep, anneal={ocfg.get('anneal_epochs', 10)}ep)")
    else:
        # Original cosine schedule
        lr_scheduler = WarmupCosineSchedule(
            optimizer,
            warmup_steps=ocfg["warmup_epochs"] * iters_per_epoch,
            start_lr=ocfg["start_lr"],
            ref_lr=ocfg["lr"],
            T_max=total_steps,
            final_lr=ocfg["final_lr"],
        )
        print(f"LR schedule: Cosine (warmup={ocfg['warmup_epochs']}ep)")
    
    # V-JEPA 2 style: EMA momentum ramps linearly over ALL training
    ema_scheduler = EMASchedule(
        start_momentum=ocfg["ema_start"],
        end_momentum=ocfg["ema_end"],
        total_steps=total_steps,
    )
    print(f"EMA schedule: {ocfg['ema_start']} → {ocfg['ema_end']} over {total_steps} steps (V-JEPA 2 style)")
    
    # Resume
    start_epoch = 0
    global_step = 0
    best_loss = float("inf")
    latest_path = output_dir / "latest.pt"
    
    if args.resume and latest_path.exists():
        print(f"Resuming from {latest_path}")
        ckpt = torch.load(latest_path, map_location=device, weights_only=False)
        tokenizer.load_state_dict(ckpt["tokenizer"])
        model.load_state_dict(ckpt["model"])
        target_encoder.load_state_dict(ckpt["target_encoder"])
        target_encoder.eval()
        optimizer.load_state_dict(ckpt["optimizer"])
        if scaler is not None and ckpt.get("scaler"):
            scaler.load_state_dict(ckpt["scaler"])
        start_epoch = ckpt["epoch"]
        global_step = ckpt.get("global_step", start_epoch * iters_per_epoch)
        best_loss = ckpt.get("best_loss", float("inf"))
        lr_scheduler.load_state(global_step)
        ema_scheduler.load_state(global_step)
        print(f"Resumed at epoch {start_epoch}, step {global_step}")
    
    # Training loop
    log_freq = cfg["logging"]["log_freq"]
    save_freq = cfg["logging"]["save_freq"]
    loss_exp = cfg["loss"]["loss_exp"]
    grad_clip = ocfg.get("grad_clip", 1.0)
    
    for epoch in range(start_epoch, num_epochs):
        print(f"\n{'='*60}")
        print(f"Epoch {epoch + 1}/{num_epochs}")
        print(f"{'='*60}")
        
        strides, weights = get_stride_schedule(cfg["stride_schedule"], epoch)
        print(f"Strides: {strides}, weights: {weights}")
        
        dataset = TemporalClipDataset(
            cs_dataset=cs_dataset,
            clip_length=L,
            strides=strides,
            stride_weights=weights,
            samples_per_epoch=samples_per_epoch,
            batch_size=batch_size,
            seed=seed + epoch,
        )
        
        loader = DataLoader(
            dataset,
            batch_size=batch_size,
            collate_fn=collate_clips,
            num_workers=dcfg.get("num_workers", 4),
            pin_memory=dcfg.get("pin_memory", True) and device.type == "cuda",
            drop_last=True,
            # persistent_workers=False: we create new loader each epoch (stride schedule changes)
            # Setting True would accumulate file handles → "Too many open files"
            persistent_workers=False,
        )
        
        train_loss, global_step, stride_losses, collapse_stats = train_epoch(
            tokenizer=tokenizer,
            model=model,
            target_encoder=target_encoder,
            dataloader=loader,
            optimizer=optimizer,
            scaler=scaler,
            lr_scheduler=lr_scheduler,
            ema_scheduler=ema_scheduler,
            mask_collator=mask_collator,
            device=device,
            dtype=dtype,
            loss_exp=loss_exp,
            log_freq=log_freq,
            epoch=epoch,
            iters_per_epoch=iters_per_epoch,  # Enforce exact step count (V-JEPA 2 style)
            clip_length=L,
            num_tokens=K,
            grad_clip=grad_clip,
            use_amp=use_amp,
            use_wandb=use_wandb,
            global_step_start=global_step,
        )
        
        # Cleanup to prevent "Too many open files" error
        del loader, dataset
        gc.collect()
        
        # Print per-stride losses and collapse diagnostics
        stride_str = " ".join([f"s{s}={l:.4f}" for s, l in stride_losses.items()])
        collapse_str = ""
        if collapse_stats:
            batch_std = collapse_stats.get('batch_std_avg', 0)
            # Warn if batch_std is very low (potential collapse)
            warn = " LOW!" if batch_std < 0.1 else ""
            collapse_str = f" [batch_std={batch_std:.4f}{warn}]"
        print(f"Epoch {epoch + 1} loss: {train_loss:.4f} ({stride_str}){collapse_str}")
        
        is_best = train_loss < best_loss
        if is_best:
            best_loss = train_loss
        
        # wandb epoch logging
        if use_wandb:
            epoch_log = {
                "epoch/train_loss": train_loss,
                "epoch/best_loss": best_loss,
                "epoch/strides": str(strides),
                "epoch": epoch + 1,
            }
            # Per-stride epoch losses
            for s, l in stride_losses.items():
                epoch_log[f"epoch/loss_s{s}"] = l
            # Collapse diagnostics
            if collapse_stats:
                epoch_log["collapse/epoch_target_norm"] = collapse_stats.get("target_norm_avg", 0)
                epoch_log["collapse/epoch_batch_std"] = collapse_stats.get("batch_std_avg", 0)
                epoch_log["collapse/epoch_token_std"] = collapse_stats.get("token_std_avg", 0)
            wandb.log(epoch_log)
        
        checkpoint = {
            "epoch": epoch + 1,
            "global_step": global_step,
            "tokenizer": tokenizer.state_dict(),
            "model": model.state_dict(),
            "target_encoder": target_encoder.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scaler": scaler.state_dict() if scaler else None,
            "train_loss": train_loss,
            "best_loss": best_loss,
            "config": cfg,
        }
        
        torch.save(checkpoint, output_dir / "latest.pt")
        
        if is_best:
            torch.save(checkpoint, output_dir / "best.pt")
            print(f"New best! loss={best_loss:.4f}")
            # Save best model to wandb
            if use_wandb:
                wandb.save(str(output_dir / "best.pt"))
        
        if (epoch + 1) % save_freq == 0:
            torch.save(checkpoint, output_dir / f"epoch_{epoch+1:03d}.pt")
    
    print(f"\nTraining complete. Best loss: {best_loss:.4f}")
    print(f"Checkpoints: {output_dir}")
    print("\nNext: run cache_tokens.py to freeze tokenizer outputs for Stage 3")
    
    # Finish wandb
    if use_wandb:
        wandb.finish()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train FinJEPA (Tokenizer + JEPA joint training)")
    parser.add_argument("--config", type=str, default="finjepa/configs/joint.yaml",
                        help="Path to config file")
    parser.add_argument("--resume", action="store_true",
                        help="Resume from latest checkpoint")
    
    # wandb arguments
    parser.add_argument("--wandb", action="store_true",
                        help="Enable wandb logging")
    parser.add_argument("--wandb-project", type=str, default="finjepa",
                        help="wandb project name")
    parser.add_argument("--wandb-name", type=str, default=None,
                        help="wandb run name (auto-generated if not set)")
    
    args = parser.parse_args()
    main(args)

