"""
Simple training loop; Boilerplate that could apply to any arbitrary neural network,
so nothing in this file really has anything to do with GPT specifically.
"""

import logging
import math
import os
from typing import Any, Dict, Optional, Tuple

import numpy as np
import torch
from torch.nn import functional as F
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm

from utils import sample

logger = logging.getLogger(__name__)

try: 
    import wandb
    WANDB_AVAILABLE = True
except ImportError:
    WANDB_AVAILABLE = False
    logger.warning("wandb not installed, logging will be disabled")


class TrainerConfig:
    def __init__(self, **kwargs):
        # Optimization parameters (defaults)
        self.max_epochs = 10 
        self.batch_size = 64
        self.learning_rate = 3e-4
        self.betas = (0.9, 0.95)
        self.grad_norm_clip = 1.0 
        self.weight_decay = 0.1     # only applied on matmul weights

        # Learning rate schedule
        self.lr_decay = False
        self.warmup_iters = 0
        self.final_iters = 0        # (at what point we reach 10% of original LR)

        # Checkpoint and sample settings
        self.ckpt_dir = None 
        self.samples_dir = None
        self.samples_every = 1
        self.save_every = 10

        # Data loading
        self.num_workers = 0    # for DataLoader

        # Loss weighting
        self.loss_weight = 1
        self.length_loss_weight = 1

        # Wandb logging 
        self.use_wandb = True 

        # Apply overrides 
        for k, v in kwargs.items():
            setattr(self, k, v)

