# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch

#################################################################################
#                         Performance Settings (FiT-style)                      #
#################################################################################
# TF32 for faster training on Ampere/Ada GPUs
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')
from torch.utils.data import DataLoader
from collections import OrderedDict
from copy import deepcopy
from glob import glob
from time import time
import logging
import os
import shutil
import sys

from accelerate import Accelerator
from accelerate.utils import set_seed

from src.diffusion import create_diffusion
from src.model import NNiT
from src.dataset import MultiHDF5ArchitectureWeightDataset
from src.train_utils import load_yaml_config
import math


#################################################################################
#                             Training Helper Functions                         #
#################################################################################

def get_cosine_schedule_with_warmup_and_min_lr(
    optimizer, num_warmup_steps: int, num_training_steps: int, min_lr_ratio: float = 0.0
):
    """
    Cosine decay schedule: warmup from 0 -> lr, then decay from lr -> min_lr.

    Args:
        optimizer: The optimizer
        num_warmup_steps: Steps for linear warmup
        num_training_steps: Total training steps
        min_lr_ratio: Minimum LR as ratio of initial LR (e.g., 0.143 for 1e-5/7e-5)
    """
    def lr_lambda(current_step):
        # Warmup phase: 0 -> 1.0
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        # Cosine decay phase: 1.0 -> min_lr_ratio
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
        return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

@torch.no_grad()
def update_ema(ema_model, model, decay=0.9999):
    """
    Step the EMA model towards the current model.
    """
    ema_params = OrderedDict(ema_model.named_parameters())
    model_params = OrderedDict(model.named_parameters())

    for name, param in model_params.items():
        name = name.replace("module.", "")
        # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
        ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)


def requires_grad(model, flag=True):
    """
    Set requires_grad flag for all parameters in a model.
    """
    for p in model.parameters():
        p.requires_grad = flag


def create_logger(logging_dir):
    """
    Create a logger that writes to a log file and stdout.
    """
    logging.basicConfig(
        level=logging.INFO,
        format='[\033[34m%(asctime)s\033[0m] %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
    )
    logger = logging.getLogger(__name__)
    return logger


def load_latest_checkpoint(checkpoint_dir, model, ema=None, opt=None, lr_scheduler=None, accelerator=None, device="cuda"):
    """
    Loads the latest checkpoint from the given directory.

    Args:
        checkpoint_dir: Directory containing checkpoint files
        model: Model to load the weights into (should be already wrapped by accelerator)
        ema: Optional EMA model to load state
        opt: Optional optimizer to load state
        lr_scheduler: Optional learning rate scheduler to load state
        accelerator: Accelerator instance for unwrapping model
        device: Device to load the model to

    Returns:
        loaded_step: The training step of the loaded checkpoint
        config: The config saved with the checkpoint
        epoch: The epoch saved with the checkpoint (default 0 if not found)
    """
    try:
        checkpoint_files = glob(f"{checkpoint_dir}/*.pt")
        if not checkpoint_files:
            raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}.")

        # Sort by step number (filenames are like 0001000.pt, 0002000.pt, etc.)
        checkpoint_files.sort(key=lambda x: int(os.path.basename(x).split('.')[0]))
        latest_checkpoint = checkpoint_files[-1]

        print(f"Loading checkpoint from {latest_checkpoint}")
        checkpoint = torch.load(latest_checkpoint, map_location=device, weights_only=False)

        # Load model weights - unwrap from accelerator wrapper
        if "model" in checkpoint:
            if accelerator is not None:
                unwrapped_model = accelerator.unwrap_model(model)
                unwrapped_model.load_state_dict(checkpoint["model"])
            else:
                model.load_state_dict(checkpoint["model"])
            print("Loaded model weights")

        # Load EMA model weights if available
        if ema is not None and "ema" in checkpoint:
            ema.load_state_dict(checkpoint["ema"])
            print("Loaded EMA model weights")

        # Load optimizer state if provided
        if opt is not None and "opt" in checkpoint:
            opt.load_state_dict(checkpoint["opt"])
            print("Loaded optimizer state")

        # Load scheduler state if provided
        if lr_scheduler is not None and "scheduler" in checkpoint:
            lr_scheduler.load_state_dict(checkpoint["scheduler"])
            print("Loaded scheduler state")

        # Get the step number from the filename
        loaded_step = int(os.path.basename(latest_checkpoint).split('.')[0])
        config = checkpoint.get("config", None)
        epoch = checkpoint.get("epoch", 0)

        return loaded_step, config, epoch
    except Exception as e:
        print(f"Error loading checkpoint: {str(e)}")
        import traceback
        traceback.print_exc()
        return 0, None, 0

