"""SAMPLING ONLY."""

from typing import Dict, Optional, Tuple

import numpy as np
import torch
from tqdm import tqdm

from stable_diffusion.ldm.modules.diffusionmodules.sige_openaimodel import SIGEUNetModel
#from .highlight_openaimodel import HighlightUNetModel
from stable_diffusion.ldm.modules.diffusionmodules.util import (
    extract_into_tensor,
    make_ddim_sampling_parameters,
    make_ddim_timesteps,
    noise_like,
)

import nltk
from nltk.tokenize import word_tokenize
from nltk import pos_tag
#nltk.download('averaged_perceptro2n_tagger')

from .utils import (preprocess,top_var_tokens,binary_mask_gen,downsample_mask,optimize_binary_mask)
from transformers import CLIPTokenizer

import pickle
import time
import os

from PIL import Image

def create_binary_array_with_blocks(array_size, block_size, num_blocks, seed=None):
    if seed is not None:
        np.random.seed(seed)
    
    # Initialize the array with zeros
    binary_array = np.zeros(array_size, dtype=np.uint8)
    
    # Calculate the number of blocks along each dimension
    num_blocks_per_dim = array_size[0] // block_size
    
    # Generate all possible block positions
    possible_positions = [(i, j) for i in range(num_blocks_per_dim) for j in range(num_blocks_per_dim)]
    
    # Randomly select the required number of block positions
    selected_positions = np.random.choice(len(possible_positions), num_blocks, replace=False)
    
    # Set the selected blocks to 1
    for pos in selected_positions:
        i, j = possible_positions[pos]
        binary_array[i*block_size:(i+1)*block_size, j*block_size:(j+1)*block_size] = 1
    
    return binary_array

def create_binary_array_with_block(array_size, block_shape, start_position, seed=None):
    if seed is not None:
        np.random.seed(seed)
    
    # Initialize the array with zeros
    binary_array = np.zeros(array_size, dtype=np.uint8)
    
    # Determine the starting position for the block
    start_x, start_y = start_position
    
    # Set the block to 1
    binary_array[start_x:start_x + block_shape[0], start_y:start_y + block_shape[1]] = 1
    
    return binary_array


