import os
import numpy as np
from PIL import Image
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple
from einops import rearrange, repeat

import torch
import torch.nn.functional as F
import torchvision 
from torch import nn
import lightning
from lightning import LightningModule
import wandb
from lightning_utilities.core.rank_zero import rank_zero_only
from torchmetrics import MaxMetric, MeanMetric
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure
from torchmetrics.image import PeakSignalNoiseRatio
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from diffusers import EulerDiscreteScheduler, UNetSpatioTemporalConditionModel, AutoencoderKLTemporalDecoder, StableVideoDiffusionPipeline
from diffusers.image_processor import VaeImageProcessor
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import _resize_with_antialiasing
from src.models.ema import LitEma
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
from ..base import BaseSystem
from safetensors import safe_open
from torch.cuda.amp import autocast

from geomloss import SamplesLoss

def _get_add_time_ids(
    unet,
    fps,
    motion_bucket_id,
    noise_aug_strength,
    dtype,
    batch_size,
):
    add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
    passed_add_embed_dim = unet.config.addition_time_embed_dim * len(add_time_ids)
    expected_add_embed_dim = unet.add_embedding.linear_1.in_features

    if expected_add_embed_dim != passed_add_embed_dim:
        raise ValueError(
            f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
        )
    add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
    add_time_ids = add_time_ids.repeat(batch_size, 1)
    return add_time_ids

def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
    """Draws samples from an lognormal distribution."""
    u = torch.rand(shape, dtype=dtype, device=device) * (1 - 2e-7) + 1e-7
    return torch.distributions.Normal(loc, scale).icdf(u).exp()


