from copy import deepcopy

import torch
from diffusers.models import AutoencoderKL
from loss import SILoss
from models.sit import SiT_models
from utils.model_utils import requires_grad
from utils.utils import load_encoders


def setup_model_and_optimizer(args, device, accelerator):
    if args.enc_type is not None:
        encoders, encoder_types, architectures = load_encoders(
            args.enc_type, device, args.resolution
        )
    else:
        raise NotImplementedError()
    z_dims = (
        [encoder.embed_dim for encoder in encoders] if args.enc_type != "None" else [0]
    )
    block_kwargs = {"fused_attn": args.fused_attn, "qk_norm": args.qk_norm}
    model = SiT_models[args.model](
        input_size=args.resolution // 8,
        num_classes=args.num_classes,
        use_cfg=(args.cfg_prob > 0),
        z_dims=z_dims,
        encoder_depth=args.encoder_depth,
        in_channels=args.in_channels,
        **block_kwargs,
    )
    model = model.to(device)
    ema = deepcopy(model).to(device)
    vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-mse").to(device)
    requires_grad(ema, False)
    latents_scale = (
        torch.tensor([0.18215, 0.18215, 0.18215, 0.18215]).view(1, 4, 1, 1).to(device)
    )
    latents_bias = torch.tensor([0.0, 0.0, 0.0, 0.0]).view(1, 4, 1, 1).to(device)
    loss_fn = SILoss(
        prediction=args.prediction,
        path_type=args.path_type,
        encoders=encoders,
        accelerator=accelerator,
        latents_scale=latents_scale,
        latents_bias=latents_bias,
        weighting=args.weighting,
    )
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=args.learning_rate,
        betas=(args.adam_beta1, args.adam_beta2),
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon,
    )
    return (
        model,
        ema,
        vae,
        encoders,
        optimizer,
        loss_fn,
        latents_scale,
        latents_bias,
        encoder_types,
        architectures,
    )
