import os
import numpy as np
from PIL import Image

from typing import Any, Dict, List, Optional, Tuple
from einops import rearrange, repeat

import torch
import torch.nn.functional as F
import torchvision as tv
from torch import nn
import lightning
from lightning_utilities.core.rank_zero import rank_zero_only
from diffusers import DDIMScheduler, UNet2DConditionModel, AutoencoderKL, StableDiffusionXLPipeline, DDPMScheduler
from transformers import CLIPTokenizer, CLIPTextModel
from ..base import BaseSystem
import wandb
from src.models.unet.hyperlora import HyperLordModel

class SDSystem(BaseSystem):
    def __init__(
        self,
        pretrained_model_name_or_path,
        variant: str = "fp16_ema",
        cfg: float = 0.1,  # classifier free guidance
        num_inference_steps: int = 50,
        guidance_scale: int = 5.0,
        num_val_dataloaders: int = 2,
        num_test_dataloaders: int = 2,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)

        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters()
        # self.vae = pipeline.vae
        self.vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
        self.scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
        unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", variant=variant)
        self.tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer", torch_dtype=torch.float16)
        self.text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch.float16)
        self.do_classifier_free_guidance = guidance_scale > 0
        self.num_inference_steps = num_inference_steps
        self.guidance_scale = guidance_scale
        self.vae.requires_grad_(False)
        unet.requires_grad_(False)
        self.text_encoder.requires_grad_(False)
        self.model = HyperLordModel(unet)

    @torch.no_grad()
    def encode_text(self, text, device):
        text_inputs = self.tokenizer(
            text, padding="max_length", max_length=self.tokenizer.model_max_length,
            truncation=True, return_tensors="pt"
        )
        text_input_ids = text_inputs.input_ids
        if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
            attention_mask = text_inputs.attention_mask.cuda()
        else:
            attention_mask = None
        prompt_embeds = self.text_encoder(
            text_input_ids.to(device), attention_mask=attention_mask)
        return prompt_embeds[0].float(), prompt_embeds[1]
    
    def setup(self, stage: str):
        super().setup(stage)
        # use different image log method based on the logger type
        self.log_image = None
        if isinstance(self.logger, lightning.pytorch.loggers.TensorBoardLogger):
            self.log_image = self.tensorboard_log_image
        elif isinstance(self.logger, lightning.pytorch.loggers.WandbLogger):
            self.log_image = self.wandb_log_image

    def configure_optimizers(self):
        """Configure optimizers and learning rate schedulers for training."""
        # param_groups = []
        # for params, lr_scale in self.trainable_parameters:
        #     param_groups.append({"params": params, "lr": self.hparams.lr * lr_scale})

        # optimizer = torch.optim.AdamW(param_groups)
        optimizer = torch.optim.AdamW(self.model.trainable_parameters, lr=1.0e-5)
        return optimizer

    def forward(self, latents, timestep, prompt_embd, meta) -> torch.Tensor:
        return self.mv_model(latents, timestep, prompt_embd, meta)

    def encode_prompt(self, prompt, device, do_classifier_free_guidance=False, do_neg_prompt=False):
        prompt_embds = []
        for p in prompt:
            if do_neg_prompt:
                prompt_embed = self.encode_text("", device)[0]
            else:
                prompt_embed = self.encode_text(p, device)[0]

            if do_classifier_free_guidance:
                negative_prompt_embed = self.encode_text("", device)[0]
                prompt_embed = torch.cat([negative_prompt_embed, prompt_embed], dim=0)
            prompt_embds.append(prompt_embed)

        prompt_embeds = torch.concat(prompt_embds, dim=0)
        return prompt_embeds

    def training_step(self, batch, batch_idx):
        image = batch["image"]  # cond image (b c h w)
        bsz, c, h ,w = image.shape
        model_input = self.vae.encode(image).latent_dist.sample()
        model_input = model_input * self.vae.config.scaling_factor
        device, weight_dtype = model_input.device, model_input.dtype

        noise = torch.randn_like(model_input)
        timesteps = torch.randint(
            0, self.scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
        ).long()
        noisy_model_input = self.scheduler.add_noise(model_input, noise, timesteps)
        
        prompt_embeds = self.encode_prompt(batch['prompt'],device=device)
        if torch.rand(1) < self.hparams.cfg:
            prompt_embeds = self.encode_prompt(batch['prompt'],device=device, do_neg_prompt=True)

        model_pred = self.model(noisy_model_input, timesteps, prompt_embeds)

        if self.scheduler.config.prediction_type == "epsilon":
            target = noise
        elif self.scheduler.config.prediction_type == "v_prediction":
            target = self.scheduler.get_velocity(model_input, noise, timesteps)
        else:
            raise ValueError(f"Unknown prediction type {self.scheduler.config.prediction_type}")

        loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        return loss

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        images_pred = self._generate_images(batch)
        # compute image & save
        image_fp = self._save_image(
            images_pred,
            batch["prompt"],
            f"{dataloader_idx}_{batch_idx}_{self.global_rank}",
            stage="validation",
        )
        return image_fp


    def test_step(self, batch, batch_idx, dataloader_idx=0):
        images_pred = self._generate_images(batch)
        # save images
        image_fp = self._save_image(
            images_pred,
            batch["prompt"],
            f"{dataloader_idx}_{batch_idx}_{self.global_rank}",
            stage="test",
        )
        return image_fp

    @torch.no_grad()
    def _generate_images(self, batch, generator=None):
        image, prompt = batch["image"], batch["prompt"]
        bs, c, h, w = image.shape
        device = image.device

        prompt_embeds = self.encode_prompt(prompt=batch['prompt'],device=device, do_classifier_free_guidance=self.do_classifier_free_guidance)
        self.scheduler.set_timesteps(self.num_inference_steps, device=device)
        timesteps = self.scheduler.timesteps
        latents = torch.randn(bs, 4, h // 8, w // 8, device=device) * self.scheduler.init_noise_sigma
        prompt_embeds = prompt_embeds.to(device)

        for i, t in enumerate(timesteps):
            # expand the latents if we are doing classifier free guidance
            latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

            noise_pred = self.model(latent_model_input,
                t,
                encoder_hidden_states=prompt_embeds,
                return_dict=False,
            )

            # perform guidance
            if self.do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
            # compute the previous noisy sample x_t -> x_t-1
            latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

        image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
        image = (image / 2 + 0.5).clamp(0, 1)
        return image

    @torch.no_grad()
    @rank_zero_only
    def _save_image(self, images_pred, prompt, batch_idx, stage="validation"):
        save_dir = self.save_dir
        with open(
            os.path.join(save_dir, f"{stage}_{self.global_step}_{batch_idx}.txt"), "w"
        ) as f:
            f.write("\n".join(prompt))
        
        im_fp = os.path.join(
            save_dir,
            f"{stage}_{self.global_step}_{batch_idx}--{prompt[0].replace(' ', '_').replace('/', '_')}.png",
        )
        tv.utils.save_image(images_pred, im_fp)
        return im_fp

    @torch.no_grad()
    @rank_zero_only
    def _log_to_wandb(self, stage):
        import wandb

        num_dataloaders = (
            self.hparams.num_val_dataloaders
            if stage == "validation"
            else self.hparams.num_test_dataloaders
        )
        for i in range(num_dataloaders):
            captions, images = [], []
            # get images which start with {stage}_{self.global_step}_{self.dataloader_idx} from self.save_dir
            for f in os.listdir(self.save_dir):
                if f.startswith(f"{stage}_{self.global_step}_{i}") and f.endswith(".png"):
                    captions.append(f)
                    images.append(os.path.join(self.save_dir, f))

            self.logger.experiment.log(
                {
                    f"{stage}_{i}": [
                        wandb.Image(im_fp, caption=caption)
                        for im_fp, caption in zip(images, captions)
                    ]
                },
                step=self.global_step,
            )
    def tensorboard_log_image(self, tag: str, image_tensor):
            self.logger.experiment.add_image(
                tag,
                image_tensor,
                self.trainer.global_step,
            )

    def wandb_log_image(self, tag: str, image_tensor):
        image_dict = {
            tag: wandb.Image(image_tensor),
        }
        self.logger.experiment.log(
            image_dict,
            step=self.trainer.global_step,
        )