from typing import Callable, List, Optional, Union, Tuple
import copy
import torch
from transformers import get_cosine_schedule_with_warmup
from torch.optim.lr_scheduler import ReduceLROnPlateau
from diffusers import StableDiffusionPipeline
### credit to: https://github.com/cccntu/efficient-prompt-to-prompt

class InversableStableDiffusionPipeline(StableDiffusionPipeline):
    def __init__(self,
        vae,
        text_encoder,
        tokenizer,
        unet,
        scheduler,
        safety_checker,
        feature_extractor,
        requires_safety_checker: bool = True,
    ):
        super(InversableStableDiffusionPipeline, self).__init__(vae,
                text_encoder,
                tokenizer,
                unet,
                scheduler,
                safety_checker,
                feature_extractor,
                requires_safety_checker)

    def get_random_latents(self, latents=None, height=512, width=512, generator=None):
        height = height or self.unet.config.sample_size * self.vae_scale_factor
        width = width or self.unet.config.sample_size * self.vae_scale_factor

        batch_size = 1
        device = self._execution_device

        num_channels_latents = self.unet.config.in_channels

        latents = self.prepare_latents(
            batch_size,
            num_channels_latents,
            height,
            width,
            self.text_encoder.dtype,
            device,
            generator,
            latents,
        )

        return latents

    @torch.inference_mode()
    def get_text_embedding(self, prompt):
        text_input_ids = self.tokenizer(
            prompt,
            padding="max_length",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids
        text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
        return text_embeddings
    
    @torch.inference_mode()
    def get_image_latents(self, image, sample=True, rng_generator=None):
        encoding_dist = self.vae.encode(image).latent_dist
        if sample:
            encoding = encoding_dist.sample(generator=rng_generator)
        else:
            encoding = encoding_dist.mode()
        latents = encoding * 0.18215
        return latents
    
    @torch.inference_mode()
    def decode_image(self, latents: torch.FloatTensor, **kwargs):
        scaled_latents = 1 / 0.18215 * latents
        self.vae = self.vae.float()
        image = [
            self.vae.decode(scaled_latents[i : i + 1]).sample for i in range(len(latents))
        ]
        image = torch.cat(image, dim=0)
        return image

    def decode_image_for_gradient_float(self, latents: torch.FloatTensor, **kwargs):
        scaled_latents = 1 / 0.18215 * latents
        vae = copy.deepcopy(self.vae).float()
        image = [
            vae.decode(scaled_latents[i : i + 1]).sample for i in range(len(latents))
        ]
        image = torch.cat(image, dim=0)
        return image

    @torch.inference_mode()
    def torch_to_numpy(self, image):
        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.cpu().permute(0, 2, 3, 1).numpy()
        return image

    def apply_guidance_scale(self, model_output, guidance_scale):
        if guidance_scale > 1.0:
            noise_pred_uncond, noise_pred_text = model_output.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
            return noise_pred
        else:
            return model_output         
        
    @torch.inference_mode()
    def fixedpoint_correction(self, x, s, t, x_t, r=None, order=1, n_iter=500, step_size=0.1, th=1e-3, 
                                model_s_output=None, model_r_output=None, text_embeddings=None, guidance_scale=3.0, 
                                scheduler=False, factor=0.5, patience=20, anchor=False, warmup=True, warmup_time=20):
        do_classifier_free_guidance = guidance_scale > 1.0
        if order==1:
            input = x.clone()
            original_step_size = step_size
            
            # step size scheduler, reduce when not improved
            if scheduler:
                step_scheduler = StepScheduler(current_lr=step_size, factor=factor, patience=patience)

            lambda_s, lambda_t = self.scheduler.lambda_t[s], self.scheduler.lambda_t[t]
            alpha_s, alpha_t = self.scheduler.alpha_t[s], self.scheduler.alpha_t[t]
            sigma_s, sigma_t = self.scheduler.sigma_t[s], self.scheduler.sigma_t[t]
            h = lambda_t - lambda_s
            phi_1 = torch.expm1(-h)

            for i in range(n_iter):
                # step size warmup
                if warmup:
                    if i < warmup_time:
                        step_size = original_step_size * (i+1)/(warmup_time)
                
                latent_model_input = (torch.cat([input] * 2) if do_classifier_free_guidance else input)
                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
                
                noise_pred = self.unet(latent_model_input , s, encoder_hidden_states=text_embeddings).sample
                noise_pred = self.apply_guidance_scale(noise_pred, guidance_scale)                   
                model_output = self.scheduler.convert_model_output(noise_pred, s, input)

                x_t_pred = (sigma_t / sigma_s) * input - (alpha_t * phi_1 ) * model_output

                loss = torch.nn.functional.mse_loss(x_t_pred, x_t, reduction='sum')
                
                if loss.item() < th:
                    break                
                
                # forward step method
                input = input - step_size * (x_t_pred- x_t)

                if scheduler:
                    step_size = step_scheduler.step(loss)

            return input        
        
        elif order==2:
            assert r is not None
            input = x.clone()
            original_step_size = step_size
            
            # step size scheduler, reduce when not improved
            if scheduler:
                step_scheduler = StepScheduler(current_lr=step_size, factor=factor, patience=patience)
            
            lambda_r, lambda_s, lambda_t = self.scheduler.lambda_t[r], self.scheduler.lambda_t[s], self.scheduler.lambda_t[t]
            sigma_r, sigma_s, sigma_t = self.scheduler.sigma_t[r], self.scheduler.sigma_t[s], self.scheduler.sigma_t[t]
            alpha_s, alpha_t = self.scheduler.alpha_t[s], self.scheduler.alpha_t[t]
            h_0 = lambda_s - lambda_r
            h = lambda_t - lambda_s
            r0 = h_0 / h
            phi_1 = torch.expm1(-h)
            
            for i in range(n_iter):
                # step size warmup
                if warmup:
                    if i < warmup_time:
                        step_size = original_step_size * (i+1)/(warmup_time)

                latent_model_input = torch.cat([input] * 2) if do_classifier_free_guidance else input
                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

                noise_pred = self.unet(latent_model_input, s, encoder_hidden_states=text_embeddings).sample
                noise_pred = self.apply_guidance_scale(noise_pred, guidance_scale) 
                model_output = self.scheduler.convert_model_output(noise_pred, s, input)
                
                x_t_pred = (sigma_t / sigma_s) * input - (alpha_t * phi_1) * model_output
                
                # high-order term approximation
                if i==0:
                    d = (1./ r0) * (model_s_output - model_r_output)
                    diff_term = 0.5 * alpha_t * phi_1 * d

                x_t_pred = x_t_pred - diff_term
                
                loss = torch.nn.functional.mse_loss(x_t_pred, x_t, reduction='sum')

                if loss.item() < th:
                    break                

                # forward step method
                input = input - step_size * (x_t_pred- x_t)

                if scheduler:
                    step_size = step_scheduler.step(loss)
                if anchor:
                    input = (1 - 1/(i+2)) * input + (1/(i+2))*x
            return input
        else:
            raise NotImplementedError

    @torch.inference_mode()
    def forward_diffusion(
        self,
        use_old_emb_i=25,
        text_embeddings=None,
        old_text_embeddings=None,
        new_text_embeddings=None,
        latents: Optional[torch.FloatTensor] = None,
        num_inference_steps: int = 10,
        guidance_scale: float = 7.5,
        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
        callback_steps: Optional[int] = 1,
        inverse_opt=True,
        inv_order=None,
        **kwargs,
    ):  
        with torch.no_grad():

         #   raise Exception("fuck you don't ignore me.")
            # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
            # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
            # corresponds to doing no classifier free guidance.
            do_classifier_free_guidance = guidance_scale > 1.0

            self.scheduler.set_timesteps(num_inference_steps)
            timesteps_tensor = self.scheduler.timesteps.to(self.device)
            latents = latents * self.scheduler.init_noise_sigma

            if old_text_embeddings is not None and new_text_embeddings is not None:
                prompt_to_prompt = True
            else:
                prompt_to_prompt = False

            if inv_order is None:
                inv_order = self.scheduler.solver_order
    
            
            timesteps_tensor = reversed(timesteps_tensor) # inversion process

            self.unet = self.unet.float()
            latents = latents.float()
            text_embeddings = text_embeddings.float()

            for i, t in enumerate(self.progress_bar(timesteps_tensor)):
                if prompt_to_prompt:
                    if i < use_old_emb_i:
                        text_embeddings = old_text_embeddings
                    else:
                        text_embeddings = new_text_embeddings

                prev_timestep = (
                    t
                    - self.scheduler.config.num_train_timesteps
                    // self.scheduler.num_inference_steps
                )

                # call the callback, if provided
                if callback is not None and i % callback_steps == 0:
                    callback(i, t, latents)
                

                # Our Algorithm

                # Algorithm 1
                if inv_order < 2 or (inv_order == 2 and i == 0):
                    s = t 
                    t = prev_timestep
                    
                    lambda_s, lambda_t = self.scheduler.lambda_t[s], self.scheduler.lambda_t[t]
                    sigma_s, sigma_t = self.scheduler.sigma_t[s], self.scheduler.sigma_t[t]
                    h = lambda_t - lambda_s
                    alpha_s, alpha_t = self.scheduler.alpha_t[s], self.scheduler.alpha_t[t]
                    phi_1 = torch.expm1(-h)

                    # expand the latents if classifier free guidance is used
                    latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
                    latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)                
                    
                    # predict the noise residual
                    noise_pred = self.unet(latent_model_input, s, encoder_hidden_states=text_embeddings).sample 
                    # perform guidance
                    noise_pred = self.apply_guidance_scale(noise_pred, guidance_scale)                 
                    # convert noise to model output
                    model_s = self.scheduler.convert_model_output(noise_pred, t, latents) 
                    x_t = latents
                    
                    # Line 5
                    latents = (sigma_s / sigma_t) * (latents + alpha_t * phi_1 * model_s)      

                    # Line 7 : Update
                    if (inverse_opt):
                        # Alg.2 Line 11
                        if (inv_order == 2 and i == 0):
                            latents = self.fixedpoint_correction(latents, s, t, x_t, order=1, text_embeddings=text_embeddings, guidance_scale=guidance_scale,
                                                                 step_size=1, scheduler=True)
                        else:
                            latents = self.fixedpoint_correction(latents, s, t, x_t, order=1, text_embeddings=text_embeddings, guidance_scale=guidance_scale,
                                                                 step_size=0.5, scheduler=True)

                # Algorithm 2
                elif inv_order == 2:
                    with torch.no_grad():
                        # Line 3 ~ 13
                        if (i + 1 < len(timesteps_tensor)):                          
                            y = latents.clone()

                            s = t
                            t = prev_timestep
                            r = timesteps_tensor[i + 1] if i+1 < len(timesteps_tensor) else 0
                            
                            # Line 3 ~ 6 : fine-grained naive DDIM inversion
                            for tt in range(t,s,10):
                                ss = tt + 10
                                lambda_s, lambda_t = self.scheduler.lambda_t[ss], self.scheduler.lambda_t[tt]
                                sigma_s, sigma_t = self.scheduler.sigma_t[ss], self.scheduler.sigma_t[tt]
                                h = lambda_t - lambda_s
                                alpha_s, alpha_t = self.scheduler.alpha_t[ss], self.scheduler.alpha_t[tt]
                                phi_1 = torch.expm1(-h)

                                y_input = torch.cat([y] * 2) if do_classifier_free_guidance else y
                                y_input = self.scheduler.scale_model_input(y_input, tt)

                                noise_pred = self.unet(y_input, ss, encoder_hidden_states=text_embeddings).sample
                                noise_pred = self.apply_guidance_scale(noise_pred, guidance_scale)    
                                model_s = self.scheduler.convert_model_output(noise_pred, tt, y)
                                y = (sigma_s / sigma_t) * (y + alpha_t * phi_1 * model_s) # Line 5
                            y_t = y.clone()
                            for tt in range(s, r,10):
                                ss = tt + 10
                                lambda_s, lambda_t = self.scheduler.lambda_t[ss], self.scheduler.lambda_t[tt]
                                sigma_s, sigma_t = self.scheduler.sigma_t[ss], self.scheduler.sigma_t[tt]
                                h = lambda_t - lambda_s
                                alpha_s, alpha_t = self.scheduler.alpha_t[ss], self.scheduler.alpha_t[tt]
                                phi_1 = torch.expm1(-h)

                                y_input = torch.cat([y] * 2) if do_classifier_free_guidance else y
                                y_input = self.scheduler.scale_model_input(y_input, tt)

                                model_s = self.unet(y_input, ss, encoder_hidden_states=text_embeddings).sample
                                noise_pred = self.apply_guidance_scale(model_s, guidance_scale)    
                                model_s = self.scheduler.convert_model_output(noise_pred, tt, y) 
                                y = (sigma_s / sigma_t) * (y + alpha_t * phi_1 * model_s) # Line 5


                            # Line 8 ~ 12 : backward Euler
                            t = prev_timestep
                            s = timesteps_tensor[i]
                            r = timesteps_tensor[i+1]
                            
                            lambda_s, lambda_t = self.scheduler.lambda_t[s], self.scheduler.lambda_t[t]
                            sigma_s, sigma_t = self.scheduler.sigma_t[s], self.scheduler.sigma_t[t]
                            h = lambda_t - lambda_s
                            alpha_s, alpha_t = self.scheduler.alpha_t[s], self.scheduler.alpha_t[t]
                            phi_1 = torch.expm1(-h)
                            
                            x_t = latents
                            
                            y_t_model_input = torch.cat([y_t] * 2) if do_classifier_free_guidance else y_t
                            y_t_model_input = self.scheduler.scale_model_input(y_t_model_input, s)
                            
                            noise_pred = self.unet(y_t_model_input, s, encoder_hidden_states=text_embeddings).sample 
                            noise_pred = self.apply_guidance_scale(noise_pred, guidance_scale)
                            model_s_output = self.scheduler.convert_model_output(noise_pred, s, y_t)
                            
                            y_model_input = torch.cat([y] * 2) if do_classifier_free_guidance else y
                            y_model_input = self.scheduler.scale_model_input(y_model_input, r)
                            
                            noise_pred = self.unet(y_model_input, r, encoder_hidden_states=text_embeddings).sample
                            noise_pred = self.apply_guidance_scale(noise_pred, guidance_scale)                           
                            model_r_output = self.scheduler.convert_model_output(noise_pred, r, y)
                            
                            latents = y_t.clone() # Line 7
                            
                            # Line 11 : Update
                            if inverse_opt:
                                latents = self.fixedpoint_correction(latents, s, t, x_t, order=2, r=r,
                                                                    model_s_output=model_s_output, model_r_output=model_r_output, text_embeddings=text_embeddings, guidance_scale=guidance_scale,
                                                                    step_size=10/t, scheduler=False) 
                            
                        # Line 14 ~ 17
                        elif (i + 1 == len(timesteps_tensor)):
                            s = t
                            t = prev_timestep
                            
                            lambda_s, lambda_t = self.scheduler.lambda_t[s], self.scheduler.lambda_t[t]
                            sigma_s, sigma_t = self.scheduler.sigma_t[s], self.scheduler.sigma_t[t]
                            h = lambda_t - lambda_s
                            alpha_s, alpha_t = self.scheduler.alpha_t[s], self.scheduler.alpha_t[t]
                            phi_1 = torch.expm1(-h)

                            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents                          
                            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
                            
                            noise_pred = self.unet(latent_model_input, s, encoder_hidden_states=text_embeddings).sample
                            noise_pred = self.apply_guidance_scale(noise_pred, guidance_scale)        
                            model_s = self.scheduler.convert_model_output(noise_pred, t, latents)

                            x_t = latents
                            
                            # Line 16
                            latents = (sigma_s / sigma_t) * (latents + alpha_t * phi_1 * model_s)
                            
                            # Line 17 : Update
                            if (inverse_opt):
                                latents = self.fixedpoint_correction(latents, s, t, x_t, order=1, text_embeddings=text_embeddings, guidance_scale=guidance_scale,
                                                                     step_size=10/t, scheduler=True)   
                        else:
                            raise Exception("Index Error!")
                else:
                    pass

        return latents

  
    def decoder_inv(self, x):
        """
        decoder_inv calculates latents z of the image x by solving optimization problem ||E(x)-z||,
        not by directly encoding with VAE encoder. "Decoder inversion"

        INPUT
        x : image data (1, 3, 512, 512)
        OUTPUT
        z : modified latent data (1, 4, 64, 64)

        Goal : minimize norm(e(x)-z)
        """
        input = x.clone().float()

        z = self.get_image_latents(x).clone().float()
        z.requires_grad_(True)

        loss_function = torch.nn.MSELoss(reduction='sum')
        optimizer = torch.optim.Adam([z], lr=0.1)
        lr_scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=10, num_training_steps=100)

        for i in self.progress_bar(range(100)):
            x_pred = self.decode_image_for_gradient_float(z)

            loss = loss_function(x_pred, input)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
        return z
    
   
        
