import os.path
from typing import Tuple, List, Union

import torch.optim
import torchvision
import wandb
from lightning.pytorch.core.module import MODULE_OPTIMIZERS
from torchmetrics.image import PeakSignalNoiseRatio
from lightning.pytorch.utilities.types import LRSchedulerPLType
import lightning.pytorch.loggers
from src.models.gs.renderers.renderer import Renderer

from .base import BaseSystem
import torch.nn.functional as F

import numpy as np
from PIL import Image
from io import BytesIO
from plyfile import PlyData, PlyElement
import matplotlib.pyplot as plt
from src.models.gs.configs.optimization import OptimizationParams


class GaussianSplattingSystem(BaseSystem):
    def __init__(
            self,
            gaussian,
            renderer: Renderer,
            loss_fn,
            camera_extent_factor: float = 1.,
            # enable_appearance_model: bool = False,
            background_color: Tuple[float, float, float] = (0., 0., 0.),
            save_val_output: bool = False,
            max_save_val_output: int = -1,
    ) -> None:
        super().__init__()
        self.automatic_optimization = False
        self.save_hyperparameters(ignore="gaussian, renderer, loss_fn")

        # setup models
        self.gaussian_model = gaussian
        self.gaussian_model.set_use_memory_efficient_attention_xformers(True)
        try:
            self.gaussian_model.fold_transformer.set_use_memory_efficient_attention_xformers(False)
        except:
            pass
        self.renderer = renderer
        self.loss_fn = loss_fn
        self.psnr = PeakSignalNoiseRatio()

        self.background_color = torch.tensor(background_color, dtype=torch.float32)
        self.optimization_params = OptimizationParams()
        self.batch_size = 1
        self.restored_epoch = 0
        self.restored_global_step = 0

    def setup(self, stage: str):
        super().setup(stage)
        
        self.gaussian_model.create_from_pcd(
            self.trainer.datamodule.point_cloud,
            deivce=self.device,
        )

        self.renderer.setup(stage, lightning_module=self)

        # use different image log method based on the logger type
        self.log_image = None
        if isinstance(self.logger, lightning.pytorch.loggers.TensorBoardLogger):
            self.log_image = self.tensorboard_log_image
        elif isinstance(self.logger, lightning.pytorch.loggers.WandbLogger):
            self.log_image = self.wandb_log_image

    def tensorboard_log_image(self, tag: str, image_tensor):
        self.logger.experiment.add_image(
            tag,
            image_tensor,
            self.trainer.global_step,
        )

    def wandb_log_image(self, tag: str, image_tensor):
        image_dict = {
            tag: wandb.Image(image_tensor),
        }
        self.logger.experiment.log(
            image_dict,
            step=self.trainer.global_step,
        )

    def forward(self, camera):
        if self.training is True:
            return self.renderer.training_forward(
                self.trainer.global_step,
                self,
                camera,
                self.gaussian_model,
                bg_color=self.background_color.to(camera.R.device),
            )
        return self.renderer(
            camera,
            self.gaussian_model,
            bg_color=self.background_color.to(camera.R.device),
        )

    def forward_with_loss_calculation(self, camera, image_info):
        image_name, meta, masked_pixels = image_info
        # forward
        outputs = self(camera)
        image = outputs["render"]
        gt_image = meta['render']
        # calculate loss
        if masked_pixels is not None:
            gt_image[masked_pixels] = image.detach()[masked_pixels]  # copy masked pixels from prediction to G.T.

        loss, info = self.loss_fn(outputs, meta)
        return outputs, info


    def training_step(self, batch, batch_idx):
        camera, image_info = batch
        # image_name, gt_image, masked_pixels = image_info
        
          
        global_step = self.trainer.global_step + 1
        if global_step % 1000 == 0:
            self.gaussian_model.oneupSHdegree()
        
        
        # forward
        outputs, loss_info = self.forward_with_loss_calculation(camera, image_info)
        image, viewspace_point_tensor, visibility_filter, radii = outputs["render"], outputs["viewspace_points"], \
            outputs["visibility_filter"], outputs["radii"]
        
        self.log_dict(loss_info, on_step=True, on_epoch=False, prog_bar=True, batch_size=self.batch_size)
      
        optimizers = self.optimizers()
        total_loss = loss_info["total_loss"]
        optimizers.zero_grad(set_to_none=True)
        self.manual_backward(total_loss)
        optimizers.step()
        # self.gaussian_model.update_learning_rate(global_step)



    def validation_step(self, batch, batch_idx, name: str = "val"):
        camera, image_info = batch
        gt_image = image_info[1]['render']

        # forward
        outputs, loss_info = self.forward_with_loss_calculation(camera, image_info)

        self.log_dict(loss_info, on_step=False, on_epoch=True, prog_bar=True, batch_size=self.batch_size)
        self.log(f"{name}/psnr", self.psnr(outputs["render"], gt_image), on_epoch=True, prog_bar=True,
                 batch_size=self.batch_size)

        # write validation image
        if self.trainer.global_rank == 0 and self.hparams["save_val_output"] is True and (
                self.hparams["max_save_val_output"] < 0 or batch_idx < self.hparams["max_save_val_output"]
        ):
            if self.log_image is not None:
                grid = torchvision.utils.make_grid(torch.concat([outputs["render"], gt_image], dim=-1))
                self.log_image(
                    tag="{}_images/{}".format(name, image_info[0].replace("/", "_")),
                    image_tensor=grid,
                )

            image_output_path = os.path.join(
                self.save_dir,
                name,
                "epoch={}-step={}".format(
                    max(self.trainer.current_epoch, self.restored_epoch),
                    max(self.trainer.global_step, self.restored_global_step),
                ),
                "{}.png".format(image_info[0].replace("/", "_"))
            )
            os.makedirs(os.path.dirname(image_output_path), exist_ok=True)
            torchvision.utils.save_image(
                torch.concat([outputs["render"], gt_image], dim=-1),
                image_output_path,
            )
            depth_image = outputs["depth"]
            gt_depth = image_info[1]['depth']
            delta_depth = torch.abs(depth_image - gt_depth)

            depth_image = (depth_image - depth_image.min()) / (depth_image.max() - depth_image.min())
            gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() - gt_depth.min())
            
            if self.log_image is not None:
                grid = torchvision.utils.make_grid(torch.concat([depth_image, gt_depth, delta_depth], dim=-1))
                self.log_image(
                    tag="{}_depth/{}".format(name, image_info[0].replace("/", "_")),
                    image_tensor=grid,
                )

            torchvision.utils.save_image(
                torch.concat([depth_image, gt_depth, delta_depth], dim=-1),
                image_output_path.replace(".png", "_depth.png"),
            )
            # log point cloud


    @torch.no_grad()
    def save_ply(self, path):
        os.makedirs(os.path.dirname(path), exist_ok=True)
        xyz = self.gaussian_model.get_xyz().cpu().numpy()

        opacities = self.gaussian_model.get_opacity().cpu().numpy()
        scale = self.gaussian_model.get_scaling().cpu().numpy()
        rotation = self.gaussian_model.get_rotation().cpu().numpy()
        color = self.gaussian_model.get_color().cpu().numpy()
        
        normals = np.zeros_like(xyz)
        dtype_full = [(attribute, 'f4') for attribute in self.gaussian_model.construct_list_of_attributes()]
        elements = np.empty(xyz.shape[0], dtype=dtype_full)
        attributes = np.concatenate((xyz, normals, opacities, scale, rotation), axis=1)
        elements[:] = list(map(tuple, attributes))
        el = PlyElement.describe(elements, 'vertex')
        PlyData([el]).write(path)

        if "WandbLogger" in self.logger.__class__.__name__:
            self.logger.experiment.log({"val_hist/point_cloud": wandb.Object3D(xyz)})
            hist = self.visual_hist(xyz, color, scale, opacities)
            self.logger.experiment.log({"val_hist/hist_img": wandb.Image(hist)})


    def test_step(self, batch, batch_idx):
        return self.validation_step(batch, batch_idx, name="test")

    def configure_optimizers(self):
        # optimizer = torch.optim.Adam(self.gaussian_model.parameters(), lr=1.0e-5)
        # return optimizer
        
        self.cameras_extent = self.trainer.datamodule.dataparser_outputs.camera_extent
        self.gaussian_model.training_setup(self.optimization_params, self.cameras_extent)
        optimizers = self.gaussian_model.optimizer
        return optimizers

    def visual_hist(self, xyz, rgb, scaling, opacity):
        x, y, z = xyz[:, 0], xyz[:, 1], xyz[:, 2]
        plt.subplot(2, 2, 1)
        plt.hist(x, bins=100, color='r', alpha=0.7)
        plt.hist(y, bins=100, color='g', alpha=0.7)
        plt.hist(z, bins=100, color='b', alpha=0.7)
        plt.title("xyz")
        plt.subplot(2, 2, 2)
        r, g, b = rgb[:, 0], rgb[:, 1], rgb[:, 2]
        plt.hist(r, bins=100, color='r', alpha=0.7)
        plt.hist(g, bins=100, color='g', alpha=0.7)
        plt.hist(b, bins=100, color='b', alpha=0.7)
        plt.title("rgb")
        plt.subplot(2, 2, 3)
        s1, s2, s3 = scaling[:, 0], scaling[:, 1], scaling[:, 2]
        plt.hist(s1, bins=100, color='r', alpha=0.7)
        plt.hist(s2, bins=100, color='g', alpha=0.7)
        plt.hist(s3, bins=100, color='b', alpha=0.7)
        plt.title("scaling")
        plt.subplot(2, 2, 4)
        opacity = opacity.flatten()
        plt.hist(opacity, bins=100, color='r')
        plt.title("opacity")
        plt.legend()
        plt.tight_layout()
        buf = BytesIO()
        plt.savefig(buf, format='png')
        buf.seek(0)

        # 使用PIL打开这个图像，并转换为NumPy数组
        image = Image.open(buf)
        image_np = np.array(image)

        # 关闭BytesIO对象
        buf.close()
        return image_np
    