class Trainer:
    def __init__(self, model, train_dataset, test_dataset, config, args):
        self.model = model
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset
        self.config = config
        self.iters = 0
        self.fixed_x = None
        self.fixed_y = None

        if WANDB_AVAILABLE and config.use_wandb:
            self.use_wandb = True
            try:
                wandb.init(project="LayoutTransformer", name=args.exp)
                wandb.config.update(args)
                logger.info("Wandb initialized successfully")
            except Exception as e:
                logger.warning(f"Failed to initialize wandb: {e}")
                self.use_wandb = False
        else:
            self.use_wandb = False
            if config.use_wandb:
                logger.warning("Wandb not available but was requested. Logging disabled.")

        # Set up device
        self.device = "cpu"
        if torch.cuda.is_available():
            self.device = torch.cuda.current_device()
            self.model = torch.nn.DataParallel(self.model).to(self.device)

        self.pad_token = getattr(train_dataset, 'pad_token', None)
        if self.pad_token is None:
            # Create default pad token if not provided
            self.pad_token = torch.zeros(12, dtype=torch.float32)
            self.pad_token[-1] = 1.0        # Assume last position is pad token


    def save_checkpoint(self, epoch: int) -> None:
        # DataParallel wrappers keep raw model object in .module attribute
        raw_model = self.model.module if hasattr(self.model, "module") else self.model
        ckpt_path = os.path.join(self.config.ckpt_dir, f"checkpoint_{epoch}.pth")
        logger.info("saving %s", ckpt_path)
        torch.save(raw_model.state_dict(), ckpt_path)


    def load_checkpoint(self, path: str) -> None:
        try: 
            ckpt = torch.load(path, map_location=self.device)
            raw_model = self.model.module if hasattr(self.model, "module") else self.model
            raw_model.load_state_dict(ckpt)
            logger.info(f"Loaded checkpoint from {path}")
        except Exception as e:
            logger.error(f"Failed to load checkpoint: {e}")


    def train(self) -> None:
        """ Train model """
        model, config = self.model, self.config
        raw_model = model.module if hasattr(self.model, "module") else model
        optimizer = raw_model.configure_optimizers(config)

        logger.info(f"Starting training for {config.max_epochs} epochs")

        def run_epoch(split: str) -> float:
            """
            Run a single epoch of training or evaluation.

            Args:
                split: 'train' or 'val'
            
            Returns:
                Average loss for epoch
            """

            is_train = split == "train"
            model.train(is_train)

            data = self.train_dataset if is_train else self.test_dataset
            if data is None:
                logger.warning(f"No data available for {split}")
                return 0.0

            # Create data loader
            loader = DataLoader(
                data,
                shuffle=True,
                pin_memory=True,
                batch_size=config.batch_size,
                num_workers=config.num_workers,
            )

            losses = []
            pbar = tqdm(enumerate(loader), total=len(loader)) if is_train else enumerate(loader)

            for it, (x, y, mask) in pbar:
                # Save fixed examples for visualization
                if epoch == 0 and not is_train:
                    self.fixed_x = x[: min(4, len(x))]
                    self.fixed_y = y[: min(4, len(y))]

                # Place data on the correct device
                x = x.to(self.device)
                y = y.to(self.device)
                mask = mask.to(self.device)

                # Forward the model
                with torch.set_grad_enabled(is_train):
                    try:
                        ce_loss, diffusion_loss, length_loss = model(x, y, mask)

                        # Combine losses with weighting
                        loss = (
                            ce_loss
                            + config.loss_weight * diffusion_loss
                            + config.length_loss_weight * length_loss
                        )
                        loss = loss.mean()
                        
                        losses.append(loss.item())

                    except Exception as e:
                        logger.error(f"Error in forward pass: {e}")
                        continue
                
                # Training step
                if is_train:
                    # Backprop and update the parameters
                    model.zero_grad()
                    loss.backward()

                    # Gradient clipping
                    torch.nn.utils.clip_grad_norm_(
                        model.parameters(), config.grad_norm_clip
                    )

                    optimizer.step()
                    self.iters += 1

                    # Apply learning rate decay if configured
                    if config.lr_decay:
                        lr = self._update_learning_rate(optimizer)
                    else:
                        lr = config.learning_rate

                    # report progress
                    if self.use_wandb:
                        wandb.log(
                            {
                                "train loss": loss.item(),
                                "ce loss": ce_loss.item(),
                                "diffuion loss": diffusion_loss.item(),
                                "length loss": length_loss.item(),
                                "lr": lr,
                                "epoch": epoch + 1,
                            },
                            step=self.iters,
                        )

                    # Update progress bar
                    pbar_desc = f"epoch {epoch+1} iter {it}: train loss {loss.item():.5f}"
                    pbar_desc += f" ce {ce_loss.item():.5f} diff {diffusion_loss.item():.5f}"
                    pbar_desc += f" len {length_loss.mean().item():.5f} lr {lr:e}"
                    if isinstance(pbar, tqdm):
                        pbar.set_description(pbar_desc)
                
            avg_loss = float(np.mean(losses)) if losses else float('inf')

            if not is_train:
                logger.info(f"Validation loss: {avg_loss:.5f}")
                if self.use_wandb:
                    wandb.log({"val_loss": avg_loss}, step=self.iters)

                # Save current loss for checkpointing
                self.current_loss = avg_loss

            return avg_loss

        # Main training loop
        for epoch in range(config.max_epochs):
            # Training phase
            train_loss = run_epoch("train")

            # Validation phase
            if self.test_dataset is not None:
                with torch.no_grad():
                    test_loss = run_epoch("test")

            # supports early stopping based on the test loss, or just save always if no test set is provided
            # good_model = self.test_dataset is None or test_loss < best_loss
            # if self.config.ckpt_dir is not None and good_model:
            #     best_loss = test_loss

            # Saving every 10 epochs
            if epoch % 10 == 0:
                self.save_checkpoint(epoch)

            # Sample from the model
            if (
                self.config.samples_dir is not None
                and (epoch + 1) % self.config.sample_every == 0
            ):
                self._visualize_samples(epoch)

        
        # Save final model
        if self.config.ckpt_dir is not None:
            self.save_checkpoint(config.max_epochs - 1)


        logger.info("Training completed")


    def _update_learning_rate(self, optimizer) -> float:
        """
        Update learning rate according to schedule.
        
        Args:
            optimizer: Optimizer to update
            
        Returns:
            New learning rate
        """
        config = self.config
        
        if self.iters < config.warmup_iters:
            # Linear warmup
            lr_mult = float(self.iters) / float(max(1, config.warmup_iters))
        else:
            # Cosine decay
            progress = float(self.iters - config.warmup_iters) / float(
                max(1, config.final_iters - config.warmup_iters)
            )
            lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))
            
        # Apply multiplier to base learning rate
        lr = config.learning_rate * lr_mult
        
        # Update optimizer learning rates
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr
            
        return lr

    def _visualize_samples(self, epoch: int) -> None:
        """
        Generate samples from the model and log visualizations.
        
        Args:
            epoch: Current epoch number
        """
        if self.fixed_x is None or not hasattr(self.train_dataset, 'render'):
            logger.warning("Cannot visualize samples: missing fixed samples or render method")
            return
            
        try:
            # Ensure model is in eval mode
            self.model.eval()
            
            with torch.no_grad():
                # Generate input visualizations
                fixed_x_categories = self.fixed_x[:, :, 4:].argmax(dim=2, keepdim=True)
                fixed_x_coords = self.fixed_x[:, :, :4]
                input_layouts = torch.cat((fixed_x_categories, fixed_x_coords), dim=-1).cpu().numpy()
                
                input_images = []
                for layout in input_layouts:
                    try:
                        img = self.train_dataset.render(layout)
                        input_images.append(img)
                    except Exception as e:
                        logger.error(f"Error rendering input layout: {e}")
                
                # Generate reconstructions
                x_cond = self.fixed_x.to(self.device)
                reconstructions = self.model(x_cond).cpu().numpy()
                
                recon_images = []
                for layout in reconstructions:
                    try:
                        img = self.train_dataset.render(layout)
                        recon_images.append(img)
                    except Exception as e:
                        logger.error(f"Error rendering reconstruction: {e}")
                
                # Log images to wandb
                if self.use_wandb and input_images and recon_images:
                    wandb.log(
                        {
                            "input_layouts": [
                                wandb.Image(img, caption=f"input_{epoch:02d}_{i:02d}")
                                for i, img in enumerate(input_images)
                            ],
                            "reconstructions": [
                                wandb.Image(img, caption=f"recon_{epoch:02d}_{i:02d}")
                                for i, img in enumerate(recon_images)
                            ],
                        },
                        step=self.iters,
                    )
                    
                # Save images to disk if samples directory is specified
                if self.config.samples_dir is not None:
                    for i, (inp_img, rec_img) in enumerate(zip(input_images, recon_images)):
                        if hasattr(inp_img, 'save'):
                            inp_path = os.path.join(self.config.samples_dir, f"input_{epoch:03d}_{i:02d}.png")
                            inp_img.save(inp_path)
                            
                        if hasattr(rec_img, 'save'):
                            rec_path = os.path.join(self.config.samples_dir, f"recon_{epoch:03d}_{i:02d}.png")
                            rec_img.save(rec_path)
                            
        except Exception as e:
            logger.error(f"Error in sample visualization: {e}")