"""Training script for ACE model with offline GP data and compilation optimizations."""

import os

os.environ["TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS"] = "TRITON"
os.environ["TORCHINDUCTOR_CPP_WRAPPER"] = "0"  # Changed from "1" to "0"
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1"
os.environ["MIOPEN_FIND_MODE"] = "1"
os.environ["PYTORCH_HIP_ALLOC_CONF"] = "max_split_size_mb:512"
os.environ["TORCHINDUCTOR_DISABLE_CUDAGRAPHS"] = "1"

# Configure inductor settings
import torch._inductor.config as config
config.max_autotune_gemm = True
config.rocm.n_max_profiling_configs = 10
config.compile_threads = 8
config.triton.unique_kernel_names = True
config.triton.cudagraphs = False
config.coordinate_descent_tuning = True
config.triton.persistent_reductions = True

import time
from pathlib import Path
from typing import Dict, Optional, Tuple
from collections import defaultdict

import hydra
import math
import torch
import torch.nn as nn
import torch.multiprocessing as mp
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader
from tqdm import tqdm

from src.data.utils import OfflineBatchLoader
from src.models.ace import AmortizedConditioningEngine
from src.models.masks import create_training_block_mask
from src.models.modules import Embedder, MixtureGaussian, Transformer
from src.models.tabular_embedder import TabularEmbedder, TabularACE
from src.utils import DataAttr

try:
    import wandb
    WANDB_AVAILABLE = True
except ImportError:
    WANDB_AVAILABLE = False
    print("Warning: wandb not available. Install with 'pip install wandb' for experiment tracking.")


