import os
import numpy as np
from PIL import Image

from typing import Any, Dict, List, Optional, Tuple
from einops import rearrange, repeat

import torch
import torch.nn.functional as F
import torchvision as tv
from torch import nn
import lightning
from lightning_utilities.core.rank_zero import rank_zero_only
from diffusers import DDIMScheduler, UNet2DConditionModel, AutoencoderKL, StableDiffusionXLPipeline, DDPMScheduler
from ..base import BaseSystem
import wandb

def tokenize_prompt(tokenizer, prompt):
    text_inputs = tokenizer(
        prompt,
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids
    return text_input_ids

def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
    prompt_embeds_list = []

    for i, text_encoder in enumerate(text_encoders):
        if tokenizers is not None:
            tokenizer = tokenizers[i]
            text_input_ids = tokenize_prompt(tokenizer, prompt)
        else:
            assert text_input_ids_list is not None
            text_input_ids = text_input_ids_list[i]

        prompt_embeds = text_encoder(
            text_input_ids.to(text_encoder.device),
            output_hidden_states=True,
        )

        # We are only ALWAYS interested in the pooled output of the final text encoder
        pooled_prompt_embeds = prompt_embeds[0]
        prompt_embeds = prompt_embeds.hidden_states[-2]
        bs_embed, seq_len, _ = prompt_embeds.shape
        prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
        prompt_embeds_list.append(prompt_embeds)

    prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
    pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
    return prompt_embeds, pooled_prompt_embeds

class SDXLSystem(BaseSystem):
    def __init__(
        self,
        pretrained_model_name_or_path,
        variant: str = "fp16_ema",
        cfg: float = 0.1,  # classifier free guidance
        num_inference_steps: int = 50,
        guidance_scale: int = 5.0,
        num_val_dataloaders: int = 2,
        num_test_dataloaders: int = 2,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)

        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters()

        pipeline = StableDiffusionXLPipeline.from_pretrained(
            pretrained_model_name_or_path,
            use_safetensors=True, 
            variant=variant)
        del pipeline.vae
        pipeline.enable_xformers_memory_efficient_attention()
        self.pipeline = pipeline

        # self.vae = pipeline.vae
        self.vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix")
        self.text_encoder = pipeline.text_encoder
        self.text_encoder_2 = pipeline.text_encoder_2
        self.unet = pipeline.unet
        self.scheduler = pipeline.scheduler
        self.ddpm_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
        self.vae.requires_grad_(False)
        self.text_encoder.requires_grad_(False)
        self.text_encoder_2.requires_grad_(False)
        self.unet.requires_grad_(True)
        self.do_classifier_free_guidance = guidance_scale > 0
        self.num_inference_steps = num_inference_steps
        self.guidance_scale = guidance_scale

  
    def setup(self, stage: str):
        super().setup(stage)
        # use different image log method based on the logger type
        self.log_image = None
        if isinstance(self.logger, lightning.pytorch.loggers.TensorBoardLogger):
            self.log_image = self.tensorboard_log_image
        elif isinstance(self.logger, lightning.pytorch.loggers.WandbLogger):
            self.log_image = self.wandb_log_image


    def configure_optimizers(self):
        """Configure optimizers and learning rate schedulers for training."""
        # param_groups = []
        # for params, lr_scale in self.trainable_parameters:
        #     param_groups.append({"params": params, "lr": self.hparams.lr * lr_scale})

        # optimizer = torch.optim.AdamW(param_groups)
        optimizer = torch.optim.AdamW(self.unet.parameters(), lr=1.0e-5)
        return optimizer

    def forward(self, latents, timestep, prompt_embd, meta) -> torch.Tensor:
        return self.mv_model(latents, timestep, prompt_embd, meta)

    def training_step(self, batch, batch_idx):
        image = batch["image"]  # cond image (b c h w)
        bsz, c, h ,w = image.shape
        model_input = self.vae.encode(image).latent_dist.sample()
        model_input = model_input * self.vae.config.scaling_factor
        device, weight_dtype = model_input.device, model_input.dtype

        if torch.rand(1) < self.hparams.cfg:
            image_embeddings = torch.zeros_like(image_embeddings)
            image_latents = torch.zeros_like(image_latents)

        noise = torch.randn_like(model_input)
        timesteps = torch.randint(
            0, self.ddpm_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
        ).long()
        noisy_model_input = self.ddpm_scheduler.add_noise(model_input, noise, timesteps)
        def compute_time_ids(original_size, crops_coords_top_left):
            # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
            target_size = torch.tensor([h, w], device=device, dtype=weight_dtype)
            # add_time_ids = list(original_size + crops_coords_top_left + target_size)
            add_time_ids = torch.concat([original_size, crops_coords_top_left, target_size], dim=0)
            return add_time_ids
        add_time_ids = torch.cat(
                [compute_time_ids(s, c) for s, c in zip(batch["original_size"], batch["crop_top_left"])]
            )
        unet_added_conditions = {"time_ids": add_time_ids}

        (
            prompt_embeds,
            negative_prompt_embeds,
            pooled_prompt_embeds,
            negative_pooled_prompt_embeds,
        ) = self.pipeline.encode_prompt(prompt=batch['prompt'],device=device, do_classifier_free_guidance=False)
        
        unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
        model_pred = self.unet(
            noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions
        ).sample

        if self.scheduler.config.prediction_type == "epsilon":
            target = noise
        elif self.scheduler.config.prediction_type == "v_prediction":
            target = self.scheduler.get_velocity(model_input, noise, timesteps)
        else:
            raise ValueError(f"Unknown prediction type {self.scheduler.config.prediction_type}")

        loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
        self.log("train_loss", loss)

        return loss

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        images_pred = self._generate_images(batch)
        # compute image & save
        image_fp = self._save_image(
            images_pred,
            batch["prompt"],
            f"{dataloader_idx}_{batch_idx}_{self.global_rank}",
            stage="validation",
        )
        return image_fp


    def test_step(self, batch, batch_idx, dataloader_idx=0):
        images_pred = self._generate_images(batch)
        # save images
        image_fp = self._save_image(
            images_pred,
            batch["prompt"],
            f"{dataloader_idx}_{batch_idx}_{self.global_rank}",
            stage="test",
        )
        return image_fp

    @torch.no_grad()
    def _generate_images(self, batch, generator=None):
        image, prompt = batch["image"], batch["prompt"]
        bs, c, h, w = image.shape
        device = image.device
        (
            prompt_embeds,
            negative_prompt_embeds,
            pooled_prompt_embeds,
            negative_pooled_prompt_embeds,
        ) = self.pipeline.encode_prompt(
            prompt=prompt,
            device=device,
            do_classifier_free_guidance=self.do_classifier_free_guidance)
        
        self.scheduler.set_timesteps(self.num_inference_steps, device=device)
        timesteps = self.scheduler.timesteps
        latents = torch.randn(bs, 4, h // 8, w // 8, device=device) * self.scheduler.init_noise_sigma
        add_text_embeds = pooled_prompt_embeds
        text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
        original_size = (1024, 1024)
        crops_coords_top_left = (0, 0)
        target_size = (1024, 1024)

        add_time_ids = self.pipeline._get_add_time_ids(
                    original_size,
                    crops_coords_top_left,
                    target_size,
                    dtype=prompt_embeds.dtype,
                    text_encoder_projection_dim=text_encoder_projection_dim,
                )
        negative_add_time_ids = add_time_ids

        if self.do_classifier_free_guidance:
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
            add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)

        prompt_embeds = prompt_embeds.to(device)
        add_text_embeds = add_text_embeds.to(device)
        add_time_ids = add_time_ids.to(device).repeat(bs, 1)

        timestep_cond = None
        for i, t in enumerate(timesteps):
            # expand the latents if we are doing classifier free guidance
            latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
            # predict the noise residual
            added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
            noise_pred = self.unet(
                latent_model_input,
                t,
                encoder_hidden_states=prompt_embeds,
                timestep_cond=timestep_cond,
                added_cond_kwargs=added_cond_kwargs,
                return_dict=False,
            )[0]

            # perform guidance
            if self.do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
            # compute the previous noisy sample x_t -> x_t-1
            latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

        image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
        image = self.pipeline.image_processor.postprocess(image, output_type='pil')[0]
        return image

    @torch.no_grad()
    @rank_zero_only
    def _save_image(self, images_pred, prompt, batch_idx, stage="validation"):
        save_dir = self.save_dir
        with open(
            os.path.join(save_dir, f"{stage}_{self.global_step}_{batch_idx}.txt"), "w"
        ) as f:
            f.write("\n".join(prompt))
        
        im_fp = os.path.join(
            save_dir,
            f"{stage}_{self.global_step}_{batch_idx}--{prompt[0].replace(' ', '_').replace('/', '_')}.png",
        )
        images_pred.save(im_fp)
        return im_fp

    @torch.no_grad()
    @rank_zero_only
    def _log_to_wandb(self, stage):
        import wandb

        num_dataloaders = (
            self.hparams.num_val_dataloaders
            if stage == "validation"
            else self.hparams.num_test_dataloaders
        )
        for i in range(num_dataloaders):
            captions, images = [], []
            # get images which start with {stage}_{self.global_step}_{self.dataloader_idx} from self.save_dir
            for f in os.listdir(self.save_dir):
                if f.startswith(f"{stage}_{self.global_step}_{i}") and f.endswith(".png"):
                    captions.append(f)
                    images.append(os.path.join(self.save_dir, f))

            self.logger.experiment.log(
                {
                    f"{stage}_{i}": [
                        wandb.Image(im_fp, caption=caption)
                        for im_fp, caption in zip(images, captions)
                    ]
                },
                step=self.global_step,
            )
    def tensorboard_log_image(self, tag: str, image_tensor):
            self.logger.experiment.add_image(
                tag,
                image_tensor,
                self.trainer.global_step,
            )

    def wandb_log_image(self, tag: str, image_tensor):
        image_dict = {
            tag: wandb.Image(image_tensor),
        }
        self.logger.experiment.log(
            image_dict,
            step=self.trainer.global_step,
        )