from .sd_step import *
from .perpneg_utils import weighted_perpendicular_aggregator
from torch.cuda.amp import custom_bwd, custom_fwd
from torchvision.utils import save_image
import torch.nn.functional as F
import torch.nn as nn
import torch
import numpy as np
from audioop import mul
from transformers import CLIPTextModel, CLIPTokenizer, logging
from diffusers import StableDiffusionPipeline, DiffusionPipeline, DDPMScheduler, DDIMScheduler, EulerDiscreteScheduler, \
    EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, ControlNetModel, \
    DDIMInverseScheduler, UNet2DConditionModel
from diffusers.utils.import_utils import is_xformers_available
from os.path import isfile
from pathlib import Path
import os
import random
from .pipeline_mvdream import MVDreamPipeline
import torchvision.transforms as T
# suppress partial model loading warning
logging.set_verbosity_error()


def rgb2sat(img, T=None):
    max_ = torch.max(img, dim=1, keepdim=True).values + 1e-5
    min_ = torch.min(img, dim=1, keepdim=True).values
    sat = (max_ - min_) / max_
    if T is not None:
        sat = (1 - T) * sat
    return sat


class SpecifyGradient(torch.autograd.Function):
    @staticmethod
    @custom_fwd
    def forward(ctx, input_tensor, gt_grad):
        ctx.save_for_backward(gt_grad)
        # we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward.
        return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype)

    @staticmethod
    @custom_bwd
    def backward(ctx, grad_scale):
        gt_grad, = ctx.saved_tensors
        gt_grad = gt_grad * grad_scale
        return gt_grad, None


def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = True