class CompiledTrainer:
    """Trainer class for ACE model with compilation optimizations."""
    
    def __init__(self, cfg: DictConfig):
        self.cfg = cfg
        self.device = torch.device(cfg.device)
        self.start_time = time.time()
        
        # Initialize wandb if available and enabled
        self.use_wandb = cfg.logging.use_wandb and WANDB_AVAILABLE
        if self.use_wandb:
            wandb.init(
                project=cfg.logging.project,
                name=cfg.logging.run_name,
                config=OmegaConf.to_container(cfg, resolve=True),
                tags=cfg.logging.get("tags", []),
            )
        
        # Setup model
        self.model = self._build_model()
        self.model = self.model.to(self.device)
        
        # Compile create_block_mask function first
        if cfg.training.compile_mask:
            self._compile_mask_function()
        
        # Compile model if enabled
        if cfg.training.compile_model:
            print(f"Compiling model with mode: {cfg.training.compile_mode}")
            compile_kwargs = {
                "fullgraph": cfg.training.get("fullgraph", False),
                "dynamic": cfg.training.get("dynamic", False),
            }
            
            # Only add mode if it's not "default"
            if cfg.training.compile_mode != "default":
                compile_kwargs["mode"] = cfg.training.compile_mode
            
            self.model = torch.compile(self.model, **compile_kwargs)
        
        # Setup optimizer
        self.optimizer = self._build_optimizer()
        
        # Setup mixed precision
        self.use_amp = cfg.training.use_amp
        self.amp_dtype = torch.bfloat16 if cfg.training.amp_dtype == "bfloat16" else torch.float16
        
        # Check bfloat16 support
        if self.use_amp and self.amp_dtype == torch.bfloat16:
            if self.device.type == "cpu":
                print("Using bfloat16 mixed precision on CPU")
            elif self.device.type == "cuda" and not torch.cuda.is_bf16_supported():
                print("Warning: bfloat16 not supported on this GPU, falling back to float16")
                self.amp_dtype = torch.float16
        
        # Setup gradient scaler for mixed precision training
        self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp and self.device.type == "cuda")
        
        # Setup data
        self.train_loader = self._build_dataloader("train")
        self.val_loader = self._build_dataloader("val") if cfg.data.val_path else None
        
        # Setup learning rate scheduler
        self.scheduler = self._build_scheduler()
        
        # Training state
        self.global_step = 0
        self.epoch = 0
        self.best_val_loss = float('inf')
        
        # Statistics tracking
        self.stats = {
            "train_loss": [],
            "val_loss": [],
            "learning_rate": [],
            "epoch_times": [],
            "step_times": [],
        }
        
        # Pre-create all possible block masks
        self.mask_cache = self._precreate_masks()
        
        # Pre-warm compilation for all shapes if enabled
        if cfg.training.compile_model and cfg.training.get("prewarm_compilation", True):
            self._prewarm_compilation()
    
    def _compile_mask_function(self):
        """Compile the create_block_mask function for better performance."""
        print("Compiling create_block_mask function...")
        import torch.nn.attention.flex_attention as flex
        
        # Save original function
        self.original_create_block_mask = flex.create_block_mask
        
        # Compile and replace
        create_block_mask_compiled = torch.compile(flex.create_block_mask)
        flex.create_block_mask = create_block_mask_compiled
        
        print("create_block_mask function compiled successfully")
    
    def _precreate_masks(self) -> Dict[Tuple[int, int, int, int | None, bool], torch.Tensor]:
        """Pre-create all possible block masks for different sequence lengths."""
        print("Pre-creating block masks for all sequence length combinations...")
        mask_cache: Dict[Tuple[int, int, int, int | None, bool], torch.Tensor] = {}
        
        buffer_len = self.cfg.model.max_buffer_size  # Can be 4, 8, or 16
        attending_chunks = self.cfg.model.get('attending_chunks', None)
        include_diagonal = self.cfg.model.get('include_diagonal_mask', True)
        
        # Check if precompile_shapes is specified in config
        if 'precompile_shapes' in self.cfg.model:
            # Use shapes from config: [context+buffer, target]
            print(f"Using precompile_shapes from config: {len(self.cfg.model.precompile_shapes)} shapes")
            for shape in self.cfg.model.precompile_shapes:
                context_buffer_len, target_len = shape
                ctx_len = context_buffer_len - buffer_len
                total_len = context_buffer_len + target_len
                
                print(f"  Shape: context={ctx_len}, buffer={buffer_len}, target={target_len}")
                key = (total_len, ctx_len, buffer_len, attending_chunks, include_diagonal)
                
                mask_kwargs = {
                    'current_total_q_len': total_len,
                    'current_total_kv_len': total_len,
                    'current_context_section_len': ctx_len,
                    'current_buffer_section_len': buffer_len,
                    'device': self.device,
                }
                
                # Add attending_chunks if specified in config
                if attending_chunks is not None:
                    mask_kwargs['attending_chunks'] = attending_chunks
                # Add diagonal control
                mask_kwargs['include_diagonal'] = include_diagonal
                    
                mask_cache[key] = create_training_block_mask(**mask_kwargs)
        else:
            # Fallback to default context lengths
            print("No precompile_shapes in config, using default context lengths")
            context_lens = [4, 8, 16, 32, 48, 64, 128, 192]  # 8 different sizes
            target_len = self.cfg.model.num_target_points  # Fixed at 256
            
            for ctx_len in context_lens:
                total_len = ctx_len + buffer_len + target_len
                key = (total_len, ctx_len, buffer_len, attending_chunks, include_diagonal)
                
                mask_kwargs = {
                    'current_total_q_len': total_len,
                    'current_total_kv_len': total_len,
                    'current_context_section_len': ctx_len,
                    'current_buffer_section_len': buffer_len,
                    'device': self.device,
                }
                
                # Add attending_chunks if specified in config
                if attending_chunks is not None:
                    mask_kwargs['attending_chunks'] = attending_chunks
                # Add diagonal control
                mask_kwargs['include_diagonal'] = include_diagonal
                    
                mask_cache[key] = create_training_block_mask(**mask_kwargs)
        
        print(f"Created {len(mask_cache)} block masks")
        return mask_cache
    
    def _prewarm_compilation(self):
        """Pre-warm model compilation for all expected shapes."""
        print("Pre-warming model compilation for all shapes...")
        
        # Create dummy batches for each shape
        # Keys are (total_len, ctx_len, buf_len, attending_chunks, include_diagonal)
        for (total_len, ctx_len, buf_len, _att_chunks, _incl_diag), mask in self.mask_cache.items():
            # Calculate dimensions
            tar_len = total_len - ctx_len - buf_len
            
            # Create dummy batch with correct batch size
            # Since our data files are pre-batched, batch_size might be null in config
            batch_size = self.cfg.data.get('batch_size') or 128  # Default to 128 if null
            dummy_batch = DataAttr(
                xc=torch.randn(batch_size, ctx_len, self.cfg.model.dim_x, device=self.device),
                yc=torch.randn(batch_size, ctx_len, self.cfg.model.dim_y, device=self.device),
                xb=torch.randn(batch_size, buf_len, self.cfg.model.dim_x, device=self.device),
                yb=torch.randn(batch_size, buf_len, self.cfg.model.dim_y, device=self.device),
                xt=torch.randn(batch_size, tar_len, self.cfg.model.dim_x, device=self.device),
                yt=torch.randn(batch_size, tar_len, self.cfg.model.dim_y, device=self.device),
            )
            
            # Run forward pass to trigger compilation
            with torch.no_grad():
                with torch.autocast(device_type=self.device.type, dtype=self.amp_dtype, enabled=self.use_amp):
                    _ = self.model(dummy_batch, mask)
            
        print("Pre-warming complete")
    
    def _build_model(self) -> AmortizedConditioningEngine:
        """Build ACE model from config."""
        print("Building model...")
        
        # Check if we should use TabularACE
        embedder_type = self.cfg.model.embedder.get('type', 'standard')
        
        if embedder_type == 'tabular':
            print("Using TabularACE for tabular data")
            # Optional unified FF factor (used to compute per-module FF dims)
            ff_factor = self.cfg.model.get('ff_factor', None)
            ff_dim = int(self.cfg.model.dim_model * float(ff_factor)) if ff_factor is not None else self.cfg.model.backbone.dim_feedforward
            # CLS concat settings
            concat_cls = self.cfg.model.embedder.get('concat_cls', False)
            num_cls_tokens = self.cfg.model.embedder.get('num_cls_tokens', 4)
            col_nhead = self.cfg.model.embedder.get('col_nhead', None)
            row_nhead = self.cfg.model.embedder.get('row_nhead', None)
            row_num_blocks = self.cfg.model.embedder.get('num_layers', 1)
            model = TabularACE(
                num_features=self.cfg.model.embedder.get('max_dim_x', self.cfg.model.dim_x),
                embed_dim=self.cfg.model.dim_model,
                transformer_layers=self.cfg.model.backbone.num_layers,
                nhead=self.cfg.model.backbone.num_heads,
                dim_feedforward=ff_dim,
                num_components=self.cfg.model.head.num_components,
                max_buffer_size=self.cfg.model.max_buffer_size,
                num_target_points=self.cfg.model.num_target_points,
                targets_block_size_for_buffer_attend=self.cfg.model.targets_block_size_for_buffer_attend,
                dropout=self.cfg.model.backbone.dropout,
                num_inducing_points=self.cfg.model.embedder.get('num_inducing_points', 64),
                num_isab_blocks=self.cfg.model.embedder.get('num_isab_blocks', 1),
                row_rope_base=self.cfg.model.embedder.get('row_rope_base', 30000),
                col_nhead=col_nhead,
                row_nhead=row_nhead,
                row_num_blocks=row_num_blocks,
                concat_cls=concat_cls,
                num_cls_tokens=num_cls_tokens,
                ff_factor=ff_factor,
            )
            print(f"Using MixtureGaussian head with {self.cfg.model.head.num_components} components")
        else:
            embedder = Embedder(
                dim_x=self.cfg.model.dim_x,
                dim_y=self.cfg.model.dim_y,
                hidden_dim=self.cfg.model.embedder.hidden_dim,
                out_dim=self.cfg.model.dim_model,
                depth=self.cfg.model.embedder.depth,
            )
            
            ff_factor = self.cfg.model.get('ff_factor', None)
            if ff_factor is not None:
                ff_dim = int(self.cfg.model.dim_model * float(ff_factor))
            else:
                ff_dim = self.cfg.model.backbone.dim_feedforward
            backbone = Transformer(
                num_layers=self.cfg.model.backbone.num_layers,
                dim_model=self.cfg.model.dim_model,
                num_head=self.cfg.model.backbone.num_heads,
                dim_feedforward=ff_dim,
                dropout=self.cfg.model.backbone.dropout,
            )
            
            # Build head with type selection and std_min support
            head_type = self.cfg.model.head.get('type', 'MixtureGaussian')
            head_kwargs = {
                'dim_y': self.cfg.model.dim_y,
                'dim_model': self.cfg.model.dim_model,
                'dim_feedforward': ff_dim,
                'num_components': self.cfg.model.head.num_components,
            }
            
            # Add std_min if specified in config
            if 'std_min' in self.cfg.model.head:
                head_kwargs['std_min'] = self.cfg.model.head.std_min
            
            # Create the head based on type
            if head_type == 'MixtureGaussian':
                head = MixtureGaussian(**head_kwargs)
            else:
                raise ValueError(f"Unknown head type: {head_type}")
            
            print(f"Using {head_type} head with {self.cfg.model.head.num_components} components")
            
            model = AmortizedConditioningEngine(
                embedder=embedder,
                backbone=backbone,
                head=head,
                max_buffer_size=self.cfg.model.max_buffer_size,
                num_target_points=self.cfg.model.num_target_points,
                targets_block_size_for_buffer_attend=self.cfg.model.targets_block_size_for_buffer_attend,
            )
        
        # Count parameters
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"Model built: {trainable_params:,} trainable parameters (total: {total_params:,})")
        
        return model
    
    def _build_optimizer(self) -> torch.optim.Optimizer:
        """Build optimizer with param groups (no decay for LN/bias/GMM head)."""
        wd = self.cfg.optimizer.weight_decay

        # Split parameters into groups
        decay_params, nodecay_params, head_params = [], [], []
        head_param_ids = {id(p) for p in self.model.head.parameters()} if hasattr(self.model, 'head') else set()

        for name, p in self.model.named_parameters():
            if not p.requires_grad:
                continue
            if id(p) in head_param_ids:
                head_params.append(p)
                continue

            # Heuristic: no decay for biases and normalization parameters
            if p.ndim < 2 or name.endswith('bias') or 'norm' in name.lower() or 'ln' in name.lower():
                nodecay_params.append(p)
            else:
                decay_params.append(p)

        param_groups = []
        if decay_params:
            param_groups.append({
                'params': decay_params,
                'weight_decay': wd,
            })
        if nodecay_params:
            param_groups.append({
                'params': nodecay_params,
                'weight_decay': 0.0,
            })
        if head_params:
            param_groups.append({
                'params': head_params,
                'weight_decay': 0.0,
            })

        opt_name = self.cfg.optimizer.name.lower()
        if opt_name == "adam":
            return torch.optim.Adam(
                param_groups,
                lr=self.cfg.optimizer.lr,
                betas=tuple(self.cfg.optimizer.betas),
            )
        elif opt_name == "adamw":
            return torch.optim.AdamW(
                param_groups,
                lr=self.cfg.optimizer.lr,
                betas=tuple(self.cfg.optimizer.betas),
            )
        else:
            raise ValueError(f"Unknown optimizer: {self.cfg.optimizer.name}")
    
    def _build_scheduler(self) -> Optional[torch.optim.lr_scheduler._LRScheduler]:
        """Build learning rate scheduler from config."""
        if not self.cfg.scheduler.get("use_scheduler", False):
            return None
        
        # Calculate total training steps (support step-based training)
        max_steps = self.cfg.training.get("max_steps", 0) or 0
        if max_steps > 0:
            total_steps = max_steps
        else:
            steps_per_epoch = len(self.train_loader)
            total_steps = steps_per_epoch * self.cfg.training.num_epochs
        # Warmup: explicit steps override ratio if provided
        warmup_steps = int(total_steps * self.cfg.scheduler.warmup_ratio)
        if 'warmup_steps' in self.cfg.scheduler and self.cfg.scheduler.warmup_steps:
            try:
                ws = int(self.cfg.scheduler.warmup_steps)
                if ws > 0:
                    warmup_steps = ws
            except Exception:
                pass
        
        def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1):
            """Create a schedule with a learning rate that decreases following the values of the cosine function."""
            def lr_lambda(current_step):
                if current_step < num_warmup_steps:
                    return float(current_step) / float(max(1, num_warmup_steps))
                progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
                return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
            
            return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)
        
        num_cycles = float(self.cfg.scheduler.get('num_cycles', 0.5))
        if self.cfg.scheduler.name == "cosine":
            scheduler = get_cosine_schedule_with_warmup(
                self.optimizer,
                num_warmup_steps=0,
                num_training_steps=total_steps,
                num_cycles=num_cycles,
            )
        elif self.cfg.scheduler.name == "cosine_with_warmup":
            scheduler = get_cosine_schedule_with_warmup(
                self.optimizer,
                num_warmup_steps=warmup_steps,
                num_training_steps=total_steps,
                num_cycles=num_cycles,
            )
        else:
            raise ValueError(f"Unknown scheduler: {self.cfg.scheduler.name}")
        
        print(f"Using {self.cfg.scheduler.name} scheduler")
        print(f"  Total steps: {total_steps}, Warmup steps: {warmup_steps} ({self.cfg.scheduler.warmup_ratio*100:.0f}%)")
        return scheduler
    
    def _build_dataloader(self, split: str) -> DataLoader:
        """Build dataloader for given split."""
        # Check if using same path for train/val with chunk splitting
        use_chunk_split = self.cfg.data.get('use_chunk_split', False)
        
        if use_chunk_split and self.cfg.data.train_path == self.cfg.data.val_path:
            # Use same path but different chunks for train/val
            data_path = Path(self.cfg.data.train_path)
            if split == "val":
                # Use first chunk for validation
                chunk_subset = [0]
            else:
                # Use all chunks except first for training
                num_chunks = self.cfg.data.get('num_chunks', 128)
                chunk_subset = list(range(1, num_chunks-1))
        else:
            # Traditional separate paths
            data_path = Path(self.cfg.data.train_path if split == "train" else self.cfg.data.val_path)
            chunk_subset = None
        
        if not data_path.exists():
            raise ValueError(f"Data path does not exist: {data_path}")
        
        # Data caching control for chunked offline loader
        cache_chunks = self.cfg.data.get('cache_chunks', True)
        dataset = OfflineBatchLoader(
            data_path,
            device="cpu",
            cache_chunks=cache_chunks,
            chunk_subset=chunk_subset,
        )
        
        # DataLoader threading and memory behavior (configurable)
        num_workers = self.cfg.data.get('num_workers', 16)
        pin_memory = self.cfg.data.get('pin_memory', self.device.type == "cuda")
        persistent_workers = self.cfg.data.get('persistent_workers', num_workers > 0)
        # Default to a finite timeout to surface stuck workers
        loader_timeout = self.cfg.data.get('loader_timeout', 100)
        # Prefetch factor only valid when num_workers > 0
        prefetch_factor = self.cfg.data.get('prefetch_factor', 4)
        
        # For validation, optionally reduce dataset size
        if split == "val" and self.cfg.data.get('val_subset_size', 0) > 0:
            val_size = min(self.cfg.data.val_subset_size, len(dataset))
            indices = list(range(val_size))
            dataset = torch.utils.data.Subset(dataset, indices)
            print(f"Using subset of validation data: {val_size} batches")
        
        # Build DataLoader kwargs while respecting PyTorch constraints
        loader_kwargs = dict(
            batch_size=None,
            shuffle=(split == "train"),
            num_workers=num_workers,
            pin_memory=pin_memory,
            timeout=loader_timeout,
        )
        if num_workers > 0:
            loader_kwargs["persistent_workers"] = persistent_workers
            loader_kwargs["prefetch_factor"] = prefetch_factor
        else:
            loader_kwargs["persistent_workers"] = False

        dataloader = DataLoader(dataset, **loader_kwargs)
        
        print(f"Loaded {split} dataset: {len(dataset)} batches from {data_path}")
        pf = prefetch_factor if num_workers > 0 else None
        print(f"  Workers={num_workers}, pin_memory={pin_memory}, persistent_workers={loader_kwargs['persistent_workers']}, prefetch_factor={pf}, timeout={loader_timeout}s, cache_chunks={cache_chunks}")
        return dataloader
    
    def _get_cached_mask(self, batch: DataAttr) -> torch.Tensor:
        """Get pre-cached block mask for the batch."""
        total_len = batch.xc.shape[1] + batch.xb.shape[1] + batch.xt.shape[1]
        attending_chunks = self.cfg.model.get('attending_chunks', None)
        include_diagonal = self.cfg.model.get('include_diagonal_mask', True)
        key = (total_len, batch.xc.shape[1], batch.xb.shape[1], attending_chunks, include_diagonal)
        
        if key in self.mask_cache:
            return self.mask_cache[key]
        else:
            # Fallback: create mask on-the-fly and cache it
            print(f"Warning: Creating mask for new shape {key}")
            
            mask_kwargs = {
                'current_total_q_len': total_len,
                'current_total_kv_len': total_len,
                'current_context_section_len': batch.xc.shape[1],
                'current_buffer_section_len': batch.xb.shape[1],
                'device': self.device,
            }
            
            # Add mask options if specified in config
            if attending_chunks is not None:
                mask_kwargs['attending_chunks'] = attending_chunks
            if include_diagonal is not None:
                mask_kwargs['include_diagonal'] = include_diagonal
                
            mask = create_training_block_mask(**mask_kwargs)
            self.mask_cache[key] = mask
            return mask
    
    def train_epoch(self) -> Dict[str, float]:
        """Train for one epoch."""
        self.model.train()
        epoch_start = time.time()
        
        total_loss = 0.0
        num_batches = 0
        step_times = []
        
        progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.epoch}")
        
        for batch_idx, batch in enumerate(progress_bar):
            # Optional global step cap (supports online step-based training)
            max_steps = self.cfg.training.get("max_steps", 0) or 0
            if max_steps > 0 and self.global_step >= max_steps:
                break
            step_start = time.time()
            
            # Move batch to device
            batch = batch.to(self.device)
            
            # Get pre-cached block mask
            block_mask = self._get_cached_mask(batch)
            
            # Forward pass with mixed precision
            with torch.autocast(device_type=self.device.type, dtype=self.amp_dtype, enabled=self.use_amp):
                output = self.model(batch, block_mask)
                loss = output.loss
            
            # Backward pass
            self.optimizer.zero_grad()
            
            if self.use_amp and self.device.type == "cuda":
                # Use gradient scaling for CUDA (works for both NVIDIA and AMD)
                self.scaler.scale(loss).backward()
                
                if self.cfg.training.grad_clip > 0:
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(), 
                        self.cfg.training.grad_clip
                    )
                
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                # Standard backward pass for CPU or no AMP
                loss.backward()
                
                if self.cfg.training.grad_clip > 0:
                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(), 
                        self.cfg.training.grad_clip
                    )
                
                self.optimizer.step()
            
            # Synchronize for accurate timing on GPU
            if self.device.type == "cuda" and torch.cuda.is_available():
                torch.cuda.synchronize()
            
            # Update learning rate scheduler
            if self.scheduler is not None:
                self.scheduler.step()
            
            # Track statistics
            step_time = time.time() - step_start
            step_times.append(step_time)
            total_loss += loss.item()
            num_batches += 1
            self.global_step += 1
            
            # Update progress bar
            progress_bar.set_postfix({
                "loss": f"{loss.item():.4f}",
                "lr": f"{self.optimizer.param_groups[0]['lr']:.2e}",
                "step_time": f"{step_time:.3f}s"
            })
            
            # Log to wandb
            if self.use_wandb and self.global_step % self.cfg.logging.log_interval == 0:
                wandb.log({
                    "train/loss": loss.item(),
                    "train/learning_rate": self.optimizer.param_groups[0]['lr'],
                    "train/step_time": step_time,
                    "train/global_step": self.global_step,
                }, step=self.global_step)

            # Optional step-based validation
            val_step_interval = self.cfg.training.get("val_step_interval", 0) or 0
            if (
                self.val_loader is not None
                and val_step_interval > 0
                and (self.global_step % val_step_interval == 0)
            ):
                val_stats = self.validate()
                self.stats["val_loss"].append(val_stats["loss"])
                print(f"Validation Loss: {val_stats['loss']:.4f}")
                if val_stats["loss"] < self.best_val_loss:
                    self.best_val_loss = val_stats["loss"]
                    print("New best validation loss! Saving model...")
                    self.save_checkpoint(
                        Path(self.cfg.checkpoint.save_dir)
                        / f"checkpoint_step_{self.global_step}.pt",
                        is_best=True,
                    )
        
        epoch_time = time.time() - epoch_start
        avg_loss = total_loss / num_batches
        
        return {
            "loss": avg_loss,
            "epoch_time": epoch_time,
            "avg_step_time": sum(step_times) / len(step_times),
            "steps_per_second": len(step_times) / epoch_time,
        }
    
    def validate(self) -> Dict[str, float]:
        """Validate model."""
        if self.val_loader is None:
            return {}
        
        self.model.eval()
        total_loss = 0.0
        num_batches = 0
        
        with torch.no_grad():
            for batch in tqdm(self.val_loader, desc="Validation"):
                batch = batch.to(self.device)
                block_mask = self._get_cached_mask(batch)
                
                with torch.autocast(device_type=self.device.type, dtype=self.amp_dtype, enabled=self.use_amp):
                    output = self.model(batch, block_mask)
                
                total_loss += output.loss.item()
                num_batches += 1
        
        avg_loss = total_loss / num_batches
        
        if self.use_wandb:
            wandb.log({
                "val/loss": avg_loss,
                "epoch": self.epoch,
            }, step=self.global_step)
        
        return {"loss": avg_loss}
    
    def save_checkpoint(self, path: Path, is_best: bool = False):
        """Save model checkpoint."""
        checkpoint = {
            "epoch": self.epoch,
            "global_step": self.global_step,
            "model_state_dict": self.model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
            "config": OmegaConf.to_container(self.cfg),
            "stats": self.stats,
        }
        
        torch.save(checkpoint, path)
        
        if is_best:
            best_path = path.parent / "best_model.pt"
            torch.save(checkpoint, best_path)
            print(f"Saved best model to {best_path}")
    
    def _train_steps(self, max_steps: int):
        """Train for a fixed number of steps with a single step-based progress bar."""
        self.model.train()
        step_start_time = time.time()
        steps_done = 0
        steps_remaining = max_steps - self.global_step
        progress_bar = tqdm(total=steps_remaining, desc="Steps")
        step_times = []
        total_loss = 0.0

        while self.global_step < max_steps:
            for batch in self.train_loader:
                if self.global_step >= max_steps:
                    break
                iter_start = time.time()
                batch = batch.to(self.device)
                block_mask = self._get_cached_mask(batch)
                with torch.autocast(device_type=self.device.type, dtype=self.amp_dtype, enabled=self.use_amp):
                    output = self.model(batch, block_mask)
                    loss = output.loss
                self.optimizer.zero_grad()
                if self.use_amp and self.device.type == "cuda":
                    self.scaler.scale(loss).backward()
                    if self.cfg.training.grad_clip > 0:
                        self.scaler.unscale_(self.optimizer)
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.training.grad_clip)
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                else:
                    loss.backward()
                    if self.cfg.training.grad_clip > 0:
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.training.grad_clip)
                    self.optimizer.step()

                if self.device.type == "cuda" and torch.cuda.is_available():
                    torch.cuda.synchronize()
                if self.scheduler is not None:
                    self.scheduler.step()

                step_time = time.time() - iter_start
                step_times.append(step_time)
                total_loss += loss.item()
                steps_done += 1
                self.global_step += 1
                progress_bar.update(1)

                # Logging
                if self.use_wandb and self.global_step % self.cfg.logging.log_interval == 0:
                    wandb.log({
                        "train/loss": loss.item(),
                        "train/learning_rate": self.optimizer.param_groups[0]['lr'],
                        "train/step_time": step_time,
                        "train/global_step": self.global_step,
                    }, step=self.global_step)

                # Step-based validation
                val_step_interval = self.cfg.training.get("val_step_interval", 0) or 0
                if (
                    self.val_loader is not None
                    and val_step_interval > 0
                    and (self.global_step % val_step_interval == 0)
                ):
                    val_stats = self.validate()
                    self.stats.setdefault("val_loss", []).append(val_stats["loss"])
                    print(f"Validation Loss: {val_stats['loss']:.4f}")
                    if val_stats["loss"] < self.best_val_loss:
                        self.best_val_loss = val_stats["loss"]
                        print("New best validation loss! Saving model...")
                        self.save_checkpoint(
                            Path(self.cfg.checkpoint.save_dir) / f"checkpoint_step_{self.global_step}.pt",
                            is_best=True,
                        )

        progress_bar.close()
        total_time = time.time() - step_start_time
        avg_loss = total_loss / max(1, steps_done)
        print(f"\nTraining complete!")
        print(f"Total time: {total_time/3600:.2f} hours ({total_time:.1f}s)")
        print(f"Final train loss: {avg_loss:.4f}")
        if self.use_wandb:
            wandb.summary["total_training_time"] = total_time
            wandb.summary["final_train_loss"] = avg_loss
            wandb.finish()

    def train(self):
        """Main training loop."""
        max_steps = self.cfg.training.get("max_steps", 0) or 0
        if max_steps > 0:
            print(f"\nStarting training for {max_steps} steps")
        else:
            print(f"\nStarting training for {self.cfg.training.num_epochs} epochs")
        print(f"Device: {self.device}")
        print(f"Compile Model: {self.cfg.training.compile_model}")
        print(f"Compile Mask: {self.cfg.training.compile_mask}")
        print(f"Mixed Precision: {self.use_amp} ({self.amp_dtype if self.use_amp else 'disabled'})")
        print(f"WandB: {self.use_wandb}")
        print(f"Scheduler: {self.cfg.scheduler.name if self.scheduler else 'None'}")
        print(f"Mask Cache Size: {len(self.mask_cache)}")
        
        # Detect if running on AMD GPU
        if self.device.type == "cuda" and hasattr(torch.version, 'hip'):
            print("Running on AMD GPU (ROCm)")

        if max_steps > 0:
            return self._train_steps(max_steps)
        
        for epoch in range(self.cfg.training.num_epochs):
            self.epoch = epoch
            
            # Train epoch
            train_stats = self.train_epoch()
            self.stats["train_loss"].append(train_stats["loss"])
            self.stats["epoch_times"].append(train_stats["epoch_time"])
            
            print(f"\nEpoch {epoch} - Train Loss: {train_stats['loss']:.4f}, "
                  f"Time: {train_stats['epoch_time']:.1f}s, "
                  f"Steps/s: {train_stats['steps_per_second']:.1f}")
            
            # Epoch-based validation (only if step-based is not enabled)
            val_step_interval = self.cfg.training.get("val_step_interval", 0) or 0
            if self.val_loader is not None and val_step_interval <= 0 and (epoch + 1) % self.cfg.training.get('val_interval', 1) == 0:
                val_stats = self.validate()
                self.stats["val_loss"].append(val_stats["loss"])
                print(f"Validation Loss: {val_stats['loss']:.4f}")
                
                # Save best model
                if val_stats["loss"] < self.best_val_loss:
                    self.best_val_loss = val_stats["loss"]
                    print(f"New best validation loss! Saving model...")
                    self.save_checkpoint(
                        Path(self.cfg.checkpoint.save_dir) / f"checkpoint_epoch_{epoch}.pt",
                        is_best=True
                    )
            
            # Save periodic checkpoint
            if (epoch + 1) % self.cfg.checkpoint.save_interval == 0:
                self.save_checkpoint(
                    Path(self.cfg.checkpoint.save_dir) / f"checkpoint_epoch_{epoch}.pt"
                )
            
            # Stop early if we've hit a global step cap
            max_steps = self.cfg.training.get("max_steps", 0) or 0
            if max_steps > 0 and self.global_step >= max_steps:
                break
        
        # Training complete
        total_time = time.time() - self.start_time
        print(f"\nTraining complete!")
        print(f"Total time: {total_time/3600:.2f} hours ({total_time:.1f}s)")
        print(f"Final train loss: {self.stats['train_loss'][-1]:.4f}")
        
        if self.val_loader is not None:
            print(f"Best validation loss: {self.best_val_loss:.4f}")
        
        # Restore original mask function if it was compiled
        if self.cfg.training.compile_mask and hasattr(self, 'original_create_block_mask'):
            import torch.nn.attention.flex_attention as flex
            flex.create_block_mask = self.original_create_block_mask
        
        if self.use_wandb:
            wandb.summary["total_training_time"] = total_time
            wandb.summary["best_val_loss"] = self.best_val_loss
            wandb.summary["final_train_loss"] = self.stats["train_loss"][-1]
            wandb.finish()


@hydra.main(version_base=None, config_path="configs", config_name="train")
def main(cfg: DictConfig):
    """Main training function."""
    # Create checkpoint directory
    checkpoint_dir = Path(cfg.checkpoint.save_dir)
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    
    # Initialize trainer
    trainer = CompiledTrainer(cfg)
    
    # Run training
    trainer.train()


if __name__ == "__main__":
    try:
        mp.set_start_method("spawn", force=True)
    except RuntimeError:
        pass
    main()
