import copy
import random

import PIL
from PIL import Image
from tqdm import tqdm

from ldm.util import instantiate_from_config, load_and_preprocess
from omegaconf import OmegaConf
import torch
from torchvision import transforms
import math
import numpy as np
from ldm.modules.diffusionmodules.util import noise_like

def cal_xt_with_skip_noise(sampler, x, noise, start_idx=0, end_idx=16, device='cuda'):
    x = x.to(device)
    b = x.shape[0]
    alphas = sampler.ddim_alphas
    # select parameters corresponding to the currently considered timestep
    a_t = torch.full((b, 1, 1, 1), alphas[end_idx], device=device)
    a_start = torch.full((b, 1, 1, 1), alphas[start_idx], device=device)
    noise = noise.to(device)
    xt = torch.randn_like(x, device=device)
    if start_idx == 0:
        xt = x * a_t.sqrt() + (1 - a_t).sqrt() * noise
    return xt

'''
根据unet预测出的噪声 计算出下一时间步含噪图像以及上一步去噪后的图像
'''
def cal_x0_prev_x(sampler, x, e_t, index, device='cpu'):
    x = x.to(device)
    e_t = e_t.to(device)
    b = x.shape[0]
    alphas = sampler.ddim_alphas
    alphas_prev = sampler.ddim_alphas_prev
    sqrt_one_minus_alphas = sampler.ddim_sqrt_one_minus_alphas
    sigmas = sampler.ddim_sigmas
    # select parameters corresponding to the currently considered timestep
    a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
    a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
    sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
    sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)

    # current prediction for x_0
    pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()

    # direction pointing to x_t
    dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
    noise = sigma_t * noise_like(x.shape, device) * 1.
    x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
    return x_prev, pred_x0

'''
图像预处理
'''
def img_preprocess(input_im, carvekit_model):
    # carvekit_model = create_carvekit_interface()
    input_im = load_and_preprocess(carvekit_model, input_im)
    input_im = (input_im / 255.0).astype(np.float32)
    # print(input_im.shape)

    input_im = transforms.ToTensor()(input_im)
    input_im = input_im * 2 - 1
    # print('input image shape', input_im.shape)
    return input_im


'''
从ckpt文件读取预训练zero123模型
'''
def load_zero_123(ckpt_path='105000.ckpt', config_path='configs/sd-objaverse-finetune-c_concat-256.yaml',
                  device='cuda'):
    config = OmegaConf.load(config_path)
    print(f'Loading model from {ckpt_path}')
    model_dict = torch.load(ckpt_path, map_location=device)
    sd = model_dict['state_dict']
    diffusion_model = instantiate_from_config(config.model)
    m, u = diffusion_model.load_state_dict(sd, strict=False)
    print('whole model: ', diffusion_model.__class__.__name__)
    return diffusion_model.to(device)

