"""Training script for ACE model with offline GP data and compilation optimizations."""

import os

os.environ["TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS"] = "TRITON"
os.environ["TORCHINDUCTOR_CPP_WRAPPER"] = "0"  # Changed from "1" to "0"
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1"
os.environ["MIOPEN_FIND_MODE"] = "1"
os.environ["PYTORCH_HIP_ALLOC_CONF"] = "max_split_size_mb:512"
os.environ["TORCHINDUCTOR_DISABLE_CUDAGRAPHS"] = "1"

# Configure inductor settings
import torch._inductor.config as config
config.max_autotune_gemm = True
config.rocm.n_max_profiling_configs = 10
config.compile_threads = 8
config.triton.unique_kernel_names = True
config.triton.cudagraphs = False
config.coordinate_descent_tuning = True
config.triton.persistent_reductions = True

import time
import random
from pathlib import Path
from typing import Dict, Optional, Tuple
from collections import defaultdict

import hydra
import math
import torch
import torch.nn as nn
import torch.multiprocessing as mp
import numpy as np
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader
from tqdm import tqdm

from src.data.utils import OfflineBatchLoader
from src.models.ace import AmortizedConditioningEngine
from src.models.masks import create_training_block_mask
from src.models.modules import Embedder, MixtureGaussian, Transformer, MultiChannelMixtureGaussian
from src.utils import DataAttr

try:
    import wandb
    WANDB_AVAILABLE = True
except ImportError:
    WANDB_AVAILABLE = False
    print("Warning: wandb not available. Install with 'pip install wandb' for experiment tracking.")

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

