"""Implementation of EDM2 model initialization for AutoFID with rejection sampling."""

import os
import logging
import torch
from accelerate.utils import set_seed
from torch.nn.utils import parameters_to_vector

from distrl.models.diffusion_models.utils import create_diffusion_model

logger = logging.getLogger(__name__)

def init_edm2_model(args, accelerator, weight_dtype):
    """
    Initialize EDM2 model components for training.

    Args:
        args: Command line arguments
        accelerator: Accelerator object for distributed training
        weight_dtype: Data type for model weights

    Returns:
        Tuple of (pipe, unet, vae, text_encoder, lora_layers, optimizer_policy, lr_scheduler_policy, start_count)
    """
    # Set random seed if provided
    if args.seed is not None:
        set_seed(args.seed, device_specific=True)

    # Create the EDM2 diffusion model
    model = create_diffusion_model(
        model_type="edm2",
        pretrained_path=args.pretrained_model_name_or_path if not args.resume_from_saved_model else None,
        sft_path=args.resume_from_saved_model,
        revision=args.revision,
        weight_dtype=weight_dtype
    )

    # Get model components
    components = model.get_components()
    pipe = components["pipe"]
    unet = components["unet"]
    vae = components["vae"]
    text_encoder = components["text_encoder"]  # Will be None for EDM2

    # Enable gradient for UNet
    model.set_gradient(True)

    # Enable xformers if specified (though EDM2 doesn't support it)
    if args.enable_xformers_memory_efficient_attention:
        model.enable_xformers()

    # Since EDM2 doesn't use LoRA layers, set it to None
    lora_layers = None

    # Initialize optimizer
    params_to_optimize = list(unet.parameters())

    # Log the number of parameters being optimized
    if accelerator.is_main_process:
        param_count = sum(p.numel() for p in params_to_optimize if p.requires_grad)
        logger.info(f"Number of optimized parameters: {param_count:,}")
        logger.info(f"Parameter vector norm: {torch.norm(parameters_to_vector(params_to_optimize)):,}")

    # Configure optimizer
    optimizer_kwargs = {
        "lr": args.learning_rate,
        "betas": (args.adam_beta1, args.adam_beta2),
        "weight_decay": args.adam_weight_decay,
        "eps": args.adam_epsilon,
    }

    # Use 8-bit Adam if specified
    if args.use_8bit_adam:
        try:
            import bitsandbytes as bnb
            optimizer_cls = bnb.optim.AdamW8bit
            logger.info("Using 8-bit Adam optimizer")
        except ImportError:
            logger.warning("bitsandbytes not found, using standard AdamW optimizer")
            optimizer_cls = torch.optim.AdamW
    else:
        optimizer_cls = torch.optim.AdamW

    # Create optimizer
    optimizer_policy = optimizer_cls(params_to_optimize, **optimizer_kwargs)

    # Create learning rate scheduler
    lr_scheduler_policy = get_scheduler(
        args.lr_scheduler,
        optimizer=optimizer_policy,
        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
    )

    # Prepare all components with accelerator
    unet, optimizer_policy, lr_scheduler_policy = accelerator.prepare(
        unet, optimizer_policy, lr_scheduler_policy
    )

    # Start count is 0 or resume from checkpoint
    start_count = 0
    if args.resume_from_checkpoint:
        path = args.resume_from_checkpoint
        if path == "latest":
            # Get the most recent checkpoint
            dirs = os.listdir(args.output_dir)
            dirs = [d for d in dirs if d.startswith("checkpoint")]
            dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
            path = dirs[-1] if len(dirs) > 0 else None
            if path is None:
                accelerator.print("No checkpoint found. Starting from scratch.")
                path = None

        if path is not None:
            accelerator.print(f"Resuming from checkpoint {path}")
            accelerator.load_state(os.path.join(args.output_dir, path))
            start_count = int(path.split("-")[1])

    return pipe, unet, vae, text_encoder, lora_layers, optimizer_policy, lr_scheduler_policy, start_count

# Helper function to get the lr scheduler
def get_scheduler(
    name,
    optimizer,
    num_warmup_steps=0,
    num_training_steps=0,
):
    """
    Create learning rate scheduler.

    Args:
        name: Name of the scheduler
        optimizer: Optimizer to use
        num_warmup_steps: Number of warmup steps
        num_training_steps: Total number of training steps

    Returns:
        Learning rate scheduler
    """
    from transformers.optimization import get_scheduler
    return get_scheduler(
        name=name,
        optimizer=optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps,
    )

class EDM2Loss:
    def __init__(self, P_mean=-0.4, P_std=1.0, sigma_data=0.5):
        self.P_mean = P_mean
        self.P_std = P_std
        self.sigma_data = sigma_data

    def __call__(self, net, images, labels=None):
        rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
        noise = torch.randn_like(images) * sigma
        denoised, logvar = net(images + noise, sigma, labels, return_logvar=True)
        loss = (weight / logvar.exp()) * ((denoised - images) ** 2) + logvar
        return loss

    def denoise(self, net, x, sigma, labels):
        return net(x, sigma, labels)

    def loss_known_sigma(self, net, xt, sigma, xt_1, class_labels):
        denoised, logvar = net(xt, sigma, class_labels, return_logvar=True)
        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
        loss = (weight.view(-1, 1, 1, 1) / logvar.exp()) * ((denoised - xt_1) ** 2) + logvar
        return loss
