import os.path
from typing import Tuple, List, Union

import torch.optim
import torchvision
import wandb
from lightning.pytorch.core.module import MODULE_OPTIMIZERS
from torchmetrics.image import PeakSignalNoiseRatio
from lightning.pytorch.utilities.types import LRSchedulerPLType
import lightning.pytorch.loggers
from src.models.gs.renderers.renderer import Renderer

from .base import BaseSystem

import numpy as np
from PIL import Image
from io import BytesIO
from plyfile import PlyData, PlyElement
import matplotlib.pyplot as plt
from einops import rearrange, repeat
from ..utils.ops import batch_gather
class GaussianSplattingSystem(BaseSystem):
    def __init__(
            self,
            gaussian,
            renderer: Renderer,
            loss_fn,
            camera_extent_factor: float = 1.,
            # enable_appearance_model: bool = False,
            background_color: Tuple[float, float, float] = (0., 0., 0.),
            save_val_output: bool = False,
            max_save_val_output: int = -1,
    ) -> None:
        super().__init__()
        self.save_hyperparameters(ignore="gaussian, renderer, loss_fn")

        # setup models
        self.gaussian_model = gaussian
        self.gaussian_model.set_use_memory_efficient_attention_xformers(True)
        try:
            self.gaussian_model.fold_transformer.set_use_memory_efficient_attention_xformers(False)
        except:
            pass
        self.renderer = renderer
        self.loss_fn = loss_fn
        self.psnr = PeakSignalNoiseRatio()
        self.background_color = torch.tensor(background_color, dtype=torch.float32)

        self.batch_size = 1
        self.restored_epoch = 0
        self.restored_global_step = 0

    def setup(self, stage: str):
        super().setup(stage)
        
        for item in self.loss_fn.loss_list:
            if item['name'] == 'lpips_loss':
                item["loss_fn"].lpips_loss = item["loss_fn"].lpips_loss.to(self.device)
        self.renderer.setup(stage, lightning_module=self)

        # use different image log method based on the logger type
        self.log_image = None
        if isinstance(self.logger, lightning.pytorch.loggers.TensorBoardLogger):
            self.log_image = self.tensorboard_log_image
        elif isinstance(self.logger, lightning.pytorch.loggers.WandbLogger):
            self.log_image = self.wandb_log_image
            self.logger.watch(self)

    def tensorboard_log_image(self, tag: str, image_tensor):
        self.logger.experiment.add_image(
            tag,
            image_tensor,
            self.trainer.global_step,
        )

    def wandb_log_image(self, tag: str, image_tensor):
        image_dict = {
            tag: wandb.Image(image_tensor),
        }
        self.logger.experiment.log(
            image_dict,
            step=self.trainer.global_step,
        )

    def forward(self, cameras, images):
        outputs_list = self.gaussian_model(
             renderer = self.renderer, 
             background_color = self.background_color, 
             cameras = cameras,
             images = images
        )
        return outputs_list

    def forward_with_loss_calculation(self, batch):
        src_indices, target_indices, cameras, images, depths = batch['src_indices'], batch['target_indices'], batch['cameras'], batch['images'], batch['depths']

        src_images, target_images, target_depths = batch_gather(images, src_indices), batch_gather(images, target_indices), batch_gather(depths, target_indices)
        # forward
        outputs_list = self(cameras, src_images)
        # reverse output
        outputs_list = outputs_list[::-1]
        gt_image = target_images
        batch['render'] = target_images
        batch['depth'] = target_depths


        masked_pixels = None
        info_list = []
        for layer_idx, outputs in enumerate(outputs_list):
            image = batch["render"]
            if masked_pixels is not None:
                image[masked_pixels] = gt_image[masked_pixels]
            # calculate loss
            info = self.loss_fn(outputs, batch)
            info_list.append(info)
        
        total_info = {}
        total_loss = 0
        for layer_idx, info in enumerate(info_list):
            for key, value in info.items():
                total_info["{}_{}".format(key, layer_idx)] = value
                if key == "total_loss":
                    total_loss += value
        total_info["final_loss"] = total_loss
    
        return outputs_list, total_info


    def training_step(self, batch, batch_idx):
        # forward
        outputs_list, loss_info = self.forward_with_loss_calculation(batch)

        self.log_dict(loss_info, on_step=True, on_epoch=False, prog_bar=True, batch_size=self.batch_size)

        return loss_info["final_loss"]

    def validation_step(self, batch, batch_idx, name: str = "val"):
        # batch size must = 1
        assert batch["src_indices"].shape[0] == 1

        target_indices, images, depths = batch['target_indices'], batch['images'], batch['depths']

        gt_image = batch_gather(images, target_indices)
        gt_depth = batch_gather(depths, target_indices)

        # forward
        outputs_list, loss_info = self.forward_with_loss_calculation(batch)

        self.log_dict(loss_info, on_step=False, on_epoch=True, prog_bar=True, batch_size=self.batch_size)

        for layer_idx, outputs in enumerate(outputs_list):
            self.log(f"{name}/psnr_{layer_idx}", self.psnr(outputs["render"], gt_image), on_step=False, on_epoch=True, prog_bar=True, logger=True)
        
        # write validation image
        if self.trainer.global_rank == 0 and self.hparams["save_val_output"] is True and (
                self.hparams["max_save_val_output"] < 0 or batch_idx < self.hparams["max_save_val_output"]
        ):
            for layer_idx, outputs in enumerate(outputs_list):
                cat_img = torch.concat([outputs["render"], gt_image], dim=-1).squeeze(0)
                if self.log_image is not None:
                    grid = torchvision.utils.make_grid(cat_img, nrow=2)
                    self.log_image(
                        tag="{}_images/{}_{}".format(name, batch_idx, layer_idx),
                        image_tensor=grid,
                    )

                image_output_path = os.path.join(
                    self.save_dir,
                    name,
                    "epoch={}-step={}".format(
                        max(self.trainer.current_epoch, self.restored_epoch),
                        max(self.trainer.global_step, self.restored_global_step),
                    ),
                    "{}_{}.png".format(batch_idx, layer_idx)
                )
                os.makedirs(os.path.dirname(image_output_path), exist_ok=True)
                torchvision.utils.save_image(
                    cat_img,
                    image_output_path,
                )
                depth_image = outputs["depth"]

                delta_depth = torch.abs(depth_image - gt_depth)
                depth_image = (depth_image - depth_image.min()) / (depth_image.max() - depth_image.min())
                gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() - gt_depth.min())
                cat_depth = torch.concat([depth_image, gt_depth, delta_depth], dim=-1).squeeze(0)
                if self.log_image is not None:
                    grid = torchvision.utils.make_grid(cat_depth, nrow=2)
                    self.log_image(
                        tag="{}_depth/{}_{}".format(name, batch_idx, layer_idx),
                        image_tensor=grid,
                    )
                torchvision.utils.save_image(
                    cat_depth,
                    image_output_path.replace(".png", "_depth.png"),
                )
                # log point cloud


    @torch.no_grad()
    def save_ply(self, path):
        os.makedirs(os.path.dirname(path), exist_ok=True)
        outputs_list = self.gaussian_model()
        for layer_idx, outputs in enumerate(outputs_list):
            opacity = outputs["opacity"].squeeze(0).cpu().numpy()
            scale = outputs["scale"].squeeze(0).cpu().numpy()
            rotation = outputs["rotation"].squeeze(0).cpu().numpy()
            color = outputs["color"].squeeze(0).cpu().numpy()
            xyz = outputs["means3D"].squeeze(0).cpu().numpy()

            if color.shape[1] == 16:
                color = np.zeros_like(xyz)
            
            # normals = np.zeros_like(xyz)
            # dtype_full = [(attribute, 'f4') for attribute in self.gaussian_model.construct_list_of_attributes()]
            # elements = np.empty(xyz.shape[0], dtype=dtype_full)
            # attributes = np.concatenate((xyz, normals, opacity, scale, rotation), axis=1)
            # elements[:] = list(map(tuple, attributes))
            # el = PlyElement.describe(elements, 'vertex')
            # PlyData([el]).write(path)
            if "WandbLogger" in self.logger.__class__.__name__:
                self.logger.experiment.log({"val_hist/point_cloud_{}".format(layer_idx): wandb.Object3D(xyz)})
                hist = self.visual_hist(xyz, color, scale, opacity)
                self.logger.experiment.log({"val_hist/hist_img_{}".format(layer_idx): wandb.Image(hist)})


    def test_step(self, batch, batch_idx):
        return self.validation_step(batch, batch_idx, name="test")

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.gaussian_model.parameters(), lr=1.0e-5)
        return optimizer

    def visual_hist(self, xyz, rgb, scaling, opacity):
        x, y, z = xyz[:, 0], xyz[:, 1], xyz[:, 2]
        plt.subplot(2, 2, 1)
        plt.hist(x, bins=100, color='r', alpha=0.7)
        plt.hist(y, bins=100, color='g', alpha=0.7)
        plt.hist(z, bins=100, color='b', alpha=0.7)
        plt.title("xyz")
        plt.subplot(2, 2, 2)
        r, g, b = rgb[:, 0], rgb[:, 1], rgb[:, 2]
        plt.hist(r, bins=100, color='r', alpha=0.7)
        plt.hist(g, bins=100, color='g', alpha=0.7)
        plt.hist(b, bins=100, color='b', alpha=0.7)
        plt.title("rgb")
        plt.subplot(2, 2, 3)
        s1, s2, s3 = scaling[:, 0], scaling[:, 1], scaling[:, 2]
        plt.hist(s1, bins=100, color='r', alpha=0.7)
        plt.hist(s2, bins=100, color='g', alpha=0.7)
        plt.hist(s3, bins=100, color='b', alpha=0.7)
        plt.title("scaling")
        plt.subplot(2, 2, 4)
        opacity = opacity.flatten()
        plt.hist(opacity, bins=100, color='r')
        plt.title("opacity")
        plt.legend()
        plt.tight_layout()
        buf = BytesIO()
        plt.savefig(buf, format='png')
        buf.seek(0)

        # 使用PIL打开这个图像，并转换为NumPy数组
        image = Image.open(buf)
        image_np = np.array(image)

        # 关闭BytesIO对象
        buf.close()
        return image_np
    