class CompiledTrainer:
    """Trainer class for ACE model with compilation optimizations."""
    
    def __init__(self, cfg: DictConfig):
        self.cfg = cfg
        self.device = torch.device(cfg.device)
        self.start_time = time.time()
        
        # Initialize wandb if available and enabled
        self.use_wandb = cfg.logging.use_wandb and WANDB_AVAILABLE
        if self.use_wandb:
            wandb.init(
                project=cfg.logging.project,
                name=cfg.logging.run_name,
                config=OmegaConf.to_container(cfg, resolve=True),
                tags=cfg.logging.get("tags", []),
            )
        
        # Setup model
        self.model = self._build_model()
        self.model = self.model.to(self.device)
        
        # Compile create_block_mask function first
        if cfg.training.compile_mask:
            self._compile_mask_function()
        
        # Compile model if enabled
        if cfg.training.compile_model:
            print(f"Compiling model with mode: {cfg.training.compile_mode}")
            compile_kwargs = {
                "fullgraph": cfg.training.get("fullgraph", False),
                "dynamic": cfg.training.get("dynamic", False),
            }
            
            # Only add mode if it's not "default"
            if cfg.training.compile_mode != "default":
                compile_kwargs["mode"] = cfg.training.compile_mode
            
            self.model = torch.compile(self.model, **compile_kwargs)
        
        # Setup optimizer
        self.optimizer = self._build_optimizer()
        
        # Setup mixed precision
        self.use_amp = cfg.training.use_amp
        self.amp_dtype = torch.bfloat16 if cfg.training.amp_dtype == "bfloat16" else torch.float16
        
        # Check bfloat16 support
        if self.use_amp and self.amp_dtype == torch.bfloat16:
            if self.device.type == "cpu":
                print("Using bfloat16 mixed precision on CPU")
            elif self.device.type == "cuda" and not torch.cuda.is_bf16_supported():
                print("Warning: bfloat16 not supported on this GPU, falling back to float16")
                self.amp_dtype = torch.float16
        
        # Setup gradient scaler for mixed precision training
        self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp and self.device.type == "cuda")
        
        # Setup data
        self.train_loader = self._build_dataloader("train")
        self.val_loader = self._build_dataloader("val") if cfg.data.val_path else None
        
        # Setup learning rate scheduler
        self.scheduler = self._build_scheduler()
        
        # Training state
        self.global_step = 0
        self.epoch = 0
        self.best_val_loss = float('inf')
        
        # Statistics tracking
        self.stats = {
            "train_loss": [],
            "val_loss": [],
            "learning_rate": [],
            "epoch_times": [],
            "step_times": [],
        }
        
        # Pre-create all possible block masks
        self.mask_cache = self._precreate_masks()
        
        # Pre-warm compilation for all shapes if enabled
        if cfg.training.compile_model and cfg.training.get("prewarm_compilation", True):
            self._prewarm_compilation()
    
    def _compile_mask_function(self):
        """Compile the create_block_mask function for better performance."""
        print("Compiling create_block_mask function...")
        import torch.nn.attention.flex_attention as flex
        
        # Save original function
        self.original_create_block_mask = flex.create_block_mask
        
        # Compile and replace
        create_block_mask_compiled = torch.compile(flex.create_block_mask)
        flex.create_block_mask = create_block_mask_compiled
        
        print("create_block_mask function compiled successfully")
    
    def _precreate_masks(self) -> Dict[Tuple[int, int, int], torch.Tensor]:
        """Pre-create all possible block masks for different sequence lengths."""
        print("Pre-creating block masks for all sequence length combinations...")
        mask_cache = {}
        
        buffer_len = self.cfg.model.max_buffer_size  # Can be 4, 8, or 16
        attending_chunks = self.cfg.model.get('attending_chunks', None)
        
        # Check if precompile_shapes is specified in config
        if 'precompile_shapes' in self.cfg.model:
            # Use shapes from config: [context+buffer, target]
            print(f"Using precompile_shapes from config: {len(self.cfg.model.precompile_shapes)} shapes")
            for shape in self.cfg.model.precompile_shapes:
                context_buffer_len, target_len = shape
                ctx_len = context_buffer_len - buffer_len
                total_len = context_buffer_len + target_len
                
                print(f"  Shape: context={ctx_len}, buffer={buffer_len}, target={target_len}")
                key = (total_len, ctx_len, buffer_len)
                
                mask_kwargs = {
                    'current_total_q_len': total_len,
                    'current_total_kv_len': total_len,
                    'current_context_section_len': ctx_len,
                    'current_buffer_section_len': buffer_len,
                    'device': self.device,
                }
                
                # Add attending_chunks if specified in config
                if attending_chunks is not None:
                    mask_kwargs['attending_chunks'] = attending_chunks
                    
                mask_cache[key] = create_training_block_mask(**mask_kwargs)
        else:
            # Fallback to default context lengths
            print("No precompile_shapes in config, using default context lengths")
            context_lens = [4, 8, 16, 32, 48, 64, 128, 192]  # 8 different sizes
            target_len = self.cfg.model.num_target_points  # Fixed at 256
            
            for ctx_len in context_lens:
                total_len = ctx_len + buffer_len + target_len
                key = (total_len, ctx_len, buffer_len)
                
                mask_kwargs = {
                    'current_total_q_len': total_len,
                    'current_total_kv_len': total_len,
                    'current_context_section_len': ctx_len,
                    'current_buffer_section_len': buffer_len,
                    'device': self.device,
                }
                
                # Add attending_chunks if specified in config
                if attending_chunks is not None:
                    mask_kwargs['attending_chunks'] = attending_chunks
                    
                mask_cache[key] = create_training_block_mask(**mask_kwargs)
        
        print(f"Created {len(mask_cache)} block masks")
        return mask_cache
    
    def _prewarm_compilation(self):
        """Pre-warm model compilation for all expected shapes."""
        print("Pre-warming model compilation for all shapes...")
        
        # Create dummy batches for each shape
        for (total_len, ctx_len, buf_len), mask in self.mask_cache.items():
            # Calculate dimensions
            tar_len = total_len - ctx_len - buf_len
            
            # Create dummy batch with correct batch size
            # Since our data files are pre-batched, batch_size might be null in config
            batch_size = self.cfg.data.get('batch_size') or 128  # Default to 128 if null
            dummy_batch = DataAttr(
                xc=torch.randn(batch_size, ctx_len, self.cfg.model.dim_x, device=self.device),
                yc=torch.randn(batch_size, ctx_len, self.cfg.model.dim_y, device=self.device),
                xb=torch.randn(batch_size, buf_len, self.cfg.model.dim_x, device=self.device),
                yb=torch.randn(batch_size, buf_len, self.cfg.model.dim_y, device=self.device),
                xt=torch.randn(batch_size, tar_len, self.cfg.model.dim_x, device=self.device),
                yt=torch.randn(batch_size, tar_len, self.cfg.model.dim_y, device=self.device),
            )
            
            # Run forward pass to trigger compilation
            with torch.no_grad():
                with torch.autocast(device_type=self.device.type, dtype=self.amp_dtype, enabled=self.use_amp):
                    _ = self.model(dummy_batch, mask)
            
        print("Pre-warming complete")
    
    def _build_model(self) -> AmortizedConditioningEngine:
        """Build ACE model from config."""
        print("Building model...")
        
        embedder = Embedder(
            dim_x=self.cfg.model.dim_x,
            dim_y=self.cfg.model.dim_y,
            hidden_dim=self.cfg.model.embedder.hidden_dim,
            out_dim=self.cfg.model.dim_model,
            depth=self.cfg.model.embedder.depth,
        )
        
        backbone = Transformer(
            num_layers=self.cfg.model.backbone.num_layers,
            dim_model=self.cfg.model.dim_model,
            num_head=self.cfg.model.backbone.num_heads,
            dim_feedforward=self.cfg.model.backbone.dim_feedforward,
            dropout=self.cfg.model.backbone.dropout,
        )
        
        # Build head with type selection and std_min support
        head_type = self.cfg.model.head.get('type', 'MixtureGaussian')

        head_kwargs = {
            'dim_y': self.cfg.model.dim_y,
            'dim_model': self.cfg.model.dim_model,
            'dim_feedforward': self.cfg.model.head.dim_feedforward,
            'num_components': self.cfg.model.head.num_components,
        }
        
        # Add std_min if specified in config
        if 'std_min' in self.cfg.model.head:
            head_kwargs['std_min'] = self.cfg.model.head.std_min
        
        # Create the head based on type
        if head_type == 'MixtureGaussian':
            head = MixtureGaussian(**head_kwargs)
        elif head_type == 'MultiChannelMixtureGaussian':
            head = MultiChannelMixtureGaussian(**head_kwargs)
        else:
            raise ValueError(f"Unknown head type: {head_type}")
        
        print(f"Using {head_type} head with {self.cfg.model.head.num_components} components")
        
        model = AmortizedConditioningEngine(
            embedder=embedder,
            backbone=backbone,
            head=head,
            max_buffer_size=self.cfg.model.max_buffer_size,
            num_target_points=self.cfg.model.num_target_points,
            targets_block_size_for_buffer_attend=self.cfg.model.targets_block_size_for_buffer_attend,
        )
        
        # Count parameters
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"Model built: {trainable_params:,} trainable parameters (total: {total_params:,})")
        
        return model
    
    def _build_optimizer(self) -> torch.optim.Optimizer:
        """Build optimizer from config."""
        if self.cfg.optimizer.name == "adam":
            return torch.optim.Adam(
                self.model.parameters(),
                lr=self.cfg.optimizer.lr,
                betas=tuple(self.cfg.optimizer.betas),
                weight_decay=self.cfg.optimizer.weight_decay,
            )
        elif self.cfg.optimizer.name == "adamw":
            return torch.optim.AdamW(
                self.model.parameters(),
                lr=self.cfg.optimizer.lr,
                betas=tuple(self.cfg.optimizer.betas),
                weight_decay=self.cfg.optimizer.weight_decay,
            )
        else:
            raise ValueError(f"Unknown optimizer: {self.cfg.optimizer.name}")
    
    def _build_scheduler(self) -> Optional[torch.optim.lr_scheduler._LRScheduler]:
        """Build learning rate scheduler from config."""
        if not self.cfg.scheduler.get("use_scheduler", False):
            return None
        
        # Calculate total training steps
        steps_per_epoch = len(self.train_loader)
        total_steps = steps_per_epoch * self.cfg.training.num_epochs
        warmup_steps = int(total_steps * self.cfg.scheduler.warmup_ratio)
        
        def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1):
            """Create a schedule with a learning rate that decreases following the values of the cosine function."""
            def lr_lambda(current_step):
                if current_step < num_warmup_steps:
                    return float(current_step) / float(max(1, num_warmup_steps))
                progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
                return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
            
            return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)
        
        if self.cfg.scheduler.name == "cosine":
            scheduler = get_cosine_schedule_with_warmup(
                self.optimizer,
                num_warmup_steps=0,
                num_training_steps=total_steps,
            )
        elif self.cfg.scheduler.name == "cosine_with_warmup":
            scheduler = get_cosine_schedule_with_warmup(
                self.optimizer,
                num_warmup_steps=warmup_steps,
                num_training_steps=total_steps,
            )
        else:
            raise ValueError(f"Unknown scheduler: {self.cfg.scheduler.name}")
        
        print(f"Using {self.cfg.scheduler.name} scheduler")
        print(f"  Total steps: {total_steps}, Warmup steps: {warmup_steps} ({self.cfg.scheduler.warmup_ratio*100:.0f}%)")
        return scheduler
    
    def _build_dataloader(self, split: str) -> DataLoader:
        """Build dataloader for given split."""
        data_path = Path(self.cfg.data.train_path if split == "train" else self.cfg.data.val_path)
        
        if not data_path.exists():
            raise ValueError(f"Data path does not exist: {data_path}")
        
        dataset = OfflineBatchLoader(data_path, device="cpu", cache_chunks=True, 
                                     max_buffer_size=self.cfg.model.max_buffer_size)

        num_workers = self.cfg.data.get('num_workers', 16)

        # For validation, optionally reduce dataset size
        if split == "val" and self.cfg.data.get('val_subset_size', 0) > 0:
            val_size = min(self.cfg.data.val_subset_size, len(dataset))
            indices = list(range(val_size))
            dataset = torch.utils.data.Subset(dataset, indices)
            print(f"Using subset of validation data: {val_size} batches")
        
        dataloader = DataLoader(
            dataset,
            batch_size=None,
            shuffle=(split == "train"),
            num_workers=num_workers,
            pin_memory=(self.device.type == "cuda"),
            persistent_workers=True,
        )
        
        print(f"Loaded {split} dataset: {len(dataset)} batches from {data_path}")
        print(f"  Using {num_workers} workers with prefetch_factor=4")
        return dataloader
    
    def _get_cached_mask(self, batch: DataAttr) -> torch.Tensor:
        """Get pre-cached block mask for the batch."""
        total_len = batch.xc.shape[1] + batch.xb.shape[1] + batch.xt.shape[1]
        key = (total_len, batch.xc.shape[1], batch.xb.shape[1])
        
        if key in self.mask_cache:
            return self.mask_cache[key]
        else:
            # Fallback: create mask on-the-fly and cache it
            print(f"Warning: Creating mask for new shape {key}")
            
            mask_kwargs = {
                'current_total_q_len': total_len,
                'current_total_kv_len': total_len,
                'current_context_section_len': batch.xc.shape[1],
                'current_buffer_section_len': batch.xb.shape[1],
                'device': self.device,
            }
            
            # Add attending_chunks if specified in config
            attending_chunks = self.cfg.model.get('attending_chunks', None)
            if attending_chunks is not None:
                mask_kwargs['attending_chunks'] = attending_chunks
                
            mask = create_training_block_mask(**mask_kwargs)
            self.mask_cache[key] = mask
            return mask
    
    def train_epoch(self) -> Dict[str, float]:
        """Train for one epoch."""
        self.model.train()
        epoch_start = time.time()
        
        total_loss = 0.0
        num_batches = 0
        step_times = []
        
        progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.epoch}")
        
        for batch_idx, batch in enumerate(progress_bar):
            step_start = time.time()
            
            # Move batch to device
            batch = batch.to(self.device)
            
            # Get pre-cached block mask
            block_mask = self._get_cached_mask(batch)
            
            # Forward pass with mixed precision
            with torch.autocast(device_type=self.device.type, dtype=self.amp_dtype, enabled=self.use_amp):
                output = self.model(batch, block_mask)
                loss = output.loss
            
            # Backward pass
            self.optimizer.zero_grad()
            
            if self.use_amp and self.device.type == "cuda":
                # Use gradient scaling for CUDA (works for both NVIDIA and AMD)
                self.scaler.scale(loss).backward()
                
                if self.cfg.training.grad_clip > 0:
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(), 
                        self.cfg.training.grad_clip
                    )
                
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                # Standard backward pass for CPU or no AMP
                loss.backward()
                
                if self.cfg.training.grad_clip > 0:
                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(), 
                        self.cfg.training.grad_clip
                    )
                
                self.optimizer.step()
            
            # Synchronize for accurate timing on GPU
            if self.device.type == "cuda" and torch.cuda.is_available():
                torch.cuda.synchronize()
            
            # Update learning rate scheduler
            if self.scheduler is not None:
                self.scheduler.step()
            
            # Track statistics
            step_time = time.time() - step_start
            step_times.append(step_time)
            total_loss += loss.item()
            num_batches += 1
            self.global_step += 1
            
            # Update progress bar
            progress_bar.set_postfix({
                "loss": f"{loss.item():.4f}",
                "lr": f"{self.optimizer.param_groups[0]['lr']:.2e}",
                "step_time": f"{step_time:.3f}s"
            })
            
            # Log to wandb
            if self.use_wandb and self.global_step % self.cfg.logging.log_interval == 0:
                wandb.log({
                    "train/loss": loss.item(),
                    "train/learning_rate": self.optimizer.param_groups[0]['lr'],
                    "train/step_time": step_time,
                    "train/global_step": self.global_step,
                }, step=self.global_step)
        
        epoch_time = time.time() - epoch_start
        avg_loss = total_loss / num_batches
        
        return {
            "loss": avg_loss,
            "epoch_time": epoch_time,
            "avg_step_time": sum(step_times) / len(step_times),
            "steps_per_second": len(step_times) / epoch_time,
        }
    
    def validate(self) -> Dict[str, float]:
        """Validate model."""
        if self.val_loader is None:
            return {}
        
        self.model.eval()
        total_loss = 0.0
        num_batches = 0
        
        with torch.no_grad():
            for batch in tqdm(self.val_loader, desc="Validation"):
                batch = batch.to(self.device)
                block_mask = self._get_cached_mask(batch)
                
                with torch.autocast(device_type=self.device.type, dtype=self.amp_dtype, enabled=self.use_amp):
                    output = self.model(batch, block_mask)
                
                total_loss += output.loss.item()
                num_batches += 1
        
        avg_loss = total_loss / num_batches
        
        if self.use_wandb:
            wandb.log({
                "val/loss": avg_loss,
                "epoch": self.epoch,
            }, step=self.global_step)
        
        return {"loss": avg_loss}
    
    def save_checkpoint(self, path: Path, is_best: bool = False):
        """Save model checkpoint."""
        checkpoint = {
            "epoch": self.epoch,
            "global_step": self.global_step,
            "model_state_dict": self.model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
            "config": OmegaConf.to_container(self.cfg),
            "stats": self.stats,
        }
        
        torch.save(checkpoint, path)
        
        if is_best:
            best_path = path.parent / "best_model.pt"
            torch.save(checkpoint, best_path)
            print(f"Saved best model to {best_path}")
    
    def train(self):
        """Main training loop."""
        print(f"\nStarting training for {self.cfg.training.num_epochs} epochs")
        print(f"Device: {self.device}")
        print(f"Compile Model: {self.cfg.training.compile_model}")
        print(f"Compile Mask: {self.cfg.training.compile_mask}")
        print(f"Mixed Precision: {self.use_amp} ({self.amp_dtype if self.use_amp else 'disabled'})")
        print(f"WandB: {self.use_wandb}")
        print(f"Scheduler: {self.cfg.scheduler.name if self.scheduler else 'None'}")
        print(f"Mask Cache Size: {len(self.mask_cache)}")
        
        # Detect if running on AMD GPU
        if self.device.type == "cuda" and hasattr(torch.version, 'hip'):
            print("Running on AMD GPU (ROCm)")
        
        for epoch in range(self.cfg.training.num_epochs):
            self.epoch = epoch
            
            # Train epoch
            train_stats = self.train_epoch()
            self.stats["train_loss"].append(train_stats["loss"])
            self.stats["epoch_times"].append(train_stats["epoch_time"])
            
            print(f"\nEpoch {epoch} - Train Loss: {train_stats['loss']:.4f}, "
                  f"Time: {train_stats['epoch_time']:.1f}s, "
                  f"Steps/s: {train_stats['steps_per_second']:.1f}")
            
            # Validate
            if self.val_loader is not None and (epoch + 1) % self.cfg.training.get('val_interval', 1) == 0:
                val_stats = self.validate()
                self.stats["val_loss"].append(val_stats["loss"])
                print(f"Validation Loss: {val_stats['loss']:.4f}")
                
                # Save best model
                if val_stats["loss"] < self.best_val_loss:
                    self.best_val_loss = val_stats["loss"]
                    print(f"New best validation loss! Saving model...")
                    self.save_checkpoint(
                        Path(self.cfg.checkpoint.save_dir) / f"checkpoint_epoch_{epoch}.pt",
                        is_best=True
                    )
            
            # Save periodic checkpoint
            if (epoch + 1) % self.cfg.checkpoint.save_interval == 0:
                self.save_checkpoint(
                    Path(self.cfg.checkpoint.save_dir) / f"checkpoint_epoch_{epoch}.pt"
                )
        
        # Training complete
        total_time = time.time() - self.start_time
        print(f"\nTraining complete!")
        print(f"Total time: {total_time/3600:.2f} hours ({total_time:.1f}s)")
        print(f"Final train loss: {self.stats['train_loss'][-1]:.4f}")
        
        if self.val_loader is not None:
            print(f"Best validation loss: {self.best_val_loss:.4f}")
        
        # Restore original mask function if it was compiled
        if self.cfg.training.compile_mask and hasattr(self, 'original_create_block_mask'):
            import torch.nn.attention.flex_attention as flex
            flex.create_block_mask = self.original_create_block_mask
        
        if self.use_wandb:
            wandb.summary["total_training_time"] = total_time
            wandb.summary["best_val_loss"] = self.best_val_loss
            wandb.summary["final_train_loss"] = self.stats["train_loss"][-1]
            wandb.finish()


@hydra.main(version_base=None, config_path="configs", config_name="train")
def main(cfg: DictConfig):
    """Main training function."""

    set_seed(cfg.get('seed', 42))
    # Create checkpoint directory
    checkpoint_dir = Path(cfg.checkpoint.save_dir)
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    
    # Initialize trainer
    trainer = CompiledTrainer(cfg)
    
    # Run training
    trainer.train()


if __name__ == "__main__":
    import sys
    try:
        mp.set_start_method("spawn", force=True)
    except RuntimeError:
        pass

    try:
        main()
    except BaseException as e:
        print(f"Exception: {e}", flush=True)
    finally:
        # fflush everything
        sys.stdout.flush()
        sys.stderr.flush()
    