import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.multiprocessing as mp
from transformers import AutoTokenizer, CLIPTextModel
from diffusers import DDIMScheduler, DDPMScheduler, StableDiffusionPipeline
from diffusers.utils.import_utils import is_xformers_available


def hash_prompt(model, prompt):
    import hashlib

    identifier = f"{model}-{prompt}"
    return hashlib.md5(identifier.encode()).hexdigest()


class StableDiffusionGuidance(nn.Module):
    def __init__(self, device, input_prompt, negative_prompt, fp16=True):
        super().__init__()

        self.device = device
        self.fp16 = fp16
        self.dtype = torch.float16 if self.fp16 else torch.float32

        self.input_prompt = input_prompt
        self.negative_prompt = negative_prompt

        self.use_cache: bool = True
        self.spawn: bool = True
        self.cache_dir = "load/text_embeddings_cache"
        self.pretrained_model_path = "load/stable-diffusion-2-1-base"

        self.pipe = StableDiffusionPipeline.from_pretrained(
            # "stabilityai/stable-diffusion-2-1-base",
            self.pretrained_model_path,
            safety_checker=None,
            requires_safety_checker=False,
            torch_dtype=self.dtype
        ).to(self.device)

        # Create model
        self.vae = self.pipe.vae.eval()
        self.unet = self.pipe.unet.eval()
        self.tokenizer = self.pipe.tokenizer
        self.text_encoder = self.pipe.text_encoder

        for p in self.vae.parameters():
            p.requires_grad_(False)
        for p in self.unet.parameters():
            p.requires_grad_(False)

        self.scheduler = DDIMScheduler.from_pretrained(
            # "stabilityai/stable-diffusion-2-1-base",
            self.pretrained_model_path,
            subfolder="scheduler",
            torch_dtype=self.dtype,
        )

        self.num_train_timesteps = self.scheduler.config.num_train_timesteps
        self.min_step_percent = 0.02
        self.max_step_percent = 0.55
        self.min_step = int(self.num_train_timesteps * self.min_step_percent)
        self.max_step = int(self.num_train_timesteps * self.max_step_percent)
        self.alphas = self.scheduler.alphas_cumprod.to(self.device)

        self.grad_clip = [0, 1.5, 2.0, 1000]
        self.grad_clip_val = None
        self.guidance_scale = 100.

        self.overhead_threshold: float = 60.0
        self.front_threshold: float = 45.0
        self.back_threshold: float = 45.0

        self.direction2idx = {
            "side": 0,
            "front": 1,
            "back": 2,
            "overhead": 3
        }

        views = ['side view', 'front view', 'back view', 'overhead view']
        self.prompts_vd = [f'{self.input_prompt}, {view}' for view in views]
        self.negative_prompts_vd = [self.negative_prompt] * 4

        self.prepare_text_embeddings()
        self.load_text_embeddings()

    def set_min_max_steps(self, min_step_percent=0.02, max_step_percent=0.98):
        self.min_step = int(self.num_train_timesteps * min_step_percent)
        self.max_step = int(self.num_train_timesteps * max_step_percent)

    @staticmethod
    def spawn_func(pretrained_model_path, prompts, cache_dir):
        os.environ["TOKENIZERS_PARALLELISM"] = "false"
        tokenizer = AutoTokenizer.from_pretrained(
            pretrained_model_path, subfolder="tokenizer"
        )
        text_encoder = CLIPTextModel.from_pretrained(
            pretrained_model_path,
            subfolder="text_encoder",
            device_map="auto",
        )

        with torch.no_grad():
            tokens = tokenizer(
                prompts,
                padding="max_length",
                max_length=tokenizer.model_max_length,
                return_tensors="pt",
            )
            text_embeddings = text_encoder(tokens.input_ids.to(text_encoder.device))[0]

        for prompt, embedding in zip(prompts, text_embeddings):
            torch.save(
                embedding,
                os.path.join(
                    cache_dir,
                    f"{hash_prompt(pretrained_model_path, prompt)}.pt",
                ),
            )

        del text_encoder

    def prepare_text_embeddings(self):
            os.makedirs(self.cache_dir, exist_ok=True)

            all_prompts = (
                    [self.input_prompt]
                    + [self.negative_prompt]
                    + self.prompts_vd
                    + self.negative_prompts_vd
            )
            prompts_to_process = []
            for prompt in all_prompts:
                if self.use_cache:
                    # some text embeddings are already in cache
                    # do not process them
                    cache_path = os.path.join(
                        self.cache_dir,
                        f"{hash_prompt(self.pretrained_model_path, prompt)}.pt",
                    )
                    if os.path.exists(cache_path):
                        print(f"Text embeddings for model {self.pretrained_model_path} and prompt [{prompt}] are already in cache, skip processing.")
                        continue
                prompts_to_process.append(prompt)

            if len(prompts_to_process) > 0:
                if self.spawn:
                    ctx = mp.get_context("spawn")
                    subprocess = ctx.Process(
                        target=self.spawn_func,
                        args=(
                            self.pretrained_model_path,
                            prompts_to_process,
                            self.cache_dir,
                        ),
                    )
                    subprocess.start()
                    subprocess.join()
                else:
                    self.spawn_func(
                        self.pretrained_model_path,
                        prompts_to_process,
                        self.cache_dir,
                    )
                torch.cuda.empty_cache()

    def load_text_embeddings(self):
        self.text_embeddings_vd = torch.stack(
            [self.load_from_cache(prompt) for prompt in self.prompts_vd], dim=0
        )
        self.uncond_text_embeddings_vd = torch.stack(
            [self.load_from_cache(prompt) for prompt in self.negative_prompts_vd], dim=0
        )

    def load_from_cache(self, prompt):
        cache_path = os.path.join(
            self.cache_dir,
            f"{hash_prompt(self.pretrained_model_path, prompt)}.pt",
        )
        if not os.path.exists(cache_path):
            raise FileNotFoundError(
                f"Text embedding file {cache_path} for model {self.pretrained_model_path} and prompt [{prompt}] not found."
            )
        return torch.load(cache_path, map_location=self.device)


    def shift_azimuth_deg(self, azimuth):
        # shift azimuth angle (in degrees), to [-180, 180]
        return (azimuth + 180) % 360 - 180

    def get_text_embeds(self, elevation, azimuth, camera_distances):
        # Get direction
        direction_idx = [0 for _ in elevation]

        for idx in range(len(elevation)):
            ele = elevation[idx]
            azi = azimuth[idx]
            dis = camera_distances[idx]

            if ele > self.overhead_threshold:
                direction = "overhead"
            elif -self.front_threshold < self.shift_azimuth_deg(azi) < self.front_threshold:
                direction = "front"
            elif (self.shift_azimuth_deg(azi) > 180 - self.back_threshold) or (
                    self.shift_azimuth_deg(azi) < -180 + self.back_threshold):
                direction = "back"
            else:
                direction = "side"

            direction_idx[idx] = self.direction2idx[direction]

        # Get text embeddings
        text_embeddings = self.text_embeddings_vd[direction_idx]
        uncond_text_embeddings = self.uncond_text_embeddings_vd[direction_idx]

        # IMPORTANT: we return (cond, uncond), which is in different order than other implementations!
        return torch.cat([text_embeddings, uncond_text_embeddings], dim=0)

    def encode_text(self, prompt):
        # prompt: [str]
        inputs = self.tokenizer(
            prompt,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt",
        )
        embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
        return embeddings

    def encode_images(self, imgs):
        # pixel value from [0, 1] to [-1, 1]
        imgs = imgs * 2.0 - 1.0

        posterior = self.vae.encode(imgs.to(self.dtype)).latent_dist
        latents = posterior.sample() * self.vae.config.scaling_factor

        return latents

    def train_step(self, pred_rgb, elevation, azimuth, camera_distances, supervise_mean, supervise_variance, as_latents=False):
        batch_size = pred_rgb.shape[0]
        pred_rgb = pred_rgb.to(self.dtype)

        if as_latents:
            latents = F.interpolate(pred_rgb, (64, 64), mode="bilinear", align_corners=False)
        else:
            pred_rgb_BCHW_512 = F.interpolate(pred_rgb, (512, 512), mode="bilinear", align_corners=False)

            # change list format to tensor
            supervise_variance = torch.stack(supervise_variance)  # Shape should be [4, 3]
            supervise_mean = torch.stack(supervise_mean)

            # change the distribution
            mean_pred_rgb = torch.mean(pred_rgb_BCHW_512, dim=[2, 3]).unsqueeze(-1).unsqueeze(-1)
            variance_pred_rgb = torch.var(pred_rgb_BCHW_512, dim=[2, 3], unbiased=False).unsqueeze(-1).unsqueeze(-1).sqrt()
            supervise_variance_expanded = supervise_variance.unsqueeze(-1).unsqueeze(-1).sqrt()
            supervise_mean_expanded = supervise_mean.unsqueeze(-1).unsqueeze(-1)

            normalized_pred_rgb = (pred_rgb_BCHW_512 - mean_pred_rgb) / variance_pred_rgb

            pred_rgb_BCHW_512 = normalized_pred_rgb * supervise_variance_expanded + supervise_mean_expanded
            pred_rgb_BCHW_512 = torch.clamp(pred_rgb_BCHW_512, 0, 1)

            # encode image into latents with vae
            latents = self.encode_images(pred_rgb_BCHW_512)   # latents: Float[Tensor, "B 4 64 64"]

        # timestep ~ U(0.02, 0.98) to avoid very high/low noise level
        t = torch.randint(self.min_step, self.max_step + 1, [batch_size], dtype=torch.long, device=self.device)

        w = (1 - self.alphas[t]).view(-1, 1, 1, 1)

        #  compute_grad_sds
        text_embeddings = self.get_text_embeds(elevation, azimuth, camera_distances)

        # ct the noise residual with unet, NO grad!
        with torch.no_grad():
            noise = torch.randn_like(latents)
            latents_noisy = self.scheduler.add_noise(latents, noise, t)
            # pred noise
            latent_model_input = torch.cat([latents_noisy] * 2, dim=0)

            input_dtype = latent_model_input.dtype
            noise_pred = self.unet(
                latent_model_input.to(self.dtype),
                torch.cat([t] * 2).to(self.dtype),
                encoder_hidden_states=text_embeddings.to(self.dtype),
            ).sample.to(input_dtype)

        # perform guidance (high scale from paper!)
        noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
        noise_pred = noise_pred_text + self.guidance_scale * (noise_pred_text - noise_pred_uncond)

        grad = w * (noise_pred - noise)
        grad = torch.nan_to_num(grad)

        # clip grad for stable training
        if self.grad_clip_val is not None:
            grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val)

        target = (latents - grad).detach()
        loss_sds = 0.5 * F.mse_loss(latents.float(), target, reduction="sum") / batch_size

        return loss_sds

    def C(self, value, global_step: int) -> float:
        if isinstance(value, int) or isinstance(value, float):
            pass
        else:
            if not isinstance(value, list):
                raise TypeError("Scalar specification only supports list, got", type(value))
            if len(value) == 3:
                value = [0] + value
            assert len(value) == 4
            start_step, start_value, end_value, end_step = value

            current_step = global_step
            value = start_value + (end_value - start_value) * max(
                min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0
            )

        return value

    def update_step(self, global_step: int):
        if self.grad_clip is not None:
            self.grad_clip_val = self.C(self.grad_clip, global_step)

        self.set_min_max_steps(
            min_step_percent=self.C(self.min_step_percent, global_step),
            max_step_percent=self.C(self.max_step_percent, global_step),
        )