class StableDiffusion(nn.Module):
    def __init__(self, device, fp16, vram_O, t_range=[0.02, 0.98], max_t_range=0.98, num_train_timesteps=None,
                 ddim_inv=False, use_control_net=False, textual_inversion_path=None,
                 LoRA_path=None, guidance_opt=None):
        super().__init__()

        self.device = device
        self.precision_t = torch.float16 if fp16 else torch.float32

        print(f'[INFO] loading stable diffusion...')

        pipe = MVDreamPipeline.from_pretrained(
            'ashawkey/mvdream-sd2.1-diffusers',
            torch_dtype=self.precision_t,
            trust_remote_code=True
        )

        self.ism = not guidance_opt.sds
        self.scheduler = pipe.scheduler
        self.sche_func = ddim_step

        if vram_O:
            pipe.enable_sequential_cpu_offload()
            pipe.enable_vae_slicing()
            pipe.unet.to(memory_format=torch.channels_last)
            pipe.enable_attention_slicing(1)
            pipe.enable_model_cpu_offload()

        pipe = pipe.to(self.device)

        self.pipe = pipe
        self.vae = pipe.vae
        self.tokenizer = pipe.tokenizer
        self.text_encoder = pipe.text_encoder
        self.unet = pipe.unet

        self.num_train_timesteps = num_train_timesteps if num_train_timesteps is not None else self.scheduler.config.num_train_timesteps
        self.scheduler.set_timesteps(self.num_train_timesteps, device=device)

        self.timesteps = torch.flip(self.scheduler.timesteps, dims=(0, ))
        self.min_step = int(self.num_train_timesteps * t_range[0])
        self.max_step = int(self.num_train_timesteps * t_range[1])
        self.warmup_step = int(
            self.num_train_timesteps*(max_t_range-t_range[1]))

        self.noise_temp = None
        self.noise_gen = torch.Generator(self.device)
        self.noise_gen.manual_seed(guidance_opt.noise_seed)

        self.alphas = self.scheduler.alphas_cumprod.to(
            self.device)  # for convenience
        self.rgb_latent_factors = torch.tensor([
            # R       G       B
            [0.298,  0.207,  0.208],
            [0.187,  0.286,  0.173],
            [-0.158,  0.189,  0.264],
            [-0.184, -0.271, -0.473]
        ], device=self.device)

        print(f'[INFO] loaded stable diffusion!')

    def augmentation(self, tensor):
        augs = T.Compose([
            T.RandomHorizontalFlip(p=0.5),
        ])

        # Apply augmentations
        augmented_tensor = augs(tensor)

        return augmented_tensor

    def normalize_and_flatten_cams(self, camera_matrix):
        camera_matrix = camera_matrix.reshape(-1,4,4)
        translation = camera_matrix[:,:3,3]
        translation = translation / (torch.norm(translation, dim=1, keepdim=True) + 1e-8)
        camera_matrix[:,:3,3] = translation
        camera_matrix = camera_matrix.reshape(-1,16)
        return camera_matrix.flatten(start_dim=1)

    def add_noise_with_cfg(self, latents, noise,
                           ind_t, ind_prev_t,
                           text_embeddings=None, cfg=1.0,
                           delta_t=1, inv_steps=1,
                           is_noisy_latent=False,
                           eta=0.0, cams=None):

        text_embeddings = text_embeddings.to(self.precision_t)
        if cfg <= 1.0:
            uncond_text_embedding = text_embeddings.reshape(
                2, -1, text_embeddings.shape[-2], text_embeddings.shape[-1])[1]

        unet = self.unet

        if is_noisy_latent:
            prev_noisy_lat = latents
        else:
            prev_noisy_lat = self.scheduler.add_noise(
                latents, noise, self.timesteps[ind_prev_t])

        cur_ind_t = ind_prev_t
        cur_noisy_lat = prev_noisy_lat

        pred_scores = []

        for i in range(inv_steps):
            # pred noise
            cur_noisy_lat_ = self.scheduler.scale_model_input(
                cur_noisy_lat, self.timesteps[cur_ind_t]).to(self.precision_t)

            if cfg > 1.0:
                latent_model_input = torch.cat(
                    [cur_noisy_lat_, cur_noisy_lat_])
                timestep_model_input = self.timesteps[cur_ind_t].reshape(
                    1, 1).repeat(latent_model_input.shape[0], 1).reshape(-1)
                
                unet_inputs = {
                    'x': latent_model_input,
                    'timesteps': timestep_model_input.to(self.precision_t),
                    'context': text_embeddings,
                    'num_frames': 4,
                    'camera': torch.cat([cams] * 2).to(self.precision_t),
                }

                unet_output = unet.forward(**unet_inputs)

                uncond, cond = torch.chunk(unet_output, chunks=2)

                # reverse cfg to enhance the distillation
                unet_output = cond + cfg * (uncond - cond)
            else:
                timestep_model_input = self.timesteps[cur_ind_t].reshape(
                    1, 1).repeat(cur_noisy_lat_.shape[0], 1).reshape(-1)
                
                
                unet_inputs = {
                    'x': cur_noisy_lat_,
                    'timesteps': timestep_model_input.to(self.precision_t),
                    'context': uncond_text_embedding,
                    'num_frames': 4,
                    'camera': cams.to(self.precision_t),
                }

                unet_output = unet.forward(**unet_inputs)

            pred_scores.append((cur_ind_t, unet_output))

            next_ind_t = min(cur_ind_t + delta_t, ind_t)
            cur_t, next_t = self.timesteps[cur_ind_t], self.timesteps[next_ind_t]
            delta_t_ = next_t - \
                cur_t if isinstance(
                    self.scheduler, DDIMScheduler) else next_ind_t-cur_ind_t

            cur_noisy_lat = self.sche_func(
                self.scheduler, unet_output, cur_t, cur_noisy_lat, -delta_t_, eta).prev_sample
            cur_ind_t = next_ind_t

            del unet_output
            torch.cuda.empty_cache()

            if cur_ind_t == ind_t:
                break

        return prev_noisy_lat, cur_noisy_lat, pred_scores[::-1]

    @torch.no_grad()
    def get_text_embeds(self, prompt, resolution=(512, 512)):
        inputs = self.tokenizer(prompt, padding='max_length',
                                max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt')
        embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
        return embeddings

    def train_step_perpneg(self, text_embeddings, pred_rgb,
                           grad_scale=1, use_control_net=False,
                           save_folder: Path = None, iteration=0, warm_up_rate=0, weights=0,
                           resolution=(512, 512), guidance_opt=None, as_latent=False, embedding_inverse=None, cams=None):
        # flip aug
        pred_rgb = self.augmentation(pred_rgb)

        B = pred_rgb.shape[0]
        K = text_embeddings.shape[0] - 1

        cams = self.normalize_and_flatten_cams(cams)

        latents, _ = self.encode_imgs(pred_rgb.to(self.precision_t))
        # timestep ~ U(0.02, 0.98) to avoid very high/low noise level

        weights = weights.reshape(-1)
        noise = torch.randn((latents.shape[0], 4, resolution[0] // 8, resolution[1] // 8, ), dtype=latents.dtype, device=latents.device,
                            generator=self.noise_gen) + 0.1 * torch.randn((1, 4, 1, 1), device=latents.device).repeat(latents.shape[0], 1, 1, 1)

        inverse_text_embeddings = embedding_inverse.unsqueeze(1).repeat(
            1, B, 1, 1).reshape(-1, embedding_inverse.shape[-2], embedding_inverse.shape[-1])

        # make it k+1, c * t, ...
        text_embeddings = text_embeddings.reshape(
            -1, text_embeddings.shape[-2], text_embeddings.shape[-1])

        if guidance_opt.annealing_intervals:
            current_delta_t = int(guidance_opt.delta_t + np.ceil((warm_up_rate)
                                  * (guidance_opt.delta_t_start - guidance_opt.delta_t)))
        else:
            current_delta_t = guidance_opt.delta_t

        ind_t = torch.randint(self.min_step, self.max_step + int(self.warmup_step*warm_up_rate),
                              (1, ), dtype=torch.long, generator=self.noise_gen, device=self.device)[0]
        ind_prev_t = max(ind_t - current_delta_t, torch.ones_like(ind_t) * 0)

        t = self.timesteps[ind_t]
        prev_t = self.timesteps[ind_prev_t]

        with torch.no_grad():
            # step unroll via ddim inversion
            if not self.ism:
                prev_latents_noisy = self.scheduler.add_noise(
                    latents, noise, prev_t)
                latents_noisy = self.scheduler.add_noise(latents, noise, t)
                target = noise
            else:
                # Step 1: sample x_s with larger steps
                xs_delta_t = guidance_opt.xs_delta_t if guidance_opt.xs_delta_t is not None else current_delta_t
                xs_inv_steps = guidance_opt.xs_inv_steps if guidance_opt.xs_inv_steps is not None else int(
                    np.ceil(ind_prev_t / xs_delta_t))
                starting_ind = max(ind_prev_t - xs_delta_t *
                                   xs_inv_steps, torch.ones_like(ind_t) * 0)

                _, prev_latents_noisy, pred_scores_xs = self.add_noise_with_cfg(latents, noise, ind_prev_t, starting_ind, inverse_text_embeddings,
                                                                                guidance_opt.denoise_guidance_scale, xs_delta_t, xs_inv_steps, eta=guidance_opt.xs_eta, cams=cams)

                # Step 2: sample x_t
                _, latents_noisy, pred_scores_xt = self.add_noise_with_cfg(prev_latents_noisy, noise, ind_t, ind_prev_t, inverse_text_embeddings,
                                                                           guidance_opt.denoise_guidance_scale, current_delta_t, 1, is_noisy_latent=True, cams=cams)

                pred_scores = pred_scores_xt + pred_scores_xs
                target = pred_scores[0][1]

        with torch.no_grad():
            latent_model_input = latents_noisy[None, :, ...].repeat(
                1 + K, 1, 1, 1, 1).reshape(-1, 4, resolution[0] // 8, resolution[1] // 8, )
            tt = t.reshape(1, 1).repeat(
                latent_model_input.shape[0], 1).reshape(-1)

            latent_model_input = self.scheduler.scale_model_input(
                latent_model_input, tt[0])
            
            unet_inputs = {
                    'x': latent_model_input.to(self.precision_t),
                    'timesteps': tt.to(self.precision_t),
                    'context': text_embeddings.to(self.precision_t),
                    'num_frames': 4,
                    'camera': torch.cat([cams] * (1 + K)).to(self.precision_t),
                }

            unet_output = self.unet.forward(**unet_inputs)

            unet_output = unet_output.reshape(
                1 + K, -1, 4, resolution[0] // 8, resolution[1] // 8, )
            noise_pred_uncond, noise_pred_text = unet_output[:1].reshape(
                -1, 4, resolution[0] // 8, resolution[1] // 8, ), unet_output[1:].reshape(-1, 4, resolution[0] // 8, resolution[1] // 8, )
            delta_noise_preds = noise_pred_text - \
                noise_pred_uncond.repeat(K, 1, 1, 1)
            delta_DSD = weighted_perpendicular_aggregator(delta_noise_preds,
                                                          weights,
                                                          B)

        pred_noise = noise_pred_uncond + guidance_opt.guidance_scale * delta_DSD
        def w(alphas): return (((1 - alphas) / alphas) ** 0.5)

        grad = w(self.alphas[t]) * (pred_noise - target)

        grad = torch.nan_to_num(grad_scale * grad)
        loss = SpecifyGradient.apply(latents, grad)

        if iteration % guidance_opt.vis_interval == 0:
            noise_pred_post = noise_pred_uncond + guidance_opt.guidance_scale * delta_DSD

            def lat2rgb(x): return torch.clip((x.permute(
                0, 2, 3, 1) @ self.rgb_latent_factors.to(x.dtype)).permute(0, 3, 1, 2), 0., 1.)
            save_path_iter = os.path.join(
                save_folder, "iter_{}_step_{}.jpg".format(iteration, prev_t.item()))
            with torch.no_grad():
                pred_x0_latent_sp = pred_original(
                    self.scheduler, noise_pred_uncond, prev_t, prev_latents_noisy)
                pred_x0_latent_pos = pred_original(
                    self.scheduler, noise_pred_post, prev_t, prev_latents_noisy)
                pred_x0_pos = self.decode_latents(
                    pred_x0_latent_pos.type(self.precision_t))
                pred_x0_sp = self.decode_latents(
                    pred_x0_latent_sp.type(self.precision_t))

                grad_abs = torch.abs(grad.detach())
                norm_grad = F.interpolate((grad_abs / grad_abs.max()).mean(dim=1, keepdim=True),
                                          (resolution[0], resolution[1]), mode='bilinear', align_corners=False).repeat(1, 3, 1, 1)

                latents_rgb = F.interpolate(lat2rgb(
                    latents), (resolution[0], resolution[1]), mode='bilinear', align_corners=False)
                latents_sp_rgb = F.interpolate(lat2rgb(
                    pred_x0_latent_sp), (resolution[0], resolution[1]), mode='bilinear', align_corners=False)

                viz_images = torch.cat([pred_rgb,
                                        latents_rgb, latents_sp_rgb,
                                        norm_grad,
                                        pred_x0_sp, pred_x0_pos], dim=0)
                save_image(viz_images, save_path_iter)

        return loss

    def train_step(self, text_embeddings, pred_rgb,
                   grad_scale=1, use_control_net=False,
                   save_folder: Path = None, iteration=0, warm_up_rate=0,
                   resolution=(512, 512), guidance_opt=None, as_latent=False, embedding_inverse=None, cams=None):

        pred_rgb = self.augmentation(pred_rgb)

        B = pred_rgb.shape[0]
        K = text_embeddings.shape[0] - 1

        cams = self.normalize_and_flatten_cams(cams)

        latents, _ = self.encode_imgs(pred_rgb.to(self.precision_t))
        # timestep ~ U(0.02, 0.98) to avoid very high/low noise level

        if self.noise_temp is None:
            self.noise_temp = torch.randn((latents.shape[0], 4, resolution[0] // 8, resolution[1] // 8, ), dtype=latents.dtype, device=latents.device,
                                          generator=self.noise_gen) + 0.1 * torch.randn((1, 4, 1, 1), device=latents.device).repeat(latents.shape[0], 1, 1, 1)

        if guidance_opt.fix_noise:
            noise = self.noise_temp
        else:
            noise = torch.randn((latents.shape[0], 4, resolution[0] // 8, resolution[1] // 8, ), dtype=latents.dtype, device=latents.device,
                                generator=self.noise_gen) + 0.1 * torch.randn((1, 4, 1, 1), device=latents.device).repeat(latents.shape[0], 1, 1, 1)

        text_embeddings = text_embeddings[:, :, ...]
        # make it k+1, c * t, ...
        text_embeddings = text_embeddings.reshape(
            -1, text_embeddings.shape[-2], text_embeddings.shape[-1])

        inverse_text_embeddings = embedding_inverse.unsqueeze(1).repeat(
            1, B, 1, 1).reshape(-1, embedding_inverse.shape[-2], embedding_inverse.shape[-1])

        if guidance_opt.annealing_intervals:
            current_delta_t = int(guidance_opt.delta_t + (warm_up_rate)
                                  * (guidance_opt.delta_t_start - guidance_opt.delta_t))
        else:
            current_delta_t = guidance_opt.delta_t

        ind_t = torch.randint(self.min_step, self.max_step + int(self.warmup_step*warm_up_rate),
                              (1, ), dtype=torch.long, generator=self.noise_gen, device=self.device)[0]
        ind_prev_t = max(ind_t - current_delta_t, torch.ones_like(ind_t) * 0)

        t = self.timesteps[ind_t]
        prev_t = self.timesteps[ind_prev_t]

        with torch.no_grad():
            # step unroll via ddim inversion
            if not self.ism:
                prev_latents_noisy = self.scheduler.add_noise(
                    latents, noise, prev_t)
                latents_noisy = self.scheduler.add_noise(latents, noise, t)
                target = noise
            else:
                # Step 1: sample x_s with larger steps
                xs_delta_t = guidance_opt.xs_delta_t if guidance_opt.xs_delta_t is not None else current_delta_t
                xs_inv_steps = guidance_opt.xs_inv_steps if guidance_opt.xs_inv_steps is not None else int(
                    np.ceil(ind_prev_t / xs_delta_t))
                starting_ind = max(ind_prev_t - xs_delta_t *
                                   xs_inv_steps, torch.ones_like(ind_t) * 0)

                _, prev_latents_noisy, pred_scores_xs = self.add_noise_with_cfg(latents, noise, ind_prev_t, starting_ind, inverse_text_embeddings,
                                                                                guidance_opt.denoise_guidance_scale, xs_delta_t, xs_inv_steps, eta=guidance_opt.xs_eta, cams=cams)
                # Step 2: sample x_t
                _, latents_noisy, pred_scores_xt = self.add_noise_with_cfg(prev_latents_noisy, noise, ind_t, ind_prev_t, inverse_text_embeddings,
                                                                           guidance_opt.denoise_guidance_scale, current_delta_t, 1, is_noisy_latent=True, cams=cams)

                pred_scores = pred_scores_xt + pred_scores_xs
                target = pred_scores[0][1]

        with torch.no_grad():
            latent_model_input = latents_noisy[None, :, ...].repeat(
                2, 1, 1, 1, 1).reshape(-1, 4, resolution[0] // 8, resolution[1] // 8, )
            tt = t.reshape(1, 1).repeat(
                latent_model_input.shape[0], 1).reshape(-1)

            latent_model_input = self.scheduler.scale_model_input(
                latent_model_input, tt[0])

            unet_inputs = {
                    'x': latent_model_input.to(self.precision_t),
                    'timesteps': tt.to(self.precision_t),
                    'context': text_embeddings.to(self.precision_t),
                    'num_frames': 4,
                    'camera': torch.cat([cams] * 2).to(self.precision_t),
                }

            unet_output = self.unet.forward(**unet_inputs)

            unet_output = unet_output.reshape(
                2, -1, 4, resolution[0] // 8, resolution[1] // 8, )
            noise_pred_uncond, noise_pred_text = unet_output[:1].reshape(
                -1, 4, resolution[0] // 8, resolution[1] // 8, ), unet_output[1:].reshape(-1, 4, resolution[0] // 8, resolution[1] // 8, )
            delta_DSD = noise_pred_text - noise_pred_uncond

        pred_noise = noise_pred_uncond + guidance_opt.guidance_scale * delta_DSD

        def w(alphas): return (((1 - alphas) / alphas) ** 0.5)

        grad = w(self.alphas[t]) * (pred_noise - target)

        grad = torch.nan_to_num(grad_scale * grad)
        loss = SpecifyGradient.apply(latents, grad)

        if iteration % guidance_opt.vis_interval == 0:
            noise_pred_post = noise_pred_uncond + 7.5 * delta_DSD
            def lat2rgb(x): return torch.clip((x.permute(
                0, 2, 3, 1) @ self.rgb_latent_factors.to(x.dtype)).permute(0, 3, 1, 2), 0., 1.)
            save_path_iter = os.path.join(
                save_folder, "iter_{}_step_{}.jpg".format(iteration, prev_t.item()))
            with torch.no_grad():
                pred_x0_latent_sp = pred_original(
                    self.scheduler, noise_pred_uncond, prev_t, prev_latents_noisy)
                pred_x0_latent_pos = pred_original(
                    self.scheduler, noise_pred_post, prev_t, prev_latents_noisy)
                pred_x0_pos = self.decode_latents(
                    pred_x0_latent_pos.type(self.precision_t))
                pred_x0_sp = self.decode_latents(
                    pred_x0_latent_sp.type(self.precision_t))
                # pred_x0_uncond = pred_x0_sp[:1, ...]

                grad_abs = torch.abs(grad.detach())
                norm_grad = F.interpolate((grad_abs / grad_abs.max()).mean(dim=1, keepdim=True),
                                          (resolution[0], resolution[1]), mode='bilinear', align_corners=False).repeat(1, 3, 1, 1)

                latents_rgb = F.interpolate(lat2rgb(
                    latents), (resolution[0], resolution[1]), mode='bilinear', align_corners=False)
                latents_sp_rgb = F.interpolate(lat2rgb(
                    pred_x0_latent_sp), (resolution[0], resolution[1]), mode='bilinear', align_corners=False)

                viz_images = torch.cat([pred_rgb,
                                        latents_rgb, latents_sp_rgb, norm_grad,
                                        pred_x0_sp, pred_x0_pos], dim=0)
                save_image(viz_images, save_path_iter)

        return loss

    def decode_latents(self, latents):
        target_dtype = latents.dtype
        latents = latents / self.vae.config.scaling_factor

        imgs = self.vae.decode(latents.to(self.vae.dtype)).sample
        imgs = (imgs / 2 + 0.5).clamp(0, 1)

        return imgs.to(target_dtype)

    def encode_imgs(self, imgs):
        target_dtype = imgs.dtype
        # imgs: [B, 3, H, W]
        imgs = 2 * imgs - 1

        posterior = self.vae.encode(imgs.to(self.vae.dtype)).latent_dist
        kl_divergence = posterior.kl()

        latents = posterior.sample() * self.vae.config.scaling_factor

        return latents.to(target_dtype), kl_divergence
