import torch
from diffusers import DDPMScheduler, DDIMScheduler
import sys
import os
sys.path.append('./')
from pipeline.pipeline_stable_diffusion_dap import StableDiffusionPipelineDAP
from model.unet_2d_condition_dap import UNet2DConditionModelDAP
from model.dit_transformer_2d_dap import DiTTransformer2DModelDAP
from pipeline.pipeline_dit_dap import DiTPipelineDAP
from diffusers import AutoencoderKL


def load_sampling_model(args, accelerator, wandb_run):
    if args.model_type.lower() == 'dit':
        model = DiTPipelineDAP.from_pretrained(args.pretrained_model_name_or_path, torch_dtype=torch.float32, wandb_run=wandb_run).to(accelerator.device)
        model.transformer.requires_grad_(False)
        model.vae.requires_grad_(False)
        model.enable_attention_slicing()
        model.scheduler = DDIMScheduler.from_config(model.scheduler.config)
        model.transformer = DiTTransformer2DModelDAP.from_pretrained('path to your transformer model')
    elif args.model_type.lower() == 'sd':
        model = StableDiffusionPipelineDAP.from_pretrained(args.pretrained_model_name_or_path, torch_dtype=torch.float32, safety_checker=None, wandb_run=wandb_run).to(accelerator.device)
        model.vae.requires_grad_(False)
        model.text_encoder.requires_grad_(False)
        model.unet.requires_grad_(False)
        model.enable_attention_slicing()
        model.scheduler = DDIMScheduler.from_config(model.scheduler.config)
        model.unet = UNet2DConditionModelDAP.from_pretrained('path to your unet models')
    return model

    
def initialize_models(args, accelerator, wandb_run):
    model = load_sampling_model(args, accelerator, wandb_run)
    model.to(accelerator.device)
    return model