import torch
import random
from libs.dpm_solver_pp import NoiseScheduleVP, DPM_Solver
import einops
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
from libs.people_dataset import _transform


def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000):
    _betas = (
        torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
    )
    return _betas.numpy()


def prepare_contexts(prompt, config, clip_text_model, device):
    img_contexts = torch.randn(config.n_samples, 2 * config.z_shape[0], config.z_shape[1], config.z_shape[2])
    clip_imgs = torch.randn(config.n_samples, 1, config.clip_img_dim)

    prompts = [ prompt ] * config.n_samples
    contexts = clip_text_model.encode(prompts)

    return contexts, img_contexts, clip_imgs


def unpreprocess(v):  # to B C H W and [0, 1]
    v = 0.5 * (v + 1.)
    v.clamp_(0., 1.)
    return v


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def split(x, config):
    C, H, W = config.z_shape
    z_dim = C * H * W
    z, clip_img = x.split([z_dim, config.clip_img_dim], dim=1)
    z = einops.rearrange(z, 'B (C H W) -> B C H W', C=C, H=H, W=W)
    clip_img = einops.rearrange(clip_img, 'B (L D) -> B L D', L=1, D=config.clip_img_dim)
    return z, clip_img


def combine(z, clip_img):
    z = einops.rearrange(z, 'B C H W -> B (C H W)')
    clip_img = einops.rearrange(clip_img, 'B L D -> B (L D)')
    return torch.concat([z, clip_img], dim=-1)

@torch.cuda.amp.autocast()
def decode(_batch, autoencoder):
    return autoencoder.decode(_batch)

@torch.no_grad()
def sample(prompt, image2:str, config, nnet, clip_text_model, feed_model, autoencoder, caption_decoder, device, return_map=False, **kwargs):
    set_seed(config.seed)
    if type(image2) is str:
        image2 = _transform(224)(Image.open(image2).convert("RGB")).unsqueeze(0).to(device)
    elif type(image2) is np.ndarray:
        image2 = _transform(224)(Image.fromarray(image2).convert("RGB")).unsqueeze(0).to(device)
    if config.get('benchmark', False):
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = False

    _betas = stable_diffusion_beta_schedule()
    N = len(_betas)

    empty_context = clip_text_model.encode([''])[0]
    contexts, img_contexts, clip_imgs = prepare_contexts(prompt, config, clip_text_model, device)
    contexts_low_dim = caption_decoder.encode_prefix(contexts)  # the low dimensional version of the contexts, which is the input to the nnet
    _n_samples = contexts_low_dim.size(0)

    attention_maps = []
    def t2i_nnet(x, timesteps, text):  # text is the low dimension version of the text clip embedding
        """
        1. calculate the conditional model output
        2. calculate unconditional model output
            config.sample.t2i_cfg_mode == 'empty_token': using the original cfg with the empty string
            config.sample.t2i_cfg_mode == 'true_uncond: using the unconditional model learned by our method
        3. return linear combination of conditional output and unconditional output
        """
        z, clip_img = split(x, config)
        t_text = torch.zeros(timesteps.size(0), dtype=torch.int, device=device)
        dict_out = nnet(z, clip_img, text=text, t_img=timesteps, t_text=t_text,
                        data_type=torch.zeros_like(t_text, device=device, dtype=torch.int) + config.data_type,
                        return_map=return_map)
        z_out, clip_img_out, text_out = dict_out["img_out"], dict_out["clip_img_out"], dict_out["text_out"]
        if return_map:
            attention_maps.append([i.detach().cpu() for i in dict_out["attention_maps"]])
        x_out = combine(z_out, clip_img_out)

        if config.sample.t2i_cfg_mode == 'empty_token':
            _empty_context = einops.repeat(empty_context, 'L D -> B L D', B=x.size(0))
            _empty_context = caption_decoder.encode_prefix(_empty_context)
            dict_out = nnet(z, clip_img, text=_empty_context, t_img=timesteps, t_text=t_text,
                                                                      data_type=torch.zeros_like(t_text, device=device, dtype=torch.int) + config.data_type)
            z_out_uncond, clip_img_out_uncond, text_out_uncond =  dict_out["img_out"], dict_out["clip_img_out"], dict_out["text_out"]
            x_out_uncond = combine(z_out_uncond, clip_img_out_uncond)
        elif config.sample.t2i_cfg_mode == 'true_uncond':
            text_N = torch.randn_like(text)  # 3 other possible choices
            dict_out = nnet(z, clip_img, text=text_N, t_img=timesteps, t_text=torch.ones_like(timesteps) * N,
                                                                      data_type=torch.zeros_like(t_text, device=device, dtype=torch.int) + config.data_type)
            z_out_uncond, clip_img_out_uncond, text_out_uncond =  dict_out["img_out"], dict_out["clip_img_out"], dict_out["text_out"]
            x_out_uncond = combine(z_out_uncond, clip_img_out_uncond)
        else:
            raise NotImplementedError

        return x_out + config.sample.scale * (x_out - x_out_uncond)



    
    def sample_fn(text):
        _z_init = torch.randn(_n_samples, *config.z_shape, device=device)
        _clip_img_init = torch.randn(_n_samples, 1, config.clip_img_dim, device=device)
        _x_init = combine(_z_init, _clip_img_init)

        noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())

        def model_fn(x, t_continuous):
            t = t_continuous * N
            return t2i_nnet(x, t, text)

        dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
        with torch.no_grad(), torch.autocast(device_type="cuda" if "cuda" in str(device) else "cpu"):
            x = dpm_solver.sample(_x_init, steps=config.sample.sample_steps, eps=1. / N, T=1.)

        _z, _clip_img = split(x, config)
        return _z, _clip_img


    samples = []  
    for i in range(config.n_iter):
        _z, _clip_img = sample_fn(text=contexts_low_dim)  # conditioned on the text embedding
        new_samples = unpreprocess(decode(_z, autoencoder))
        for sample in new_samples:
            samples.append(transforms.ToPILImage()(sample))
    return {"samples": samples,
            "attention_maps": attention_maps
            }


