import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import custom_bwd, custom_fwd
from diffusers import UniPCMultistepScheduler, DDIMScheduler

from PIL import Image
from .sc_adapter.selective_adapter_norm_detailed_clip import SC_Adapter
from .sc_adapter.pipeline_adapter_plus import StableDiffusionPipeline


class SpecifyGradient(torch.autograd.Function):
    @staticmethod
    @custom_fwd
    def forward(ctx, input_tensor, gt_grad):
        ctx.save_for_backward(gt_grad)
        # we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward.
        return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype)

    @staticmethod
    @custom_bwd
    def backward(ctx, grad_scale):
        gt_grad, = ctx.saved_tensors
        gt_grad = gt_grad * grad_scale
        return gt_grad, None
    

class sc_adapter(nn.Module):
    def __init__(self, device, t_range=[0.02, 0.98], **kwargs):
        super().__init__()
        base_model_path = "runwayml/stable-diffusion-v1-5"
        sc_ckpt = ["SSR_Encoder/pytorch_model.bin",
                   "SSR_Encoder/pytorch_model_1.bin"]
        image_encoder_path = "SSR_Encoder/image_encoder_l"

        noise_scheduler = DDIMScheduler(
            num_train_timesteps=1000,
            beta_start=0.00085,
            beta_end=0.012,
            beta_schedule="scaled_linear",
            clip_sample=False,
            set_alpha_to_one=False,
            steps_offset=1,
        )

        pipe = StableDiffusionPipeline.from_pretrained(
            base_model_path,
            safety_checker=None,
            scheduler=noise_scheduler,  # follow previous dreamer works to use the ddim sampler
            torch_dtype=torch.float32).to(device)
        # pipe.enable_xformers_memory_efficient_attention()
        # pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
        self.sc_model = SC_Adapter(pipe.unet, image_encoder_path, device, dtype=torch.float32)
        self.sc_model.get_pipe(pipe)
        self.sc_model.load_sc_adapter(sc_ckpt[0], sc_ckpt[1])

        self.device = device
        self.num_train_timesteps = self.sc_model.pipe.scheduler.config.num_train_timesteps
        self.min_step = int(self.num_train_timesteps * t_range[0])
        self.max_step = int(self.num_train_timesteps * t_range[1])
        self.alphas = self.sc_model.pipe.scheduler.alphas_cumprod.to(self.device) # for convenience
    
    def encode_imgs(self, imgs):
        # imgs: [B, 3, H, W]

        imgs = 2 * imgs - 1

        posterior = self.sc_model.pipe.vae.encode(imgs).latent_dist
        latents = posterior.sample() * self.sc_model.pipe.vae.config.scaling_factor

        return latents
    
    def train_step(self,
                   pred_rgb, 
                   pil_image_list,
                   concept_list,
                   uncond_concept=None,
                   prompt=None,
                   negative_prompt=None,
                   scale=None,
                   grad_scale=1,
                   guidance_scale=100,
                   as_latent=False,
                   num_samples=1,
                   is_style=True,):
        
        if prompt is None:
            prompt = "best quality, high quality"
        else:
            # prompt = prompt + ", best quality, high quality"
            prompt = [p + ", best quality, high quality" for p in prompt]
        if negative_prompt is None:
            # negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
            negative_prompt = ["monochrome, lowres, bad anatomy, worst quality, low quality"]*len(prompt)
        if scale is None:
            scale = [1.0]*len(concept_list)
        if uncond_concept is None:
            uncond_concept = [""]*len(concept_list)

        image_prompt_embeds, uncond_image_prompt_embeds = self.sc_model.get_image_embeds(
            pil_image_list, num_samples, concept_list, uncond_concept, is_style
        )
        image_prompt_embeds = image_prompt_embeds.view(1, -1, 768)
        uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(1, -1, 768)
        cache, uncond_cache = [], []
        for i in range(len(scale)):
            image_prompt_embed_new = scale[i] * image_prompt_embeds[:, int(77*6 * i): int(77*6 + 77*6 * i),:].clone()
            # image_prompt_embed_new = scale[i] * image_prompt_embeds[:, int(77*6 * i): int(77 + 77*6 * i),:].clone() # only select z0 for localization
            # uncond_image_prompt_embeds_new = uncond_image_prompt_embeds[:, int(77*6 * i): int(77 + 77*6 * i),:].clone()
            image_prompt_embed = image_prompt_embed_new.clone()
            cache.append(image_prompt_embed)
            # uncond_cache.append(uncond_image_prompt_embeds_new)
        image_prompt_embeds = torch.cat(cache, dim=1)
        # uncond_image_prompt_embeds = torch.cat(uncond_cache, dim=1)

        with torch.inference_mode():
            prompt_embeds = self.sc_model.pipe._encode_prompt(
                prompt, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True,
                negative_prompt=negative_prompt)
            negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)
            prompt_embeds_ = prompt_embeds_.view(1, -1, 768)
            negative_prompt_embeds_ = negative_prompt_embeds_.view(1, -1, 768)
            prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
            negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
        img_prompt_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)

        if as_latent:
            latents = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1
        else:
            # interp to 512x512 to be fed into vae.
            pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
            # encode image into latents with vae, requires grad!
            latents = self.encode_imgs(pred_rgb_512)
        
        # timestep ~ U(0.02, 0.98) to avoid very high/low noise level
        t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device)

        # predict the noise residual with unet, NO grad!
        with torch.no_grad():
            # add noise
            noise = torch.randn_like(latents)
            latents_noisy = self.sc_model.pipe.scheduler.add_noise(latents, noise, t)
            # pred noise
            latent_model_input = torch.cat([latents_noisy] * 2)
            tt = torch.cat([t] * 2)
            noise_pred = self.sc_model.pipe.unet(latent_model_input, tt, encoder_hidden_states=img_prompt_embeddings).sample

            # perform guidance (high scale from paper!)
            noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_pos - noise_pred_uncond)
        
        # w(t), sigma_t^2
        w = (1 - self.alphas[t])
        grad = grad_scale * w[:, None, None, None] * (noise_pred - noise)
        grad = torch.nan_to_num(grad)

        loss = SpecifyGradient.apply(latents, grad)

        return loss