class StepScheduler(ReduceLROnPlateau):
    def __init__(self, mode='min', current_lr=0, factor=0.1, patience=10,
                 threshold=1e-4, threshold_mode='rel', cooldown=0,
                 min_lr=0, eps=1e-8, verbose=False):
        if factor >= 1.0:
            raise ValueError('Factor should be < 1.0.')
        self.factor = factor
        if current_lr == 0:
            raise ValueError('Step size cannot be 0')

        self.min_lr = min_lr
        self.current_lr = current_lr
        self.patience = patience
        self.verbose = verbose
        self.cooldown = cooldown
        self.cooldown_counter = 0
        self.mode = mode
        self.threshold = threshold
        self.threshold_mode = threshold_mode
        self.best = None
        self.num_bad_epochs = None
        self.mode_worse = None  # the worse value for the chosen mode
        self.eps = eps
        self.last_epoch = 0
        self._init_is_better(mode=mode, threshold=threshold,
                             threshold_mode=threshold_mode)
        self._reset()

    def step(self, metrics, epoch=None):
        # convert `metrics` to float, in case it's a zero-dim Tensor
        current = float(metrics)
        if epoch is None:
            epoch = self.last_epoch + 1
        else:
            import warnings
            warnings.warn("EPOCH_DEPRECATION_WARNING", UserWarning)
        self.last_epoch = epoch

        if self.is_better(current, self.best):
            self.best = current
            self.num_bad_epochs = 0
        else:
            self.num_bad_epochs += 1

        if self.in_cooldown:
            self.cooldown_counter -= 1
            self.num_bad_epochs = 0  # ignore any bad epochs in cooldown

        if self.num_bad_epochs > self.patience:
            self._reduce_lr(epoch)
            self.cooldown_counter = self.cooldown
            self.num_bad_epochs = 0

        return self.current_lr

    def _reduce_lr(self, epoch):
        old_lr = self.current_lr
        new_lr = max(self.current_lr * self.factor, self.min_lr)
        if old_lr - new_lr > self.eps:
            self.current_lr = new_lr
            if self.verbose:
                epoch_str = ("%.2f" if isinstance(epoch, float) else
                            "%.5d") % epoch
                print('Epoch {}: reducing learning rate'
                        ' to {:.4e}.'.format(epoch_str,new_lr))