'''
与zero123基本相同，根据图片和角度生成文本引导和图像引导
'''
def get_context(model, input_image, cam_angle, guidance_scale, device):
    input_image = input_image.to(device)
    b, channels, h, w = input_image.shape

    shape = [4, h // 8, w // 8]  # 4 * 32 * 32
    with torch.no_grad():
        # conditional context
        c = model.cond_stage_model(input_image).tile(1, 1, 1).permute(1, 0, 2)  # n 1 768
        T = torch.stack([
            torch.tensor([math.radians(angle[0]),
                          math.sin(math.radians(angle[1])),
                          math.cos(math.radians(angle[1])),
                          angle[2]])
            for angle in cam_angle
        ], dim=0).to(c.device)  # 形状: [batch_size, 4]
        T = T.unsqueeze(1).to(c.device)
        c = torch.cat([c, T], dim=-1)
        c = model.cc_projection(c)  # n 1 772 -> n 1 768 # embedding

        cond = {}
        cond['c_crossattn'] = [c]
        cond['c_concat'] = [model.encode_first_stage((input_image.to(c.device))).mode()
                            .repeat(1, 1, 1, 1)]
        # no unconditional context
        if guidance_scale == 1:
            context = torch.cat(cond['c_crossattn'], 1)
            img_context = torch.cat(cond['c_concat'], 1)
            return context, img_context

        else:
            uc = {}
            uc['c_concat'] = [torch.zeros(b, 4, h // 8, w // 8).to(c.device)]
            uc['c_crossattn'] = [torch.zeros_like(c).to(c.device)]
            if isinstance(cond, dict):
                assert isinstance(uc, dict)
                c_in = dict()
                for k in cond:
                    if isinstance(cond[k], list):
                        c_in[k] = [torch.cat([
                            uc[k][i],
                            cond[k][i]]) for i in range(len(cond[k]))]
                    else:
                        c_in[k] = torch.cat([
                            uc[k],
                            cond[k]])
            else:
                c_in = torch.cat([uc, cond])

            context = torch.cat(c_in['c_crossattn'], 1)
            img_context = torch.cat(c_in['c_concat'], 1)
            return context, img_context


'''
unet按照ddim逐步采样、去噪
'''
def zero123_sampler(unet, sampler, device,
                    context, img_context, init_timesteps=1000, dtype=torch.float32,
                    return_images=False, guidance_scale=1, num_steps=16,
                    train_sampler=True, num_steps_eval=3):
    b = img_context.shape[0] // 2
    img_shape = (b, 4, 32, 32)
    diffusion_device = next(sampler.model.parameters()).device
    # print('diffusion device', diffusion_device)
    img_context = img_context.to(device)
    context = context.to(device)
    noise_pred = None
    # 随机噪声初始化-》unet输出-》下阶段input和去噪后图像
    if train_sampler:
        sampler.make_schedule(ddim_num_steps=num_steps, ddim_eta=1.0)
        timesteps = sampler.ddim_timesteps
        timesteps = timesteps[: -1]
        time_range = np.flip(timesteps)
        total_steps = timesteps.shape[0]
        print(sampler.ddim_sigmas)

        iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
        D_x = torch.randn(img_shape).to(device)
        pred_x0 = torch.zeros(img_shape).to(device)
        # D_x = torch.zeros(img_shape).to(device)
        for i, step in enumerate(iterator):
            index = total_steps - i - 1
            # print(f"Step {i} of {num_steps}: timestep: {step}")
            ts = torch.full((b,), step, dtype=torch.long)
            if guidance_scale == 1:
                unet_input = torch.cat((D_x, img_context), dim=1).to(device)
                noise_pred = unet(x=unet_input.to(dtype), timesteps=ts.to(device), context=context.to(device)).to(
                    torch.float32)
            else:
                D_x_in = torch.cat([D_x] * 2)
                ts_in = torch.cat([ts, ts])
                unet_input = torch.cat((D_x_in, img_context), dim=1).to(device)
                noise_pred_all = unet(x=unet_input.to(dtype), timesteps=ts_in.to(device),
                                      context=context.to(device)).to(torch.float32)
                noise_pred_uncond, noise_pred = noise_pred_all.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
            D_x, pred_x0 = cal_x0_prev_x(sampler, D_x, noise_pred, index=index, device=img_context.device)
            # print('predicted image shape D_x: ', D_x.shape)
    else:
        sampler.make_schedule(ddim_num_steps=num_steps_eval, ddim_eta=1.0)
        timesteps = sampler.ddim_timesteps
        print('Sampling timesteps: ', timesteps.shape)
        timesteps = timesteps[: -1]
        time_range = np.flip(timesteps)
        total_steps = timesteps.shape[0]
        print(sampler.ddim_sigmas)
        # print(f"Running DDIM Sampling with {total_steps} timesteps")

        iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
        D_x = torch.randn(img_shape).to(device)
        pred_x0 = torch.zeros(img_shape).to(device)
        for i, step in enumerate(iterator):
            index = total_steps - i - 1
            # print(f"Step {i} of {num_steps}: timestep: {step}")
            ts = torch.full((b,), step, dtype=torch.long)
            with torch.no_grad():
                if guidance_scale == 1:
                    unet_input = torch.cat((D_x, img_context), dim=1).to(device)
                    noise_pred = unet(x=unet_input.to(dtype), timesteps=ts.to(device), context=context.to(device))
                else:
                    D_x_in = torch.cat([D_x] * 2)  # 1 4 32 32 --> 2 4 32 32
                    ts_in = torch.cat([ts, ts])
                    # print(f'D_x_in shape:{D_x_in.shape} img_context shape:{img_context.shape}')
                    unet_input = torch.cat((D_x_in, img_context), dim=1).to(device)  # 2 8 32 32
                    # print('unet input shape: ', unet_input.shape)
                    noise_pred_all = unet(x=unet_input.to(dtype), timesteps=ts_in.to(device),
                                          context=context.to(device)).to(torch.float32)  # 2 4 32 32
                    # print('noise pred all shape: ', noise_pred_all.shape)
                    noise_pred_uncond, noise_pred = noise_pred_all.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)  # 1 4 32 32
                    # print('noise pred shape: ', noise_pred.shape)
                D_x, pred_x0 = cal_x0_prev_x(sampler, D_x, noise_pred, index=index, device=img_context.device)
        D_x = D_x.to(torch.float32)
        pred_x0 = pred_x0.to(torch.float32)

    if return_images:
        # images = sampler.model.decode_first_stage(pred_x0.to('cuda:3'))
        images = sampler.model.decode_first_stage(pred_x0.to(device))
        return images.to(torch.float32)
    else:
        # print('pred_x0 shape', pred_x0.shape)
        return pred_x0.to(torch.float32)


'''
记录unet采样每一步预测出的噪声返回，用于基准角度缓存
'''
def zero123_sampler_noise_by_step(unet, sampler, device,
                                  context, img_context, dtype=torch.float32, guidance_scale=1, num_steps=16):
    b = img_context.shape[0] // 2
    img_shape = (b, 4, 32, 32)
    img_context = img_context.to(device)
    context = context.to(device)
    noise_pred = None

    sampler.make_schedule(ddim_num_steps=num_steps, ddim_eta=1.0)
    timesteps = sampler.ddim_timesteps
    print('Sampling timesteps: ', timesteps.shape)
    timesteps = timesteps[: -1]
    time_range = np.flip(timesteps)
    total_steps = timesteps.shape[0]
    print(sampler.ddim_sigmas)
    # print(f"Running DDIM Sampling with {total_steps} timesteps")

    sampler_noise_by_step = []
    iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
    D_x = torch.randn(img_shape).to(device)
    # noised_img = copy.deepcopy(D_x)
    pred_x0 = torch.zeros(img_shape).to(device)
    for i, step in enumerate(iterator):
        index = total_steps - i - 1
        # print(f"Step {i} of {num_steps}: timestep: {step}")
        ts = torch.full((b,), step, dtype=torch.long)
        with torch.no_grad():
            if guidance_scale == 1:
                unet_input = torch.cat((D_x, img_context), dim=1).to(device)
                noise_pred = unet(x=unet_input.to(dtype), timesteps=ts.to(device), context=context.to(device))
            else:
                D_x_in = torch.cat([D_x] * 2)  # 1 4 32 32 --> 2 4 32 32
                ts_in = torch.cat([ts, ts])
                # print(f'D_x_in shape:{D_x_in.shape} img_context shape:{img_context.shape}')
                unet_input = torch.cat((D_x_in, img_context), dim=1).to(device)  # 2 8 32 32
                noise_pred_all = unet(x=unet_input.to(dtype), timesteps=ts_in.to(device),
                                      context=context.to(device)).to(
                    torch.float32)  # 2 4 32 32
                noise_pred_uncond, noise_pred = noise_pred_all.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)  # 1 4 32 32
            D_x, pred_x0 = cal_x0_prev_x(sampler, D_x, noise_pred, index=index, device=img_context.device)
            sampler_noise_by_step.append(D_x)
    return sampler_noise_by_step


'''
跳步采样，跳过指定步数后，用插值生成的噪声alter_noise完成后续采样步骤
'''
def zero123_sampler_noise_skip(unet, sampler, device, skip_start_idx, num_skip_steps, alter_noise, return_images,
                               context, img_context, dtype=torch.float32, guidance_scale=1, num_steps=32, train_sampler=False):
    b = img_context.shape[0] // 2
    img_shape = (b, 4, 32, 32)
    img_context = img_context.to(device)
    context = context.to(device)
    noise_pred = None

    sampler.make_schedule(ddim_num_steps=num_steps, ddim_eta=1.0)
    timesteps = sampler.ddim_timesteps
    timesteps = timesteps[: -1]
    time_range = np.flip(timesteps)
    total_steps = timesteps.shape[0]

    iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
    # D_x = torch.randn(img_shape).to(device)
    D_x = alter_noise.to(device)
    pred_x0 = torch.zeros(img_shape).to(device)
    for i, step in enumerate(iterator):
        index = total_steps - i - 1
        if skip_start_idx <= i < skip_start_idx + num_skip_steps:
            # D_x, pred_x0 = cal_x0_prev_x(sampler, D_x, alter_noise[i], index=index, device=device)
            continue
        # if i == skip_start_idx + num_skip_steps - 1:
        #     D_x = alter_noise.to(device)
        #     continue
        #     D_x = cal_xt_with_skip_noise(sampler, D_x, alter_noise, start_idx=0, end_idx=15, device=device)
        # print(f"Step {i} of {num_steps}: timestep: {step}")
        ts = torch.full((b,), step, dtype=torch.long)
        if train_sampler:
            if guidance_scale == 1:
                unet_input = torch.cat((D_x, img_context), dim=1).to(device)
                noise_pred = unet(x=unet_input.to(dtype), timesteps=ts.to(device), context=context.to(device))
            else:
                D_x_in = torch.cat([D_x] * 2)  # 1 4 32 32 --> 2 4 32 32
                ts_in = torch.cat([ts, ts])
                # print(f'D_x_in shape:{D_x_in.shape} img_context shape:{img_context.shape}')
                unet_input = torch.cat((D_x_in, img_context), dim=1).to(device)  # 2 8 32 32
                noise_pred_all = unet(x=unet_input.to(dtype), timesteps=ts_in.to(device),
                                      context=context.to(device)).to(
                    torch.float32)  # 2 4 32 32
                noise_pred_uncond, noise_pred = noise_pred_all.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)  # 1 4 32 32
            D_x, pred_x0 = cal_x0_prev_x(sampler, D_x, noise_pred, index=index, device=img_context.device)
        else:
            with torch.no_grad():
                if guidance_scale == 1:
                    unet_input = torch.cat((D_x, img_context), dim=1).to(device)
                    noise_pred = unet(x=unet_input.to(dtype), timesteps=ts.to(device), context=context.to(device))
                else:
                    D_x_in = torch.cat([D_x] * 2)  # 1 4 32 32 --> 2 4 32 32
                    ts_in = torch.cat([ts, ts])
                    # print(f'D_x_in shape:{D_x_in.shape} img_context shape:{img_context.shape}')
                    unet_input = torch.cat((D_x_in, img_context), dim=1).to(device)  # 2 8 32 32
                    noise_pred_all = unet(x=unet_input.to(dtype), timesteps=ts_in.to(device),
                                          context=context.to(device)).to(
                        torch.float32)  # 2 4 32 32
                    noise_pred_uncond, noise_pred = noise_pred_all.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)  # 1 4 32 32
                D_x, pred_x0 = cal_x0_prev_x(sampler, D_x, noise_pred, index=index, device=img_context.device)
    if return_images:
        # images = sampler.model.decode_first_stage(pred_x0.to('cuda:3'))
        images = sampler.model.decode_first_stage(pred_x0.to(device))
        return images.to(torch.float32)
    else:
        # print('pred_x0 shape', pred_x0.shape)
        return pred_x0.to(torch.float32)