class HighlightDDIMSampler(object):
    def __init__(self, args, model, schedule="linear", **kwargs):
        super().__init__()
        self.args = args
        self.model = model
        self.ddpm_num_timesteps = model.num_timesteps
        self.schedule = schedule

    def register_buffer(self, name, attr):
        if type(attr) == torch.Tensor:
            if attr.device != torch.device("cuda"):
                attr = attr.to(torch.device("cuda"))
        setattr(self, name, attr)

    def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True):
        self.ddim_timesteps = make_ddim_timesteps(
            ddim_discr_method=ddim_discretize,
            num_ddim_timesteps=ddim_num_steps,
            num_ddpm_timesteps=self.ddpm_num_timesteps,
            verbose=verbose,
        )
        alphas_cumprod = self.model.alphas_cumprod
        assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, "alphas have to be defined for each timestep"
        to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)

        self.register_buffer("betas", to_torch(self.model.betas))
        self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
        self.register_buffer("alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev))

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())))
        self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())))
        self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu())))
        self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())))
        self.register_buffer("sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)))

        # ddim sampling parameters
        ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
            alphacums=alphas_cumprod.cpu(), ddim_timesteps=self.ddim_timesteps, eta=ddim_eta, verbose=verbose
        )
        self.register_buffer("ddim_sigmas", ddim_sigmas)
        self.register_buffer("ddim_alphas", ddim_alphas)
        self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
        self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
            (1 - self.alphas_cumprod_prev)
            / (1 - self.alphas_cumprod)
            * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
        )
        self.register_buffer("ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps)

    @torch.no_grad()
    def sample(
        self,
        S,
        batch_size,
        shape,
        conditioning=None,
        callback=None,
        normals_sequence=None,
        img_callback=None,
        quantize_x0=False,
        eta=0.0,
        mask=None,
        x0=None,
        temperature=1.0,
        noise_dropout=0.0,
        score_corrector=None,
        corrector_kwargs=None,
        verbose=True,
        x_T=None,
        log_every_t=1,
        unconditional_guidance_scale=1.0,
        unconditional_conditioning=None,
        conv_masks: Optional[Dict[Tuple[int, int], torch.Tensor]] = None,
        prompts = None,
        interval = None,
        prompt_i = None,
        threshold = None,
        save_pkl = False,
        # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
        **kwargs,
    ):
        if conditioning is not None:
            if isinstance(conditioning, dict):
                cbs = conditioning[list(conditioning.keys())[0]].shape[0]
                if cbs != batch_size:
                    print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
            else:
                if conditioning.shape[0] != batch_size:
                    print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")

        self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
        # sampling
        C, H, W = shape
        size = (batch_size, C, H, W)
        #print(f"Data shape for DDIM sampling is {size}, eta {eta}")

        samples, intermediates = self.ddim_sampling(
            conditioning,
            size,
            callback=callback,
            img_callback=img_callback,
            quantize_denoised=quantize_x0,
            mask=mask,
            x0=x0,
            ddim_use_original_steps=False,
            noise_dropout=noise_dropout,
            temperature=temperature,
            score_corrector=score_corrector,
            corrector_kwargs=corrector_kwargs,
            x_T=x_T,
            log_every_t=log_every_t,
            unconditional_guidance_scale=unconditional_guidance_scale,
            unconditional_conditioning=unconditional_conditioning,
            conv_masks=conv_masks,
            prompts = prompts,
            interval = interval,
            prompt_i = prompt_i,
            threshold = threshold,
            save_pkl = save_pkl,
        )
        return samples, intermediates

    @torch.no_grad() 
    def ddim_sampling(
        self,
        cond,
        shape,
        x_T=None,
        ddim_use_original_steps=False,
        callback=None,
        timesteps=None,
        quantize_denoised=False,
        mask=None,
        x0=None,
        img_callback=None,
        log_every_t=1,
        temperature=1.0,
        noise_dropout=0.0,
        score_corrector=None,
        corrector_kwargs=None,
        unconditional_guidance_scale=1.0,
        unconditional_conditioning=None,
        conv_masks: Optional[Dict[Tuple[int, int], torch.Tensor]] = None,
        prompts = None,
        interval = None,
        prompt_i = None,
        threshold = None,
        save_pkl = False,
    ):
        device = self.model.betas.device
        b = shape[0]
        if x_T is None:
            img = torch.randn(shape, device=device)
        else:
            img = x_T

        if timesteps is None:
            timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
        elif timesteps is not None and not ddim_use_original_steps:
            subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
            timesteps = self.ddim_timesteps[:subset_end]

        intermediates = {"x_inter": [img], "pred_x0": [img]}
        time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
        total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
        print(f"Running DDIM Sampling with {total_steps} timesteps")
        #print(f"Model name: {self.model.__class__}")

        iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
        #### my code
        save_time = save_pkl
        # set_new_interval = False
        save_num_blocks = save_pkl
        #####
        
        start = time.time()
        
        mode = "full"
        prev_mode = mode
        self.model.model.diffusion_model.set_mode(mode)
        print(f"Mode: {self.model.model.diffusion_model.mode}")   
        for i, step in enumerate(iterator):
            index = total_steps - i - 1
            ts = torch.full((b,), step, device=device, dtype=torch.long)
            if mask is None:
                # assert x0 is not None
                img_orig = img
                #if isinstance(self.model.model.diffusion_model, (SIGEUNetModel, HighlightUNetModel)):
                if (i+1)% interval ==0 or (i+1)<=10: 
                    if mode == "sparse":
                        mode = "full"
                        self.model.model.diffusion_model.set_mode("full")
                else:
                    if mode == "full":
                        mode = "sparse"
                        self.model.model.diffusion_model.set_mode("sparse")
                img, pred_x0,attn_masks = self.p_sample_ddim(
                    img_orig,
                    cond,
                    ts,
                    index=index, 
                    use_original_steps=ddim_use_original_steps,
                    quantize_denoised=quantize_denoised,
                    temperature=temperature,
                    noise_dropout=noise_dropout,
                    score_corrector=score_corrector,
                    corrector_kwargs=corrector_kwargs,
                    unconditional_guidance_scale=unconditional_guidance_scale,
                    unconditional_conditioning=unconditional_conditioning,
                    need_result=True,
                    mode = mode
                )
                if (i+1)==10 and interval != 1:    
                    attn_map = preprocess(512,512,attn_masks)
                    #tokenizer_version="openai/clip-vit-large-patch14"
                    #tokenizer = CLIPTokenizer.from_pretrained(tokenizer_version)
                    binary_mask = binary_mask_gen(prompt_index = prompt_i, num = 1, attn_map = attn_map,
                                                            tokenizer = self.model.cond_stage_model.tokenizer, 
                                                            #tokenizer=tokenizer,
                                                            prompt = prompts, threshold=threshold,save_mask = False)
                    binary_mask = optimize_binary_mask(binary_mask)
                    binary_masks = downsample_mask(binary_mask, min_res=8, dilation=1) 
                    self.model.model.diffusion_model.set_masks(binary_masks) 
                        
                
            if callback:
                callback(i)
            if img_callback:
                img_callback(pred_x0, i)
        return img, intermediates

    @torch.no_grad()
    def p_sample_ddim(
        self,
        x,
        c,
        t,
        index,
        repeat_noise=False,
        use_original_steps=False,
        quantize_denoised=False,
        temperature=1.0,
        noise_dropout=0.0,
        score_corrector=None,
        corrector_kwargs=None,
        unconditional_guidance_scale=1.0,
        unconditional_conditioning=None,
        need_result=True,
        mode = None,
    ):
        b, *_, device = *x.shape, x.device

        if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
            e_t = self.model.apply_model(x, t, c)
        else:
            x_in = torch.cat([x] * 2)
            t_in = torch.cat([t] * 2)
            c_in = torch.cat([unconditional_conditioning, c])
            e_t_batch, attn_masks=self.model.apply_model(x_in, t_in, c_in,mode)
            e_t_uncond, e_t = e_t_batch.chunk(2)
            e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)

        if not need_result:
            return

        if score_corrector is not None:
            assert self.model.parameterization == "eps"
            e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)

        alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
        alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
        sqrt_one_minus_alphas = (
            self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
        )
        sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
        # select parameters corresponding to the currently considered timestep
        a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
        a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
        sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
        sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)

        # current prediction for x_0
        pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
        if quantize_denoised:
            pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
        # direction pointing to x_t
        dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
        noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
        if noise_dropout > 0.0:
            noise = torch.nn.functional.dropout(noise, p=noise_dropout)
        x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
        return x_prev, pred_x0, attn_masks
