import os
from lightning.pytorch.utilities.types import STEP_OUTPUT
import numpy as np
from PIL import Image

from typing import Any, Dict, List, Optional, Tuple
from einops import rearrange, repeat

import torch
import torch.nn.functional as F
import torchvision as tv
from torch import nn
from tqdm import tqdm
from lightning_utilities.core.rank_zero import rank_zero_only
from diffusers import DDIMScheduler, UNet2DModel, DDIMPipeline
from src.models.ema import EMA
from ..base import BaseSystem
from ...utils.visual import concatenate_images
from torchmetrics.image.fid import FrechetInceptionDistance
from ...models.network.unet import DeepSupervisonUNet2DModel

class DiffusionSystem(BaseSystem):
    def __init__(
        self,
        diffusion_timesteps = 50,
        lr = 2.0e-4,
        mode = "epsilon",
        ema_decay=0.9999,
        ema_update_every=10,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)

        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False, ignore=["diffusion"])
        self.scheduler = DDIMScheduler(num_train_timesteps=1000, 
                                       beta_start=0.0001,
                                       beta_end=0.02,
                                       beta_schedule='linear',
                                       prediction_type=mode)
        # config = UNet2DModel.load_config()
        self.unet = DeepSupervisonUNet2DModel.from_config("google/ddpm-cifar10-32")
        # self.unet = DDIMPipeline.from_pretrained("pretrain/cifar10").unet
        self.ema = EMA(self.unet, beta = ema_decay, update_every = ema_update_every)
        self.lr = lr
        self.fid = FrechetInceptionDistance(reset_real_features=False)
    
    @rank_zero_only
    def add_real_data(self):
        print("loading real data")
        real_data = torch.load("data/cifar10/real_features.pt").to(self.device)
        self.fid.to(self.device)
        max_size = real_data.shape[0]
        step = 500
        for i in tqdm(range(max_size//step)):
            data = real_data[i*step:(i+1)*step]
            self.fid.update(data, real=True)
        print("load fid real data finish!")

    def setup(self, stage: str) -> None:
        super().setup(stage)
        self.add_real_data()

    def configure_optimizers(self):
        """Configure optimizers and learning rate schedulers for training."""
        param_groups = []
        param_groups.append({"params": self.unet.parameters() , "lr": self.lr})
        optimizer = torch.optim.AdamW(param_groups)
        return optimizer

    def training_step(self, batch, batch_idx):
        images = batch["images"]  # cond image (b c h w)
        t = torch.randint(0, self.scheduler.config.num_train_timesteps, (images.shape[0],), device=images.device).long()
        noise = torch.randn_like(images)
        noise_z = self.scheduler.add_noise(images, noise, t)
        denoise = self.unet(noise_z, t)

        if self.hparams.mode == "epsilon":
            target = noise
        else:
            target = images
        total_loss = []
        for idx, item in enumerate(denoise):
            loss = F.mse_loss(item, target)
            self.log("train_loss_{}".format(idx), loss)
            total_loss.append(loss)
        total_loss = sum(total_loss) / len(total_loss)
        self.log("train_loss", total_loss, prog_bar=True)
        return total_loss
    
    def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None:
        self.ema.update()

    def on_validation_epoch_start(self) -> None:
        return super().on_validation_epoch_start()

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        images_pred = self._generate_images(batch) # 
        images_pred = ((images_pred / 2 + 0.5) * 255).to(torch.uint8)
        self.fid.update(images_pred, False)
        images_pred = images_pred.cpu().numpy()
        # compute image & save
        images_pred = rearrange(images_pred, "b c h w -> h (b w) c")
        images_pred = Image.fromarray(images_pred)
        image_fp = self._save_image(
            images_pred,
            "",
            f"{dataloader_idx}_{batch_idx}_{self.global_rank}",
            stage="validation",
        )
        return images_pred

    def on_validation_epoch_end(self) -> None:
        # if "wandb" in str(self.logger):
        #     self._log_to_wandb("validation")
        fid = self.fid.compute()
        self.fid.reset()
        self.log("fid", fid, prog_bar=True)

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        images_pred = self._generate_images(batch)
        images = ((images_pred / 2 + 0.5) * 255).cpu().numpy().astype(np.uint8)
        # compute image & save
        images = rearrange(images, "b c h w -> h (b w) c")
        images = Image.fromarray(images)
        # save images
        image_fp = self._save_image(images, "", batch_idx, stage="test")

        return image_fp

    def on_test_epoch_end(self) -> None:
        """Lightning hook that is called when a test epoch ends."""

        # log images
        if "wandb" in str(self.logger):
            self._log_to_wandb("test")

    @torch.no_grad()
    def _generate_images(self, batch):
        images = batch["images"]  # b x f x c x h x w
        bs, c, h, w = images.shape
        device = images.device
        latents = torch.randn(bs, c, h, w, device=device)
        self.scheduler.set_timesteps(self.hparams.diffusion_timesteps, device=device)
        timesteps = self.scheduler.timesteps
        for i, t in enumerate(timesteps):
            _timestep = t
            noise_pred = self.ema(latents, _timestep)[-1]
            latents = self.scheduler.step(noise_pred, t, latents).prev_sample
        return latents

    @torch.no_grad()
    @rank_zero_only
    def _save_image(self, im, prompt, batch_idx, stage="validation"):
        save_dir = os.path.join(self.save_dir, stage)
        os.makedirs(save_dir, exist_ok=True)
        im_fp = os.path.join(
            save_dir,
            f"{self.global_step}_{batch_idx}.png",
        )
        im.save(im_fp)
        # add image to logger
        if "tensorboard" in str(self.logger) and stage == "validation":
            log_image = torch.tensor(np.array(im) / 255.0).permute(2, 0, 1).float().cpu()
            self.logger.experiment.add_image(
                f"{stage}/{self.global_step}_{batch_idx}",
                log_image,
                global_step=self.global_step,
            )
        
        return im_fp

    @torch.no_grad()
    @rank_zero_only
    def _log_to_wandb(self, stage, output_images_fp: Optional[List[Any]] = None):
        import wandb
        
        captions, images = [], []
        if output_images_fp is None:
            # get images which start with {stage}_{self.global_step} from self.save_dir
            for f in os.listdir(self.save_dir):
                if f.startswith(f"{stage}_{self.global_step}") and f.endswith(".png"):
                    captions.append(f)
                    images.append(os.path.join(self.save_dir, f))
        else:
            images = output_images_fp
            captions = [os.basename(fp) for fp in output_images_fp]

        self.logger.experiment.log(
            {
                stage: [
                    wandb.Image(im_fp, caption=caption)
                    for im_fp, caption in zip(images, captions)
                ]
            },
            step=self.global_step,
        )
