"""Implementation of SIT model initialization for AutoFID with rejection sampling."""

import os
import logging
import torch
from copy import deepcopy
from accelerate.utils import set_seed
from torch.nn.utils import parameters_to_vector

from distrl.models.diffusion_models.utils import create_diffusion_model

logger = logging.getLogger(__name__)

def init_sit_model(args, accelerator, weight_dtype, unet_copy=False):
    """
    Initialize SIT model components for training.

    Args:
        args: Command line arguments
        accelerator: Accelerator object for distributed training
        weight_dtype: Data type for model weights

    Returns:
        Tuple of (pipe, unet, vae, text_encoder, lora_layers, optimizer_policy, lr_scheduler_policy, start_count)
    """
    # Set environment variables for SIT configuration
    os.environ['DISTRL_SIT_IMAGE_SIZE'] = str(args.image_size) if hasattr(args, 'image_size') else '256'
    os.environ['DISTRL_SIT_MODE'] = os.environ.get('DISTRL_SIT_MODE', 'ODE')  # Default to ODE sampling
    os.environ['DISTRL_SIT_NUM_STEPS'] = str(args.num_inference_steps)

    # Set random seed if provided
    if args.seed is not None:
        set_seed(args.seed, device_specific=True)

    # Create the SIT diffusion model
    model = create_diffusion_model(
        model_type="sit",
        pretrained_path=args.pretrained_model_name_or_path if not args.resume_from_saved_model else None,
        sft_path=args.resume_from_saved_model,
        revision=args.revision,
        weight_dtype=weight_dtype
    )

    # Get model components
    components = model.get_components()
    pipe = components["pipe"]
    unet = components["unet"]
    vae = components["vae"]
    text_encoder = components["text_encoder"]  # Will be None for SIT

    # Enable gradient for UNet
    model.set_gradient(True)

    # Enable xformers if specified (though SIT may not support it)
    if args.enable_xformers_memory_efficient_attention:
        model.enable_xformers()

    # Since SIT doesn't use LoRA layers, set it to None
    lora_layers = None

    # Initialize optimizer
    params_to_optimize = list(unet.parameters())

    # Log the number of parameters being optimized
    if accelerator.is_main_process:
        param_count = sum(p.numel() for p in params_to_optimize if p.requires_grad)
        logger.info(f"Number of optimized parameters: {param_count:,}")
        logger.info(f"Parameter vector norm: {torch.norm(parameters_to_vector(params_to_optimize)):,}")

    # Configure optimizer with AdamW (following SIT's train.py)
    optimizer_kwargs = {
        "lr": args.learning_rate,
        "betas": (args.adam_beta1, args.adam_beta2),
        "weight_decay": args.adam_weight_decay,
        "eps": args.adam_epsilon,
    }

    # Use 8-bit Adam if specified
    if args.use_8bit_adam:
        try:
            import bitsandbytes as bnb
            optimizer_cls = bnb.optim.AdamW8bit
            logger.info("Using 8-bit Adam optimizer")
        except ImportError:
            logger.warning("bitsandbytes not found, using standard AdamW optimizer")
            optimizer_cls = torch.optim.AdamW
    else:
        optimizer_cls = torch.optim.AdamW

    # Create optimizer
    optimizer_policy = optimizer_cls(params_to_optimize, **optimizer_kwargs)

    # Create learning rate scheduler
    lr_scheduler_policy = get_scheduler(
        args.lr_scheduler,
        optimizer=optimizer_policy,
        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
    )

    if unet_copy:
        unet_copy = deepcopy(unet)
        unet_copy.eval()
        unet_copy.requires_grad_(False)
        unet_copy.to(accelerator.device, dtype=weight_dtype)

    # Prepare all components with accelerator
    unet, optimizer_policy, lr_scheduler_policy = accelerator.prepare(
        unet, optimizer_policy, lr_scheduler_policy
    )
    vae = vae.to(accelerator.device)
    pipe.vae = vae
    pipe.vae.eval()

    # Start count is 0 or resume from checkpoint
    start_count = 0
    if args.resume_from_checkpoint:
        path = args.resume_from_checkpoint
        if path == "latest":
            # Get the most recent checkpoint
            dirs = os.listdir(args.output_dir)
            dirs = [d for d in dirs if d.startswith("checkpoint")]
            dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
            path = dirs[-1] if len(dirs) > 0 else None
            if path is None:
                accelerator.print("No checkpoint found. Starting from scratch.")
                path = None
        if path is not None:
            accelerator.print(f"Resuming from checkpoint {path}")
            accelerator.load_state(os.path.join(args.output_dir, path))
            start_count = int(path.split("-")[1])
    elif args.resume_from_saved_model:
        dir_name = os.path.basename(args.resume_from_saved_model)
        if dir_name.startswith("save_"):
            try:
                start_count = int(dir_name.split("_")[1])
            except (IndexError, ValueError):
                start_count = 0
                print(f"Could not parse step count from {dir_name}, starting from 0")
        else:
            start_count = 0

    if unet_copy:
        return pipe, unet, unet_copy, vae, text_encoder, lora_layers, optimizer_policy, lr_scheduler_policy, start_count
    else:
        return pipe, unet, vae, text_encoder, lora_layers, optimizer_policy, lr_scheduler_policy, start_count

