"""Training script for ACE model with offline data."""

import time
from pathlib import Path
from typing import Dict, Optional, Tuple

import hydra
import math
import torch
import torch.nn as nn
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader
from tqdm import tqdm

from src.data.utils import OfflineBatchLoader
from src.models.ace import AmortizedConditioningEngine
from src.models.masks import create_training_block_mask
from src.models.modules import Embedder, MixtureGaussian, Transformer
from src.utils import DataAttr

try:
    import wandb
    WANDB_AVAILABLE = True
except ImportError:
    WANDB_AVAILABLE = False
    print("Warning: wandb not available. Install with 'pip install wandb' for experiment tracking.")


class Trainer:
    """Trainer class for ACE model."""
    
    def __init__(self, cfg: DictConfig):
        self.cfg = cfg
        self.device = torch.device(cfg.device)
        self.start_time = time.time()
        
        # Initialize wandb if available and enabled
        self.use_wandb = cfg.logging.use_wandb and WANDB_AVAILABLE
        if self.use_wandb:
            wandb.init(
                project=cfg.logging.project,
                name=cfg.logging.run_name,
                config=OmegaConf.to_container(cfg, resolve=True),
                tags=cfg.logging.get("tags", []),
            )
        
        # Setup model
        self.model = self._build_model()
        self.model = self.model.to(self.device)
        
        # Compile model if enabled
        if cfg.training.compile_model:
            print(f"Compiling model with backend: {cfg.training.compile_backend}")
            self.model = torch.compile(self.model, backend=cfg.training.compile_backend)
        
        # Setup optimizer
        self.optimizer = self._build_optimizer()
        
        # Setup mixed precision
        self.use_amp = cfg.training.use_amp
        self.amp_dtype = torch.bfloat16 if cfg.training.amp_dtype == "bfloat16" else torch.float16
        
        # Check bfloat16 support
        if self.use_amp and self.amp_dtype == torch.bfloat16:
            if self.device.type == "cpu":
                # CPU supports bfloat16 on recent PyTorch versions
                print("Using bfloat16 mixed precision on CPU")
            elif not torch.cuda.is_bf16_supported():
                print("Warning: bfloat16 not supported on this GPU, falling back to float16")
                self.amp_dtype = torch.float16
        
        self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp and self.device.type == "cuda")
        
        # Setup data
        self.train_loader = self._build_dataloader("train")
        self.val_loader = self._build_dataloader("val") if cfg.data.val_path else None
        
        # Setup learning rate scheduler (after dataloaders are created)
        self.scheduler = self._build_scheduler()
        
        # Training state
        self.global_step = 0
        self.epoch = 0
        self.best_val_loss = float('inf')
        
        # Statistics tracking
        self.stats = {
            "train_loss": [],
            "val_loss": [],
            "learning_rate": [],
            "epoch_times": [],
            "step_times": [],
        }
    
    def _build_model(self) -> AmortizedConditioningEngine:
        """Build ACE model from config."""
        print("Building model...")
        
        embedder = Embedder(
            dim_x=self.cfg.model.dim_x,
            dim_y=self.cfg.model.dim_y,
            hidden_dim=self.cfg.model.embedder.hidden_dim,
            out_dim=self.cfg.model.dim_model,
            depth=self.cfg.model.embedder.depth,
        )
        
        backbone = Transformer(
            num_layers=self.cfg.model.backbone.num_layers,
            dim_model=self.cfg.model.dim_model,
            num_head=self.cfg.model.backbone.num_heads,
            dim_feedforward=self.cfg.model.backbone.dim_feedforward,
            dropout=self.cfg.model.backbone.dropout,
        )
        
        head = MixtureGaussian(
            dim_y=self.cfg.model.dim_y,
            dim_model=self.cfg.model.dim_model,
            dim_feedforward=self.cfg.model.head.dim_feedforward,
            num_components=self.cfg.model.head.num_components,
        )
        
        model = AmortizedConditioningEngine(
            embedder=embedder,
            backbone=backbone,
            head=head,
            max_buffer_size=self.cfg.model.max_buffer_size,
            num_target_points=self.cfg.model.num_target_points,
            targets_block_size_for_buffer_attend=self.cfg.model.targets_block_size_for_buffer_attend,
        )
        
        # Count parameters
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"Model built: {trainable_params:,} trainable parameters (total: {total_params:,})")
        
        return model
    
    def _build_optimizer(self) -> torch.optim.Optimizer:
        """Build optimizer from config."""
        if self.cfg.optimizer.name == "adam":
            return torch.optim.Adam(
                self.model.parameters(),
                lr=self.cfg.optimizer.lr,
                betas=tuple(self.cfg.optimizer.betas),
                weight_decay=self.cfg.optimizer.weight_decay,
            )
        elif self.cfg.optimizer.name == "adamw":
            return torch.optim.AdamW(
                self.model.parameters(),
                lr=self.cfg.optimizer.lr,
                betas=tuple(self.cfg.optimizer.betas),
                weight_decay=self.cfg.optimizer.weight_decay,
            )
        else:
            raise ValueError(f"Unknown optimizer: {self.cfg.optimizer.name}")
    
    def _build_scheduler(self) -> Optional[torch.optim.lr_scheduler._LRScheduler]:
        """Build learning rate scheduler from config."""
        if not self.cfg.scheduler.get("use_scheduler", False):
            return None
        
        # Calculate total training steps
        steps_per_epoch = len(self.train_loader)
        total_steps = steps_per_epoch * self.cfg.training.num_epochs
        warmup_steps = int(total_steps * self.cfg.scheduler.warmup_ratio)
        
        # We'll implement a custom cosine schedule with warmup to avoid the warning
        def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1):
            """Create a schedule with a learning rate that decreases following the values of the cosine function."""
            def lr_lambda(current_step):
                if current_step < num_warmup_steps:
                    return float(current_step) / float(max(1, num_warmup_steps))
                progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
                return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
            
            return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)
        
        if self.cfg.scheduler.name == "cosine":
            # Pure cosine scheduler without warmup
            scheduler = get_cosine_schedule_with_warmup(
                self.optimizer,
                num_warmup_steps=0,
                num_training_steps=total_steps,
            )
        elif self.cfg.scheduler.name == "cosine_with_warmup":
            # Cosine scheduler with linear warmup
            scheduler = get_cosine_schedule_with_warmup(
                self.optimizer,
                num_warmup_steps=warmup_steps,
                num_training_steps=total_steps,
            )
        else:
            raise ValueError(f"Unknown scheduler: {self.cfg.scheduler.name}")
        
        print(f"Using {self.cfg.scheduler.name} scheduler")
        print(f"  Total steps: {total_steps}, Warmup steps: {warmup_steps} ({self.cfg.scheduler.warmup_ratio*100:.0f}%)")
        return scheduler
    
    def _build_dataloader(self, split: str) -> DataLoader:
        """Build dataloader for given split."""
        data_path = Path(self.cfg.data.train_path if split == "train" else self.cfg.data.val_path)
        
        if not data_path.exists():
            raise ValueError(f"Data path does not exist: {data_path}")
        
        dataset = OfflineBatchLoader(data_path, device="cpu")  # Load on CPU, move to device later
        
        # DataLoader with batch_size=None since data is pre-batched
        dataloader = DataLoader(
            dataset,
            batch_size=None,
            shuffle=(split == "train"),
            num_workers=self.cfg.data.num_workers,
            pin_memory=(self.device.type == "cuda"),
        )
        
        print(f"Loaded {split} dataset: {len(dataset)} batches from {data_path}")
        return dataloader
    
    def _create_block_mask(self, batch: DataAttr) -> torch.Tensor:
        """Create block mask for training."""
        total_len = batch.xc.shape[1] + batch.xb.shape[1] + batch.xt.shape[1]
        
        return create_training_block_mask(
            current_total_q_len=total_len,
            current_total_kv_len=total_len,
            current_context_section_len=batch.xc.shape[1],
            current_buffer_section_len=batch.xb.shape[1],
            current_target_block_size_for_buffer_attend=self.cfg.model.targets_block_size_for_buffer_attend,
            q_block_size=self.cfg.model.q_block_size,
            kv_block_size=self.cfg.model.kv_block_size,
            device=self.device,
        )
    
    def train_epoch(self) -> Dict[str, float]:
        """Train for one epoch."""
        self.model.train()
        epoch_start = time.time()
        
        total_loss = 0.0
        num_batches = 0
        step_times = []
        
        progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.epoch}")
        
        for batch_idx, batch in enumerate(progress_bar):
            step_start = time.time()
            
            # Move batch to device
            batch = batch.to(self.device)
            
            # Create block mask
            block_mask = self._create_block_mask(batch)
            
            # Forward pass with mixed precision
            with torch.autocast(device_type=self.device.type, dtype=self.amp_dtype, enabled=self.use_amp):
                output = self.model(batch, block_mask)
                loss = output.loss
            
            # Backward pass
            self.optimizer.zero_grad()
            
            if self.use_amp and self.device.type == "cuda":
                # Use gradient scaling for CUDA
                self.scaler.scale(loss).backward()
                
                # Gradient clipping with scaler
                if self.cfg.training.grad_clip > 0:
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(), 
                        self.cfg.training.grad_clip
                    )
                
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                # Standard backward pass for CPU or no AMP
                loss.backward()
                
                # Gradient clipping
                if self.cfg.training.grad_clip > 0:
                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(), 
                        self.cfg.training.grad_clip
                    )
                
                self.optimizer.step()
            
            # Update learning rate scheduler
            if self.scheduler is not None:
                self.scheduler.step()
            
            # Track statistics
            step_time = time.time() - step_start
            step_times.append(step_time)
            total_loss += loss.item()
            num_batches += 1
            self.global_step += 1
            
            # Update progress bar
            progress_bar.set_postfix({
                "loss": f"{loss.item():.4f}",
                "lr": f"{self.optimizer.param_groups[0]['lr']:.2e}",
                "step_time": f"{step_time:.3f}s"
            })
            
            # Log to wandb
            if self.use_wandb and self.global_step % self.cfg.logging.log_interval == 0:
                wandb.log({
                    "train/loss": loss.item(),
                    "train/learning_rate": self.optimizer.param_groups[0]['lr'],
                    "train/step_time": step_time,
                    "train/global_step": self.global_step,
                }, step=self.global_step)
        
        epoch_time = time.time() - epoch_start
        avg_loss = total_loss / num_batches
        
        return {
            "loss": avg_loss,
            "epoch_time": epoch_time,
            "avg_step_time": sum(step_times) / len(step_times),
            "steps_per_second": len(step_times) / epoch_time,
        }
    
    def validate(self) -> Dict[str, float]:
        """Validate model."""
        if self.val_loader is None:
            return {}
        
        self.model.eval()
        total_loss = 0.0
        num_batches = 0
        
        with torch.no_grad():
            for batch in tqdm(self.val_loader, desc="Validation"):
                batch = batch.to(self.device)
                block_mask = self._create_block_mask(batch)
                
                with torch.autocast(device_type=self.device.type, dtype=self.amp_dtype, enabled=self.use_amp):
                    output = self.model(batch, block_mask)
                
                total_loss += output.loss.item()
                num_batches += 1
        
        avg_loss = total_loss / num_batches
        
        # Log to wandb
        if self.use_wandb:
            wandb.log({
                "val/loss": avg_loss,
                "epoch": self.epoch,
            }, step=self.global_step)
        
        return {"loss": avg_loss}
    
    def save_checkpoint(self, path: Path, is_best: bool = False):
        """Save model checkpoint."""
        checkpoint = {
            "epoch": self.epoch,
            "global_step": self.global_step,
            "model_state_dict": self.model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
            "config": OmegaConf.to_container(self.cfg),
            "stats": self.stats,
        }
        
        torch.save(checkpoint, path)
        
        if is_best:
            best_path = path.parent / "best_model.pt"
            torch.save(checkpoint, best_path)
            print(f"Saved best model to {best_path}")
    
    def train(self):
        """Main training loop."""
        print(f"\nStarting training for {self.cfg.training.num_epochs} epochs")
        print(f"Device: {self.device}")
        print(f"Compile: {self.cfg.training.compile_model}")
        print(f"Mixed Precision: {self.use_amp} ({self.amp_dtype if self.use_amp else 'disabled'})")
        print(f"WandB: {self.use_wandb}")
        print(f"Scheduler: {self.cfg.scheduler.name if self.scheduler else 'None'}")
        
        for epoch in range(self.cfg.training.num_epochs):
            self.epoch = epoch
            
            # Train epoch
            train_stats = self.train_epoch()
            self.stats["train_loss"].append(train_stats["loss"])
            self.stats["epoch_times"].append(train_stats["epoch_time"])
            
            print(f"\nEpoch {epoch} - Train Loss: {train_stats['loss']:.4f}, "
                  f"Time: {train_stats['epoch_time']:.1f}s, "
                  f"Steps/s: {train_stats['steps_per_second']:.1f}")
            
            # Validate at specified intervals
            if self.val_loader is not None and (epoch + 1) % self.cfg.training.get('val_interval', 1) == 0:
                val_stats = self.validate()
                self.stats["val_loss"].append(val_stats["loss"])
                print(f"Validation Loss: {val_stats['loss']:.4f}")
                
                # Save best model
                if val_stats["loss"] < self.best_val_loss:
                    self.best_val_loss = val_stats["loss"]
                    print(f"New best validation loss! Saving model...")
                    self.save_checkpoint(
                        Path(self.cfg.checkpoint.save_dir) / f"checkpoint_epoch_{epoch}.pt",
                        is_best=True
                    )
            
            # Save periodic checkpoint
            if (epoch + 1) % self.cfg.checkpoint.save_interval == 0:
                self.save_checkpoint(
                    Path(self.cfg.checkpoint.save_dir) / f"checkpoint_epoch_{epoch}.pt"
                )
        
        # Training complete
        total_time = time.time() - self.start_time
        print(f"\nTraining complete!")
        print(f"Total time: {total_time/3600:.2f} hours ({total_time:.1f}s)")
        print(f"Final train loss: {self.stats['train_loss'][-1]:.4f}")
        
        if self.val_loader is not None:
            print(f"Best validation loss: {self.best_val_loss:.4f}")
            print(f"Final validation loss: {self.stats['val_loss'][-1]:.4f}" if self.stats['val_loss'] else "No validation performed")
        
        # Print average performance
        avg_steps_per_sec = sum([epoch_stats / time for epoch_stats, time in 
                               zip([len(self.train_loader)] * len(self.stats["epoch_times"]), 
                                   self.stats["epoch_times"])]) / len(self.stats["epoch_times"])
        print(f"Average training speed: {avg_steps_per_sec:.1f} steps/second")
        
        # Final statistics
        if self.use_wandb:
            wandb.summary["total_training_time"] = total_time
            wandb.summary["best_val_loss"] = self.best_val_loss
            wandb.summary["final_train_loss"] = self.stats["train_loss"][-1]
            wandb.summary["final_val_loss"] = self.stats["val_loss"][-1] if self.stats["val_loss"] else None


@hydra.main(version_base=None, config_path="../configs", config_name="train")
def main(cfg: DictConfig):
    """Main training function."""
    # Create checkpoint directory
    checkpoint_dir = Path(cfg.checkpoint.save_dir)
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    
    # Initialize trainer
    trainer = Trainer(cfg)
    
    # Run training
    trainer.train()
    
    # Close wandb
    if trainer.use_wandb:
        wandb.finish()


if __name__ == "__main__":
    main()