class SVDSystem(BaseSystem):
    def __init__(
        self,
        lr: float,
        mv_model: torch.nn.Module,
        recon_model: torch.nn.Module,
        base_model_id: str = "stabilityai/stable-video-diffusion-img2vid",
        variant: str = "fp16",
        cfg: float = 0.1, 
        report_to: str = "wandb",
        ema_decay_rate: float = 0.9999,
        compile: bool = False,
        use_ema: bool = False,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)

        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False, ignore=["mv_model", "recon_model"])

        self.pipeline = StableVideoDiffusionPipeline.from_pretrained(base_model_id, subfolder="pipeline", variant=variant)
        # self.pipeline.enable_model_cpu_offload()

        self.scheduler = self.pipeline.scheduler
        self.image_encoder = self.pipeline.image_encoder
        self.vae = self.pipeline.vae
        self.unet = self.pipeline.unet
        self.feature_extractor = self.pipeline.feature_extractor
        self.image_processor = self.pipeline.image_processor
        self.image_encoder.requires_grad_(False)
        self.vae.requires_grad_(False)
        self.unet.requires_grad_(True)

        self.mv_model = mv_model(self.pipeline.unet)
        self.recon_model = recon_model

        model_path = "pretrain/LGM/model_fp16.safetensors"        
        tensors = {}

        with safe_open(model_path, framework="pt", device="cpu") as f:
            for key in f.keys():
                tensors[key] = f.get_tensor(key)
                if key == "unet.conv_in.weight":
                    rgb_weight = tensors[key][:, :3, :, :]
                    ray_weight = tensors[key][:, 3:, :, :]
                    shape_list = list(tensors[key].shape)
                    shape_list[1] = 4 + 6
                    new_weight = torch.zeros(shape_list)
                    new_weight[:, 4:, :, :] = ray_weight
                    tensors[key] = new_weight
                if self.recon_model.output_channels == 12:
                    if key == "unet.conv_out.weight":
                        xyz_weight = tensors[key][:3, :, :, :]
                        other_weight = tensors[key][3:, :, :, :]
                        shape_list = list(tensors[key].shape)
                        shape_list[0] = 1 + 11
                        new_weight = torch.zeros(shape_list)
                        new_weight[1:, :, :, :] = other_weight
                        tensors[key] = new_weight
                    if key == "conv.weight" or key == "conv.bias":
                        # remove last conv layer
                        tensors.pop(key)
                    if key == "unet.conv_out.bias":
                        # tensors[key] = torch.zeros(12)
                        tensors.pop(key)



        missing_keys, unexpected_keys = self.recon_model.load_state_dict(tensors, strict=False)


        # metric objects for calculating and averaging accuracy across batches
        self.psnr = PeakSignalNoiseRatio(data_range=1.0)
        self.ssim =StructuralSimilarityIndexMeasure() 
        self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True) 
        self.emd = SamplesLoss("sinkhorn", blur=0.01)

        self.use_ema = use_ema
        if use_ema:
            self.model_ema = LitEma(self.mv_model, decay=ema_decay_rate)

        self.trainable_parameters = [
            (self.mv_model.parameters(), 1.0),
            (self.recon_model.parameters(), 1.0),
        ]

        self.num_inference_steps = 25
        self.min_guidance_scale = 1.0
        self.max_guidance_scale = 3.0
        self.conditioning_dropout_prob = cfg

    @contextmanager
    def ema_scope(self, context=None):
        if self.use_ema:
            self.model_ema.store(self.model.parameters())
            self.model_ema.copy_to(self.model)
            if context is not None:
                print(f"{context}: Switched to EMA weights")
        try:
            yield None
        finally:
            if self.use_ema:
                self.model_ema.restore(self.model.parameters())
                if context is not None:
                    print(f"{context}: Restored training weights")

    def setup(self, stage: str) -> None:
        super().setup(stage)
        if self.hparams.compile and stage == "fit":
            self.mv_model = torch.compile(self.mv_model)
        
        self.log_image = None
        if isinstance(self.logger, lightning.pytorch.loggers.TensorBoardLogger):
            self.log_image = self.tensorboard_log_image
        elif isinstance(self.logger, lightning.pytorch.loggers.WandbLogger):
            self.log_image = self.wandb_log_image
            self.logger.watch(self)

    def configure_optimizers(self):
        """Configure optimizers and learning rate schedulers for training."""
        param_groups = []
        for params, lr_scale in self.trainable_parameters:
            param_groups.append({"params": params, "lr": self.hparams.lr * lr_scale})

        optimizer = torch.optim.AdamW(param_groups)
        return optimizer

    def forward(self, latents, timestep, encoder_hidden_states, added_time_ids, cond):
        return self.mv_model(latents, timestep, encoder_hidden_states, added_time_ids, cond)

    
    def on_train_batch_end(self, *args, **kwargs):
        if self.use_ema:
            self.model_ema(self.model)

    def training_step(self, batch, batch_idx):

        condition_image = batch["condition_image"]  # cond image b x c x h x w
        diffusion_images = batch["diffusion_images"]  # b x m x c x h x w #  [-1, 1]
        bsz, m, c, h, w = diffusion_images.shape
        dtype = diffusion_images.dtype
        latents = self.tensor_to_vae_latent(diffusion_images, self.vae)

        sigmas = rand_log_normal(shape=[bsz, 1, 1, 1, 1], loc=0.7, scale=1.6).to(latents.device)
        timesteps = torch.Tensor(
                    [0.25 * sigma.log() for sigma in sigmas]).to(latents.device)

        render_x1 = self.recon_model(latents, batch, timesteps)

        # use emd loss
        def emd_loss(pred, gt_images, target_resolution):
            # resize the images to target resolution
            pred = rearrange(pred, "b m c h w -> (b m) c h w")
            gt_images = rearrange(gt_images, "b m c h w -> (b m) c h w")
            pred = F.interpolate(pred, size=target_resolution, mode="bilinear", align_corners=False)
            gt_images = F.interpolate(gt_images, size=target_resolution, mode="bilinear", align_corners=False)
            recon_gt = torch.zeros(bsz*m, 5, target_resolution, target_resolution, device=diffusion_images.device)
            recon_gt[:, 0:3, :, :] = gt_images
            xy = torch.meshgrid(torch.linspace(-1, 1, target_resolution), torch.linspace(-1, 1, target_resolution))
            xy = torch.stack(xy, dim=-1).to(diffusion_images.device) # h x w x 2
            xy = repeat(xy, 'h w c -> (b m) c h w', b=bsz, m=m)
            recon_gt[:, 3:5, :, :] = xy
            recon_gt = rearrange(recon_gt, "bm c h w -> bm (h w) c").contiguous()

            recon_pred = torch.zeros(bsz*m, 5, target_resolution, target_resolution, device=diffusion_images.device)
            recon_pred[:, 0:3, :, :] = pred * 2 - 1
            xy = torch.meshgrid(torch.linspace(-1, 1, target_resolution), torch.linspace(-1, 1, target_resolution))
            xy = torch.stack(xy, dim=-1).to(diffusion_images.device)
            xy = repeat(xy, 'h w c -> (b m) c h w', b=bsz, m=m)
            recon_pred[:, 3:5, :, :] = xy
            recon_pred = rearrange(recon_pred, "bm c h w -> bm (h w) c").contiguous()

            emd_loss = self.emd(recon_gt.float(), recon_pred.float())
            recon_loss = emd_loss.mean()
            return recon_loss

        
        recon_images = diffusion_images * 0.5 + 0.5
        recon_loss = torch.abs(render_x1['images_pred'] - recon_images).mean()

        recon_loss += emd_loss(render_x1['images_pred'], recon_images, 128)

        # recon_loss += 0.2 * self.ssim(
        #     render_x1['images_pred'].reshape(-1, 3, h, w).to(dtype),
        #     recon_images.reshape(-1, 3, h, w)
        # )
        recon_loss += 0.2 * self.lpips(
            render_x1['images_pred'].reshape(-1, 3, h, w).to(dtype),
            recon_images.reshape(-1, 3, h, w)
        )
        self.log("recon_loss", recon_loss,  prog_bar=True)

        return recon_loss

    def inference_step(self, batch, batch_idx, dataloader_idx=0, stage = "val"):
        images_pred, gaussions = self._generate_images(batch) # image in [0, 1] 1 x 8 x 3 x 512 x 512
        images_gt = batch['diffusion_images'] * 0.5 + 0.5
        image_fp = self._save_image(images_pred, images_gt, batch["prompt"], f"{batch_idx}_{self.global_rank}", stage=stage)

        images_gt = rearrange(images_gt, "b m c h w -> (b m) c h w")
        images_pred = rearrange(images_pred, "b m c h w -> (b m) c h w")

        with autocast(dtype=torch.float32):
            psnr = self.psnr(images_gt.float(), images_pred.float())
            ssim = self.ssim(images_gt.float(), images_pred.float())
            lpips = self.lpips(images_gt.float(), images_pred.float())
        self.log(f"{stage}_psnr", psnr, on_step=False, on_epoch=True, prog_bar=True, add_dataloader_idx=False, sync_dist=True)
        self.log(f"{stage}_ssim", ssim, on_step=False, on_epoch=True, prog_bar=True, add_dataloader_idx=False, sync_dist=True)
        self.log(f"{stage}_lpips", lpips, on_step=False, on_epoch=True, prog_bar=True, add_dataloader_idx=False, sync_dist=True)
        if "WandbLogger" in self.logger.__class__.__name__:
            xyz = gaussions[..., :3].reshape(-1, 3).cpu().numpy()
            opacity = gaussions[..., 3:4].reshape(-1, 1).cpu().numpy()
            scale = gaussions[..., 4:7].reshape(-1, 3).cpu().numpy()
            rotation = gaussions[..., 7:11].reshape(-1, 4).cpu().numpy()
            rgbs = gaussions[..., 11:].reshape(-1, 3).cpu().numpy()
            hist = self.visual_hist(xyz, rgbs, scale, opacity)
            self.logger.experiment.log({"val/point_cloud": wandb.Object3D(xyz)}, step=self.global_step)
            self.logger.experiment.log({"val/hist_img": wandb.Image(hist)})

        return image_fp

    @rank_zero_only
    def save_cond_image(self, cond, stage, batch_idx):
        cond_image = rearrange(cond, "b m c h w -> b c h (m w)")
        grid = torchvision.utils.make_grid(cond_image, nrow=2)
        self.log_image(
            tag="{}_cond_images/{}".format(stage, batch_idx),
            image_tensor=grid,
        )

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        return self.inference_step(batch, batch_idx, dataloader_idx, stage = "val")

    def test_step(self, batch, batch_idx, dataloader_idx):
        return self.inference_step(batch, batch_idx, dataloader_idx, stage = "test")

    @torch.no_grad()
    def _encode_cond(self, image, do_classifier_free_guidance=False):
        device, dtype = image.device, image.dtype
        image = image.to(torch.float32)
        image = _resize_with_antialiasing(image, (224, 224))
        image = (image + 1.0) / 2.0
        image = image.to(dtype)
    
        # Normalize the image with for CLIP input
        image = self.feature_extractor(
            images=image,
            do_normalize=True,
            do_center_crop=False,
            do_resize=False,
            do_rescale=False,
            return_tensors="pt",
        ).pixel_values
        image = image.to(device=device, dtype=dtype)
        image_embeddings = self.image_encoder(image).image_embeds
        image_embeddings = image_embeddings.unsqueeze(1) # b x 1 x 768

        if do_classifier_free_guidance:
            negative_image_embeddings = torch.zeros_like(image_embeddings)
            image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
        return image_embeddings
    
    @torch.no_grad()
    def tensor_to_vae_latent(self, t, vae, needs_upcasting = True):
        ori_shape_len = len(t.shape)
        if ori_shape_len == 4:
            t = t.unsqueeze(1)
        video_length = t.shape[1]

        t = rearrange(t, "b f c h w -> (b f) c h w")
        if needs_upcasting:
            vae.to(dtype=torch.float32)
        dtype = next(vae.parameters()).dtype
        t = t.to(dtype)
        latents = vae.encode(t).latent_dist.sample()
        if needs_upcasting:
            vae.to(dtype=torch.float16)
        latents = rearrange(latents, "(b f) c h w -> b f c h w", f=video_length)
        if ori_shape_len == 4:
            latents = latents.squeeze(1)
        latents = latents * vae.config.scaling_factor
        return latents
    
    @torch.no_grad()
    def _generate_images(self, batch):
        do_classifier_free_guidance = self.max_guidance_scale > 1.0

        condition_image = batch["condition_image"]  # cond image b x c x h x w
        diffusion_images = batch["diffusion_images"]  # b x m x c x h x w

        (b, m, c, h, w), device = diffusion_images.shape, diffusion_images.device
        dtype = diffusion_images.dtype
        latents = self.tensor_to_vae_latent(diffusion_images, self.vae)
        
        self.scheduler.set_timesteps(self.num_inference_steps, device=device)

        timesteps = self.scheduler.timesteps
        t = timesteps[-1]
        results = self.recon_model(latents, batch, t)
        preds = results['images_pred'].to(dtype)
        
        return preds, results['gaussians']

    @torch.no_grad()
    @rank_zero_only
    def _save_image(self, images_pred, images, prompt, batch_idx, stage="validation"):
        save_dir = self.save_dir
        if self.log_image is not None:
            _images = rearrange(images, "b m c h w -> 1 c (b h) (m w)")
            _images_pred = rearrange(images_pred, "b m c h w ->1 c (b h) (m w)")
            _full_image = torch.concat([_images, _images_pred], axis=2) 
            grid = torchvision.utils.make_grid(_full_image, nrow=2)
            self.log_image(
                tag="{}_images/{}".format(stage, batch_idx),
                image_tensor=grid,
            )
        images = rearrange(images, "b m c h w -> (b h) (m w) c")
        images_pred = rearrange(images_pred, "b m c h w -> (b h) (m w) c")
        full_image = torch.concat([images, images_pred], axis=0) 
        full_image = (full_image * 255).cpu().numpy().astype(np.uint8)
        with open(
            os.path.join(save_dir, f"{stage}_{self.global_step}_{batch_idx}.txt"), "w"
        ) as f:
            f.write("\n".join(prompt))
        
        im = Image.fromarray(full_image)
        im_fp = os.path.join(
            save_dir,
            f"{stage}_{self.global_step}_{batch_idx}--{prompt[0].replace(' ', '_').replace('/', '_')}.png",
        )
        im.save(im_fp)
        return im_fp

    def visual_hist(self, xyz, rgb, scaling, opacity):
        import matplotlib.pyplot as plt
        from io import BytesIO
    
        x, y, z = xyz[:, 0], xyz[:, 1], xyz[:, 2]
        plt.subplot(2, 2, 1)
        plt.hist(x, bins=100, color='r', alpha=0.7)
        plt.hist(y, bins=100, color='g', alpha=0.7)
        plt.hist(z, bins=100, color='b', alpha=0.7)
        plt.title("xyz")
        plt.subplot(2, 2, 2)
        r, g, b = rgb[:, 0], rgb[:, 1], rgb[:, 2]
        plt.hist(r, bins=100, color='r', alpha=0.7)
        plt.hist(g, bins=100, color='g', alpha=0.7)
        plt.hist(b, bins=100, color='b', alpha=0.7)
        plt.title("rgb")
        plt.subplot(2, 2, 3)
        s1, s2, s3 = scaling[:, 0], scaling[:, 1], scaling[:, 2]
        plt.hist(s1, bins=100, color='r', alpha=0.7)
        plt.hist(s2, bins=100, color='g', alpha=0.7)
        plt.hist(s3, bins=100, color='b', alpha=0.7)
        plt.title("scaling")
        plt.subplot(2, 2, 4)
        opacity = opacity.flatten()
        plt.hist(opacity, bins=100, color='r')
        plt.title("opacity")
        plt.legend()
        plt.tight_layout()
        buf = BytesIO()
        plt.savefig(buf, format='png')
        buf.seek(0)

        # 使用PIL打开这个图像，并转换为NumPy数组
        image = Image.open(buf)
        image_np = np.array(image)

        # 关闭BytesIO对象
        buf.close()
        return image_np

    def tensorboard_log_image(self, tag: str, image_tensor):
        self.logger.experiment.add_image(
            tag,
            image_tensor,
            self.trainer.global_step,
        )

    def wandb_log_image(self, tag: str, image_tensor):
        image_dict = {
            tag: wandb.Image(image_tensor),
        }
        self.logger.experiment.log(
            image_dict,
            step=self.trainer.global_step,
        )