# Helper function to get the lr scheduler
def get_scheduler(
    name,
    optimizer,
    num_warmup_steps=0,
    num_training_steps=0,
):
    """
    Create learning rate scheduler.

    Args:
        name: Name of the scheduler
        optimizer: Optimizer to use
        num_warmup_steps: Number of warmup steps
        num_training_steps: Total number of training steps

    Returns:
        Learning rate scheduler
    """
    from transformers.optimization import get_scheduler
    return get_scheduler(
        name=name,
        optimizer=optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps,
    )

class SITLoss:
    """
    Loss function for SIT model training.
    Based on the transport-based loss implementation in SIT.
    """
    def __init__(self, path_type='linear', prediction='velocity', loss_weight='snr+mse',
                 sigma_data=0.5, train_eps=1e-3, sample_eps=1e-3):
        """
        Initialize SITLoss with transport settings.

        Args:
            path_type: Type of transport path ('linear', 'cosine', etc.)
            prediction: Type of prediction ('velocity', 'noise', etc.)
            loss_weight: Weighting scheme for the loss ('snr+mse', etc.)
            sigma_data: Data noise level
            train_eps: Training epsilon value
            sample_eps: Sampling epsilon value
        """
        from distrl.models.diffusion_models.sit.original.transport import create_transport
        self.path_type = path_type
        self.prediction = prediction
        self.loss_weight = loss_weight
        self.sigma_data = sigma_data
        self.train_eps = train_eps
        self.sample_eps = sample_eps

        # Create transport for loss calculation
        self.transport = create_transport(
            path_type, prediction, loss_weight, train_eps, sample_eps
        )

    def __call__(self, net, images, labels=None):
        """
        Calculate loss for SIT model using transport-based approach.

        Args:
            net: SIT model
            images: Input images
            labels: Class labels

        Returns:
            Loss tensor
        """
        model_kwargs = {}
        if labels is not None:
            model_kwargs['y'] = labels

        # Use transport's training_losses method
        loss_dict = self.transport.training_losses(net, images, model_kwargs)
        return loss_dict["loss"].mean()

    def denoise(self, net, x, sigma, labels=None):
        """
        Apply denoising using the SIT model.

        Args:
            net: SIT model
            x: Noisy input
            sigma: Noise level
            labels: Optional class labels

        Returns:
            Denoised output
        """
        model_kwargs = {}
        if labels is not None:
            model_kwargs['y'] = labels

        # Use model's forward with net inputs
        return net(x, sigma, **model_kwargs)

    def loss_known_sigma(self, net, xt, sigma, xt_1, class_labels=None):
        """
        Calculate loss for known sigma value.

        Args:
            net: SIT model
            xt: Current noisy state
            sigma: Noise level
            xt_1: Target clean state
            class_labels: Optional class labels

        Returns:
            Loss tensor
        """
        model_kwargs = {}
        if class_labels is not None:
            model_kwargs['y'] = class_labels

        # Simplified MSE loss between denoised prediction and target
        pred = net(xt, sigma, **model_kwargs)
        loss = ((pred - xt_1) ** 2).mean()
        return loss
