import argparse
from enum import Enum

import numpy as np
import torch
from torchvision import transforms
import argparse
from guided_diffusion.guided_diffusion import dist_util
from guided_diffusion.guided_diffusion.script_util import (
    create_model_and_diffusion,
    args_to_dict,
    add_dict_to_argparser,
)
from latent_diffusion.latent_diffusion import create_latent_diffusion_model
from stable_diffusion import HuggingFaceDiffusionWrapper


def diffusion_defaults():
    """
    Defaults for image and classifier training.
    """
    return dict(
        learn_sigma=True,
        diffusion_steps=1000,
        noise_schedule="linear",
        timestep_respacing="250",
        use_kl=False,
        predict_xstart=False,
        rescale_timesteps=False,
        rescale_learned_sigmas=False,
    )


def classifier_defaults():
    """
    Defaults for classifier models.
    """
    return dict(
        image_size=256,
        classifier_use_fp16=False,
        classifier_width=128,
        classifier_depth=2,
        classifier_attention_resolutions="32,16,8",  # 16
        classifier_use_scale_shift_norm=True,  # False
        classifier_resblock_updown=True,  # False
        classifier_pool="attention",
    )


def model_and_diffusion_defaults():
    """
    Defaults for image training.
    """
    res = dict(
        image_size=256,
        num_channels=256,
        num_res_blocks=2,
        num_heads=4,
        num_heads_upsample=-1,
        num_head_channels=64,
        attention_resolutions="32,16,8",
        channel_mult="",
        dropout=0.0,
        class_cond=False,
        use_checkpoint=False,
        use_scale_shift_norm=True,
        resblock_updown=True,
        use_fp16=True,
        use_new_attention_order=False,
    )
    res.update(diffusion_defaults())
    return res

def create_argparser():
    defaults = dict(
        clip_denoised=True,
        num_samples=1,
        batch_size=4,
        use_ddim=False,
        model_path="./guided_diffusion/models/256x256_diffusion_uncond.pt",
    )
    defaults.update(model_and_diffusion_defaults())
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    return parser

trans_to_256= transforms.Compose([
   transforms.Resize((256, 256)),])

trans_to_224= transforms.Compose([
   transforms.Resize((224, 224)),])

DiffusionTypes = Enum('DiffusionTypes', 'original, latent, stable')

class DiffusionWrapper:
    def __init__(self, diffusion, args, model=None, diffusion_type=DiffusionTypes.original):
        self.diffusion = diffusion
        self.model = model
        self.args = args
        self.diffusion_type = diffusion_type
        self.context = None
        self.label = False

    def p_sample(self, img, t):
        if self.diffusion_type == DiffusionTypes.latent:
            if self.label:
                random_label = torch.randint(0, 100, (img.size(0), ), device=img.device)
                label_context = self.diffusion.get_learned_conditioning({self.diffusion.cond_stage_key: random_label})
            else:
                no_label = torch.zeros((img.size(0), ), dtype=torch.int, device=img.device)
                label_context = self.diffusion.get_learned_conditioning({self.diffusion.cond_stage_key: no_label})
                label_context = torch.zeros_like(label_context)
            sample = self.diffusion.p_sample(x=img,
                                             c=label_context,
                                             t=t,
                                             clip_denoised=self.args.clip_denoised,
                                             return_x0=True)
            sample = {'sample': sample[0], 'pred_xstart': sample[1]}
        else:
            sample = self.diffusion.p_sample(
                self.model,
                img,
                t,
                clip_denoised=self.args.clip_denoised,
                denoised_fn=None,
                cond_fn=None,
                model_kwargs={},
            )
        return sample

    def encode(self, img):
        if self.diffusion_type != DiffusionTypes.original:
            return self.diffusion.encode_first_stage(img)
        return img

    def decode(self, z):
        if self.diffusion_type != DiffusionTypes.original:
            return self.diffusion.decode_first_stage(z)
        return z


def create_diffusion_model(diffusion_type):
    if diffusion_type == "stable":
        return HuggingFaceDiffusionWrapper()

    args = create_argparser().parse_args([])

    if diffusion_type == "latent":
        return DiffusionWrapper(create_latent_diffusion_model(), args,
                                model=None, diffusion_type=DiffusionTypes.latent)

    dist_util.setup_dist()

    d_model, diffusion = create_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys())
    )

    d_model.load_state_dict(
        dist_util.load_state_dict(args.model_path, map_location="cpu")
    )
    d_model.to(dist_util.dev())
    if args.use_fp16:
        d_model.convert_to_fp16()
    d_model.eval()
    return DiffusionWrapper(diffusion, args, model=d_model)


def beta(t, steps, start, end):
    return (t - 1) / (steps - 1) * (end - start) + start


def add_noise(x, delta, opt_t, steps, start, end):
    return np.sqrt(1 - beta(opt_t, steps, start, end)) * (x + torch.randn_like(x) * delta)


def get_opt_t(delta, start, end, steps):
    return np.clip(int(np.around(1 + (steps - 1) / (end - start) * (1 - 1 / (1 + delta ** 2) - start))), 0, steps)


def denoise(img, diffusion_wrapper, steps, start, end, delta, direct_pred=False):
    opt_t = get_opt_t(delta, start, end, steps)

    if isinstance(diffusion_wrapper, DiffusionWrapper):
        img = diffusion_wrapper.encode(img)

    img_xt = add_noise(img, delta, opt_t, steps, start, end)
    if isinstance(diffusion_wrapper, HuggingFaceDiffusionWrapper):
        return diffusion_wrapper(img_xt, opt_t)


    if len(img_xt.size()) == 3: # Make work on batches
        img_xt = img_xt.unsqueeze(0)

    indices = list(range(opt_t))[::-1]

    if not direct_pred:
        from tqdm.auto import tqdm
        indices = tqdm(indices)

    img_iter = img_xt

    for i in indices:
        t = torch.tensor([i] * img_xt.size(0), device=img_xt.device)
        # t = t.to(device)
        with torch.no_grad():
            out = diffusion_wrapper.p_sample(img_iter, t)

            img_iter = out['sample']
            if direct_pred:
                img_iter = out['pred_xstart']
                break


    # img_iter = ((img_iter + 1) * 127.5).clamp(0, 255).to(th.uint8)
    # img_iter = img_iter.permute(0, 2, 3, 1)
    # img_iter = img_iter.contiguous()
    return diffusion_wrapper.decode(img_iter)
