import importlib
import os
import sys

import torch

from omegaconf import OmegaConf

def create_latent_diffusion_model():
    sys.path.append(os.getcwd())
    config_path = "latent_diffusion/configs/latent-diffusion/cin-ldm-vq-f8.yaml"
    ckpt_path = "latent_diffusion/ckpts/cin/model.ckpt"
    config = OmegaConf.load(config_path)
    # model
    model = instantiate_from_config(config['model'])
    # Load the checkpoint
    checkpoint = torch.load(ckpt_path)

    # Get the state_dict
    state_dict = checkpoint["state_dict"]

    # Specify the keys you want to remove
    keys_to_remove = ["ddim_sigmas", "ddim_alphas", "ddim_alphas_prev", "ddim_sqrt_one_minus_alphas"]

    # Remove the keys
    for key in keys_to_remove:
        del state_dict[key]

    # Load the model's state_dict
    model.load_state_dict(checkpoint['state_dict'])

    model.to('cuda')
    return model


def instantiate_from_config(config):
    if not "target" in config:
        if config == '__is_first_stage__':
            return None
        elif config == "__is_unconditional__":
            return None
        raise KeyError("Expected key `target` to instantiate.")
    return get_obj_from_str(config["target"])(**config.get("params", dict()))


def get_obj_from_str(string, reload=False):
    module, cls = string.rsplit(".", 1)
    module = "latent_diffusion." + module
    if reload:
        module_imp = importlib.import_module(module)
        importlib.reload(module_imp)
    return getattr(importlib.import_module(module, package=None), cls)

if __name__ == "__main__":
    model = create_latent_diffusion_model()