#################################################################################
#                                  Training Loop                                #
#################################################################################



def main(config, config_file_path=None):
    """Trains a new DiT model."""
    assert torch.cuda.is_available(), "Training currently requires at least one GPU."

    accelerator = Accelerator()
    set_seed(config['train']['global_seed'])
    device = accelerator.device

    if accelerator.is_main_process:
        os.makedirs(config['train']['results_dir'], exist_ok=True)  
        experiment_index = len(glob(f"{config['train']['results_dir']}/*"))
        model_string_name = config['train']['name'].replace("/", "-") 
        experiment_dir = f"{config['train']['results_dir']}/{experiment_index:03d}-{model_string_name}" 
        checkpoint_dir = f"{experiment_dir}/checkpoints"  
        os.makedirs(checkpoint_dir, exist_ok=True)
        logger = create_logger(experiment_dir)
        logger.info(f"Experiment directory created at {experiment_dir}")
        
        if config_file_path and os.path.exists(config_file_path):
            config_filename = os.path.basename(config_file_path)
            config_dest_path = os.path.join(experiment_dir, config_filename)
            shutil.copy2(config_file_path, config_dest_path)
            logger.info(f"Configuration file saved to {config_dest_path}")
        else:
            logger.warning("No configuration file path provided or file not found - config not saved to experiment directory")

    architecture_n_vocab = config['diffusion']['architecture_n_vocab']
    architecture_max_layers = config['diffusion']['architecture_max_layer']
    weight_max_size = config['diffusion']['weight_max_size']
    patch_size = config['diffusion']['patch_size']
    
    model = NNiT(
        architecture_max_layer=architecture_max_layers,
        architecture_n_vocab=architecture_n_vocab,
        weight_max_size=weight_max_size,   
        patch_size=patch_size,
        hidden_size=config['diffusion']['hidden_size'],
        depth = config['diffusion']['depth'],
        num_heads=config['diffusion']['num_heads'],
        mlp_ratio=config['diffusion']['mlp_ratio'],
        learn_sigma=config['diffusion']['learn_sigma'],
        use_swiglu=config['diffusion']['use_swiglu'],
        use_swiglu_large=config['diffusion']['use_swiglu_large'],
    ).to(device)
    
    
    ema = deepcopy(model).to(device)  
    requires_grad(ema, False)    
    diffusion = create_diffusion(
        timestep_respacing="",
        learn_sigma=config['diffusion']['learn_sigma'],
        predict_xstart=False,
    )
    if accelerator.is_main_process:
        logger.info(f"DiT Parameters: {sum(p.numel() for p in model.parameters()):,}")

    
    opt = torch.optim.AdamW(model.parameters(), lr=float(config['train']['lr']), weight_decay=0)

    
    dataset = MultiHDF5ArchitectureWeightDataset(
        file_path=config['data']['path'],
        architecture_max_layer=architecture_max_layers,
        weight_max_size=weight_max_size,
        patch_size=patch_size,
        with_bias=True,
    )

    
    seed = config['train']['global_seed']
    generator = torch.Generator()
    generator.manual_seed(seed)

    loader = DataLoader(
        dataset,
        batch_size=int(config['train']['global_batch_size'] // accelerator.num_processes),
        shuffle=True,
        num_workers=config['train']['num_workers'],
        pin_memory=True,
        drop_last=True,
        generator=generator,
    )
    if accelerator.is_main_process:
        logger.info(f"Dataset contains {len(dataset):,} MLP policies ({config['data']['path']})")

    
    warmup_steps = config['train']['warmup_steps']
    lr_schedule = config['train'].get('lr_schedule', 'constant_with_warmup')

    
    steps_per_epoch = len(dataset) // config['train']['global_batch_size']
    num_training_steps = config['train']['epochs'] * steps_per_epoch

    if lr_schedule == 'cosine':
        min_lr = float(config['train'].get('min_lr', 1e-6))
        min_lr_ratio = min_lr / float(config['train']['lr'])
        lr_scheduler = get_cosine_schedule_with_warmup_and_min_lr(
            opt,
            num_warmup_steps=warmup_steps,
            num_training_steps=num_training_steps,
            min_lr_ratio=min_lr_ratio
        )
        if accelerator.is_main_process:
            logger.info(f"Using cosine LR schedule: {config['train']['lr']} -> {min_lr} over {num_training_steps} steps")
    else:
        # Default: constant with warmup
        def warmup_lambda(step):
            if step < warmup_steps:
                return step / warmup_steps
            return 1.0
        lr_scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=warmup_lambda)
        if accelerator.is_main_process:
            logger.info(f"Using constant LR with warmup: {config['train']['lr']} after {warmup_steps} warmup steps")

    
    update_ema(ema, model, decay=0) 
    model.train()  
    ema.eval()  

    model, opt, loader = accelerator.prepare(model, opt, loader)

    start_step = 0
    start_epoch = 0
    if config['train']['checkpoint_dir'] is not None:
        if os.path.exists(config['train']['checkpoint_dir']):
            start_step, loaded_config, start_epoch = load_latest_checkpoint(
                config['train']['checkpoint_dir'],
                model,
                ema=ema,
                opt=opt,
                lr_scheduler=lr_scheduler,
                accelerator=accelerator,
                device=device
            )
            if loaded_config is not None:
                # Update config with loaded config, preserving any new settings
                for key in loaded_config:
                    if key not in config:
                        config[key] = loaded_config[key]
                    else:
                        # For nested dictionaries, update recursively
                        if isinstance(config[key], dict) and isinstance(loaded_config[key], dict):
                            config[key].update(loaded_config[key])
                        else:
                            config[key] = loaded_config[key]
            if accelerator.is_main_process:
                logger.info(f"Resuming training from step {start_step}")
        else:
            if accelerator.is_main_process:
                logger.info("No checkpoint directory found, starting training from scratch")
    else:
        if accelerator.is_main_process:
            logger.info("No checkpoint directory specified, starting training from scratch")

    # Ensure all processes are synchronized after loading checkpoint
    accelerator.wait_for_everyone()
    
    # Variables for monitoring/logging purposes:
    train_steps = start_step
    log_steps = 0
    running_loss = 0
    running_metrics = {}  # Dictionary to accumulate all loss components
    start_time = time()
    if accelerator.is_main_process:
        logger.info(f"Training for {config['train']['epochs']} epochs...")

    for epoch in range(start_epoch, config['train']['epochs']):
        if accelerator.is_main_process:
            logger.info(f"Beginning epoch {epoch}...")
        for architecture, weight in loader:
            weight = weight.to(device)
            architecture = architecture.to(device)

            model_kwargs = dict()

            # x_start = {"weight": weight, "architecture": architecture}
            x_start = {"architecture": architecture, "weight": weight}

            # Use MoNL mixed training (learns both vanilla and per-modality)
            with accelerator.autocast():
                loss_dict = diffusion.training_losses(model, x_start, model_kwargs)

            loss = loss_dict["loss"].mean()
            opt.zero_grad()
            accelerator.backward(loss)
            accelerator.clip_grad_norm_(model.parameters(), max_norm=1.0)
            opt.step()
            
            # Step learning rate scheduler
            lr_scheduler.step()
            
            update_ema(ema, model)

            # Log loss values:
            running_loss += loss.item()

            # Accumulate all loss components
            for key, value in loss_dict.items():
                if key != "loss":  # Skip the total loss as we already track it
                    if key not in running_metrics:
                        running_metrics[key] = 0
                    running_metrics[key] += value.mean().item()

            log_steps += 1
            train_steps += 1
            if train_steps % config['train']['log_every'] == 0:
                # Measure training speed:
                torch.cuda.synchronize()
                end_time = time()
                steps_per_sec = log_steps / (end_time - start_time)
                # Reduce loss history over all processes:
                avg_loss = torch.tensor(running_loss / log_steps, device=device)
                avg_loss = accelerator.reduce(avg_loss, reduction="mean").item()

                # Prepare the log message with specific order
                log_msg = f"(step={train_steps:07d}) Train Loss: {avg_loss:.6f}"

                # Add metrics in specific order: arch MSE, arch VB, weight MSE, weight VB
                metric_order = ["mse_architecture", "vb_architecture", "mse_weight", "vb_weight"]
                for key in metric_order:
                    if key in running_metrics:
                        avg_metric = torch.tensor(running_metrics[key] / log_steps, device=device)
                        avg_metric = accelerator.reduce(avg_metric, reduction="mean").item()
                        # Simplify the key names for display
                        display_name = key.replace("_", " ").title()
                        log_msg += f", {display_name}: {avg_metric:.6f}"

                log_msg += f", Train Steps/Sec: {steps_per_sec:.2f}"

                # Log current learning rate
                current_lr = lr_scheduler.get_last_lr()[0]
                log_msg += f", LR: {current_lr:.2e}"

                if accelerator.is_main_process:
                    logger.info(log_msg)
                # Reset monitoring variables:
                running_loss = 0
                running_metrics = {}
                log_steps = 0
                start_time = time()

            # Save DiT checkpoint:
            if train_steps % config['train']['ckpt_every'] == 0 and train_steps > 0:
                if accelerator.is_main_process:
                    checkpoint_data = {
                        "model": accelerator.unwrap_model(model).state_dict(),
                        "ema": ema.state_dict(),
                        "opt": opt.state_dict(),
                        "scheduler": lr_scheduler.state_dict(),
                        "config": config,  # Save the entire config for reproducibility
                        "epoch": epoch,
                    }
                    checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
                    torch.save(checkpoint_data, checkpoint_path)
                    logger.info(f"Saved checkpoint to {checkpoint_path}")
                # Wait for checkpoint to be saved before continuing
                accelerator.wait_for_everyone()


    model.eval()
    # important! This disables randomized embedding dropout
    # do any sampling/FID calculation/etc. with ema (or model) in eval mode ...

    # Save final checkpoint
    if accelerator.is_main_process:
        final_checkpoint_data = {
            "model": accelerator.unwrap_model(model).state_dict(),
            "ema": ema.state_dict(),
            "opt": opt.state_dict(),
            "scheduler": lr_scheduler.state_dict(),
            "config": config
        }
        final_checkpoint_path = f"{checkpoint_dir}/final_checkpoint.pt"
        torch.save(final_checkpoint_data, final_checkpoint_path)
        logger.info(f"Saved final checkpoint to {final_checkpoint_path}")

    accelerator.wait_for_everyone()

    if accelerator.is_main_process:
        logger.info("Done!")


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description='Train NNiT model')
    parser.add_argument('--config', type=str, required=True, help='Path to config YAML file')
    args = parser.parse_args()

    # Load configuration from YAML file
    config_file_path = args.config
    config = load_yaml_config(config_file_path)

    # Call main with the configuration dictionary and config file path
    main(config, config_file_path)
    
    