'''
unet随机采样一次去噪一次，用于score打分网络
'''
def zero123_denoise(
        unet, sampler, sampler_noise, noise,
        context, img_context, timesteps, device,
        dtype=torch.float32, predict_x0=True, guidance_scale=1
):
    img_shape = (1, 4, 32, 32)
    sampler_noise = sampler_noise.to(device)
    sampler_noise.requires_grad_(True)
    img_context = img_context.to(device)
    context = context.to(device)
    sampler.make_schedule(ddim_num_steps=32, ddim_eta=1.0)
    timesteps = sampler.ddim_timesteps
    timesteps = timesteps[: -1]
    time_range = np.flip(timesteps)
    total_steps = timesteps.shape[0]
    iterator = list(tqdm(time_range, desc='DDIM Sampler', total=total_steps))
    # 随机选择一个 (i, step) 元素
    random_choice = random.choice(list(enumerate(iterator)))
    i, step = random_choice
    index = total_steps - i - 1
    b = sampler_noise.shape[0]
    ts = torch.full((b,), step, dtype=torch.long)
    sigmas = sampler.ddim_sigmas
    alphas_prev = sampler.ddim_alphas_prev
    sqrt_one_minus_alphas = sampler.ddim_sqrt_one_minus_alphas
    a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
    sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
    sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)
    noise = sigma_t * noise.to(device) * 1.
    sampler_noise = a_prev.sqrt() * sampler_noise + sqrt_one_minus_at * noise

    if guidance_scale == 1:
        unet_input = torch.cat((sampler_noise, img_context), dim=1).to(device)
        noise_pred = unet(x=unet_input.to(dtype), timesteps=ts.to(device), context=context.to(device)).to(torch.float32)
    else:
        t = torch.cat([ts, ts])
        x_in = torch.cat([sampler_noise] * 2)
        unet_input = torch.cat((x_in, img_context), dim=1).to(device)
        # print('unet input shape: ', unet_input.shape)
        noise_pred = unet(x=unet_input.to(dtype), timesteps=t.to(device), context=context.to(device)).to(torch.float32)
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

    print(
        f'sampler noise requires grad:{sampler_noise.requires_grad}, noise pred requires grad:{noise_pred.requires_grad}')
    x_prev, pred_x0 = cal_x0_prev_x(sampler=sampler, x=sampler_noise, e_t=noise_pred, index=index,
                                    device=img_context.device)
    # x_prev, pred_x0 cannot grad
    if predict_x0:
        return pred_x0.to(torch.float32)
    else:
        return noise_pred.to(torch.float32)
