import concurrent.futures
import json
import math
import os.path
import shutil
from concurrent.futures import ThreadPoolExecutor
from rich.progress import track
import random
from typing import Literal, Tuple, Optional
from PIL import Image
import numpy as np
os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"
import cv2
import torch.utils.data
from lightning import LightningDataModule
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS

from src.models.gs.cameras import CameraType, Camera
from src.data.dataparsers.dataparser import ImageSet
from src.data.dataparsers.dataset import DatasetParams
from src.data.dataparsers.colmap_dataparser import ColmapDataParser
from src.data.dataparsers.blender_dataparser import BlenderDataParser
from src.data.dataparsers.nsvf_dataparser import NSVFDataParser
from src.data.dataparsers.nerfies_dataparser import NerfiesDataparser
from src.utils.point import store_ply, BasicPointCloud

from tqdm import tqdm
from src.utils.misc import resolve_dir
from plyfile import PlyData, PlyElement

class Dataset(torch.utils.data.Dataset):
    def __init__(
            self,
            image_set: ImageSet,
            undistort_image: bool = True,
    ) -> None:
        super().__init__()
        self.image_set = image_set
        self.undistort_image = undistort_image
        self.image_cameras: list[Camera] = [i for i in image_set.cameras]  # store undistorted camera
        points = PlyData.read("data/nerf/my_synthetic/data/lego_8192.ply")
        self.points = np.stack([points['vertex']['x'], points['vertex']['y'], points['vertex']['z']], axis=1)

    def __len__(self):
        return len(self.image_set)
    
    
    def undistort_image(self, index:int, image_numpy_dict: dict) -> dict:
        image_numpy_undistotred_dict = {}
        camera = self.image_set.cameras[index]  # get original camera
        distortion = camera.distortion_params
        if distortion is not None and torch.any(distortion != 0.):
            assert camera.camera_type == CameraType.PERSPECTIVE
                    # build intrinsics matrix
            intrinsics_matrix = np.eye(3)
            intrinsics_matrix[0, 0] = float(camera.fx)  # fx
            intrinsics_matrix[1, 1] = float(camera.fy)  # fy
            intrinsics_matrix[0, 2] = float(camera.cx)  # cx
            intrinsics_matrix[1, 2] = float(camera.cy)  # cy
            # calculate new intrinsics matrix, without black border
            image_shape = (int(camera.width), int(camera.height))
            distortion = distortion.numpy()
            new_intrinsics_matrix, _ = cv2.getOptimalNewCameraMatrix(
                intrinsics_matrix,
                distortion,
                image_shape,
                0,
                image_shape,
            )
            
            for key in image_numpy_dict.keys():
                numpy_image = image_numpy_dict[key]
                # undistort image
                undistorted_image = cv2.undistort(numpy_image, intrinsics_matrix, distortion, None, new_intrinsics_matrix)
                # update variables
                image_numpy_undistotred_dict[key] = undistorted_image
        # if "PREVIEW_UNDISTORTED_IMAGE" in os.environ:
        #     undistorted_pil_image = Image.fromarray(undistorted_image)
        #     image_save_path = os.path.join(os.environ["PREVIEW_UNDISTORTED_IMAGE"], self.image_set.image_names[index])
        #     os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
        #     undistorted_pil_image.save(image_save_path, quality=100)

            self.image_cameras[index].camera_type = torch.tensor(CameraType.PERSPECTIVE)
            self.image_cameras[index].fx = torch.tensor(new_intrinsics_matrix[0, 0], dtype=torch.float)
            self.image_cameras[index].fy = torch.tensor(new_intrinsics_matrix[1, 1], dtype=torch.float)
            self.image_cameras[index].cx = torch.tensor(new_intrinsics_matrix[0, 2], dtype=torch.float)
            self.image_cameras[index].cy = torch.tensor(new_intrinsics_matrix[1, 2], dtype=torch.float)
            self.image_cameras[index].distortion_params = torch.zeros((4,), dtype=torch.float)
        return image_numpy_undistotred_dict
    

    def load_depth(self, depth_path):
        img = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)
        # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_NEAREST)
        mask = img > 1000 # depth = 65535 is invalid
        img[mask] = 0
        mask = ~mask
        return img[None, :, :, 0]
    
    def load_image_from_path(self, image_dict: dict) -> dict:
        for key in image_dict.keys():
            image_path = image_dict[key]
            if image_path is not None:
                if key == 'depth':
                    image_dict[key] = self.load_depth(image_path)
                else:
                    pil_image = Image.open(image_path)
                    image_dict[key] = np.array(pil_image, dtype="uint8")
                
        return image_dict

    def depth2real(self, depth, fx, fy, cx, cy):
        # depth: [1, H, W]
        # intrinsics: [3, 3]
        # c2w: [4, 4]
        # return: [H, W, 3]
        H, W = depth.shape[1:]
        x, y = torch.meshgrid(torch.arange(W), torch.arange(H))
        x = (x - cx) / fx
        y = (y - cy) / fy
        xyz = torch.stack([x, y, torch.ones_like(x)], dim=0)
        len_xyz = torch.sqrt(torch.sum(xyz ** 2, dim=0, keepdim=True))
        real_depth = depth * len_xyz

        return real_depth 

    def get_image(self, index) -> Tuple[str, torch.Tensor, Optional[torch.Tensor]]:
        # TODO: resize
        image_keys = ["render"]
        image_path_dict = {}
        image_numpy_dict = {}
        image_dict = {}
        image_path_dict["render"] = self.image_set.image_paths[index]
        if self.image_set.depth_paths is not None and self.image_set.depth_paths[index] is not None:
            image_keys.append("depth")
            image_path_dict["depth"] = self.image_set.depth_paths[index]
        if self.image_set.normal_paths is not None and self.image_set.normal_paths[index] is not None:
            image_keys.append("normal")
            image_path_dict["normal"] = self.image_set.normal_paths[index]

        image_numpy_dict = self.load_image_from_path(image_path_dict)
        
        # undistort image
        if self.undistort_image is True:
            image_numpy_dict = self.undistort_image(index, image_numpy_dict)

        for key in image_numpy_dict.keys():
            if key == 'depth':
                depth = torch.from_numpy(image_numpy_dict[key])
                # depth = self.depth2real(depth, self.image_cameras[index].fx, self.image_cameras[index].fy, self.image_cameras[index].cx, self.image_cameras[index].cy)
                image_dict[key] = depth 
                continue
            numpy_image = image_numpy_dict[key]
            # convert image to tensor
            image = torch.from_numpy(numpy_image.astype(np.float64) / 255.0)
            # remove alpha channel
            if image.shape[2] == 4:
                # TODO: sync background color with model.background_color
                background_color = torch.tensor([0., 0., 0.])
                image = image[:, :, :3] * image[:, :, 3:4] + background_color * (1 - image[:, :, 3:4])
            image = image.to(torch.float).permute(2,0,1)  # [channel, height, width]
            image_dict[key] = image
                
        mask = None
        if self.image_set.mask_paths is not None and self.image_set.mask_paths[index] is not None:
            pil_image = Image.open(self.image_set.mask_paths[index])
            mask = torch.from_numpy(np.array(pil_image))
            # mask must be single channel
            assert len(mask.shape) == 2, "the mask image must be single channel"
            # the shape of the mask must match to the image
            assert mask.shape[:2] == image.shape[:2], \
                "the shape of mask {} doesn't match to the image {}".format(mask.shape[:2], image.shape[:2])
            mask = (mask == 0).unsqueeze(-1).expand(*image.shape)  # True is the masked pixels
            mask = mask.permute(2, 0, 1)  # [channel, height, width]

        # image = image_dict["render"]
        image_dict['points'] = torch.from_numpy(self.points)
        
        return self.image_set.image_names[index], image_dict, mask

    def __getitem__(self, index) -> Tuple[Camera, Tuple]:
        return self.image_cameras[index], self.get_image(index)


class CacheDataLoader(torch.utils.data.DataLoader):

    def __init__(
            self,
            dataset: torch.utils.data.Dataset,
            max_cache_num: int,
            shuffle: bool,
            seed: int = -1,
            distributed: bool = False,
            world_size: int = -1,
            global_rank: int = -1,
            **kwargs,
    ):
        assert kwargs.get("batch_size", 1) == 1, "only batch_size=1 is supported"

        self.dataset = dataset

        super().__init__(dataset=dataset, **kwargs)

        self.shuffle = shuffle
        self.max_cache_num = max_cache_num

        # image indices to use
        self.indices = list(range(len(self.dataset)))
        if distributed is True and self.max_cache_num != 0:
            assert world_size > 0
            assert global_rank >= 0
            image_num_to_use = math.ceil(len(self.indices) / world_size)
            start = global_rank * image_num_to_use
            end = start + image_num_to_use
            indices = self.indices[start:end]
            indices += self.indices[:image_num_to_use - len(indices)]
            self.indices = indices

            print("#{} distributed indices (total: {}): {}".format(os.getpid(), len(self.indices), self.indices))

        # cache all images if max_cache_num > len(dataset)
        if self.max_cache_num >= len(self.indices):
            self.max_cache_num = -1

        self.num_workers = kwargs.get("num_workers", 0)

        if self.max_cache_num < 0:
            # cache all data
            print("cache all images")
            self.cached = self._cache_data(self.indices)

        # use dedicated random number generator foreach dataloader
        if self.shuffle is True:
            assert seed >= 0, "seed must be provided when shuffle=True"
            self.generator = torch.Generator()
            self.generator.manual_seed(seed)
            print("#{} dataloader seed to {}".format(os.getpid(), seed))

    def _cache_data(self, indices: list):
        # TODO: speedup image loading
        cached = []
        if self.num_workers > 0:
            with ThreadPoolExecutor(max_workers=self.num_workers) as e:
                for i in tqdm(
                        e.map(self.dataset.__getitem__, indices),
                        total=len(indices),
                        desc="#{} caching images (1st: {})".format(os.getpid(), indices[0]),
                ):
                    cached.append(i)
        else:
            for i in tqdm(indices, desc="#{} loading images (1st: {})".format(os.getpid(), indices[0])):
                cached.append(self.dataset.__getitem__(i))

        return cached

    def __len__(self) -> int:
        return len(self.indices)

    def __getitem__(self, idx):
        return self.dataset.__getitem__(idx)

    def __iter__(self):
        # TODO: support batching
        if self.max_cache_num < 0:
            if self.shuffle is True:
                indices = torch.randperm(len(self.cached), generator=self.generator).tolist()  # shuffle for each epoch
                # print("#{} 1st index: {}".format(os.getpid(), indices[0]))
            else:
                indices = list(range(len(self.cached)))

            for i in indices:
                yield self.cached[i]
        else:
            if self.shuffle is True:
                indices = torch.randperm(len(self.indices), generator=self.generator).tolist()  # shuffle for each epoch
                # print("#{} 1st index: {}".format(os.getpid(), indices[0]))
            else:
                indices = self.indices.copy()

            # print("#{} self.max_cache_num={}, indices: {}".format(os.getpid(), self.max_cache_num, indices))

            if self.max_cache_num == 0:
                # no cache
                for i in indices:
                    yield self.__getitem__(i)
            else:
                # cache
                # the list contains the data have not been cached
                not_cached = indices.copy()

                while not_cached:
                    # select self.max_cache_num images
                    to_cache = not_cached[:self.max_cache_num]
                    del not_cached[:self.max_cache_num]

                    # cache
                    cached = self._cache_data(to_cache)

                    for i in cached:
                        yield i


class DataModule(LightningDataModule):
    def __init__(
            self,
            path: str,
            params: DatasetParams,
            type: Literal["colmap", "blender", "nsvf"] = None,
            distributed: bool = False,
            undistort_image: bool = False,
            n_pts: int | None = None,
    ) -> None:
        r"""Load dataset

            Args:
                path: the path to the dataset

                type: the dataset type
        """

        super().__init__()
        self.save_hyperparameters()

    def setup(self, stage: str) -> None:
        super().setup(stage)
        print(self.trainer.default_root_dir)
        output_path = resolve_dir(self.trainer, None, "data")
        os.makedirs(output_path, exist_ok=True)

        # output_path = self.trainer.lightning_module.hparams["save_dir"]

        # store global rank, will be used as the seed of the CacheDataLoader
        self.global_rank = self.trainer.global_rank

        # detect dataset type
        if self.hparams["type"] is None:
            print(os.path.join(self.hparams["path"], "transforms_train.json"))
            if os.path.isdir(os.path.join(self.hparams["path"], "sparse")) is True:
                self.hparams["type"] = "colmap"
            elif os.path.exists(os.path.join(self.hparams["path"], "transforms_train.json")):
                self.hparams["type"] = "blender"
            elif os.path.exists(os.path.join(self.hparams["path"], "intrinsics.txt")) and os.path.exists(os.path.join(self.hparams["path"], "bbox.txt")):
                self.hparams["type"] = "nsvf"
            elif os.path.exists(os.path.join(self.hparams["path"], "dataset.json")):
                self.hparams["type"] = "nerfies"
            else:
                raise ValueError("Can not detect dataset type automatically")

        # build dataparser params
        dataparser_params = {
            "path": self.hparams["path"],
            "output_path": output_path,
            "global_rank": self.global_rank,
        }
        

        if self.hparams["type"] == "colmap":
            dataparser = ColmapDataParser(params=self.hparams["params"].colmap, **dataparser_params)
        elif self.hparams["type"] == "blender":
            dataparser_params["n_pts"] = self.hparams["n_pts"]
            dataparser = BlenderDataParser(params=self.hparams["params"].blender, **dataparser_params)
        elif self.hparams["type"] == "nsvf":
            dataparser = NSVFDataParser(params=self.hparams["params"].nsvf, **dataparser_params)
        elif self.hparams["type"] == "nerfies":
            dataparser = NerfiesDataparser(params=self.hparams["params"].nerfies, **dataparser_params)
        else:
            raise ValueError("unsupported dataset type {}".format(self.hparams["type"]))

        # load dataset
        self.dataparser_outputs = dataparser.get_outputs()

        self.prune_extent = self.dataparser_outputs.camera_extent
        # add background sphere: https://github.com/graphdeco-inria/gaussian-splatting/issues/300#issuecomment-1756073909
        if self.hparams["params"].add_background_sphere is True:
            # find the scene center and size
            point_max_coordinate = np.max(self.dataparser_outputs.point_cloud.xyz, axis=0)
            point_min_coordinate = np.min(self.dataparser_outputs.point_cloud.xyz, axis=0)
            scene_center = (point_max_coordinate + point_min_coordinate) / 2
            scene_size = np.max(point_max_coordinate - point_min_coordinate)
            scene_radius = scene_size / 2.
            # build unit sphere points
            n_points = self.hparams["params"].background_sphere_points
            samples = np.arange(n_points)
            y = 1 - (samples / float(n_points - 1)) * 2  # y goes from 1 to -1
            radius = np.sqrt(1 - y * y)  # radius at y
            phi = math.pi * (math.sqrt(5.) - 1.)  # golden angle in radians
            theta = phi * samples  # golden angle increment
            x = np.cos(theta) * radius
            z = np.sin(theta) * radius
            unit_sphere_points = np.concatenate([x[:, None], y[:, None], z[:, None]], axis=1)
            # build background sphere
            background_sphere_point_xyz = (unit_sphere_points * scene_radius * self.hparams["params"].background_sphere_distance) + scene_center
            background_sphere_point_rgb = np.asarray(np.random.random(background_sphere_point_xyz.shape) * 255, dtype=np.uint8)
            # add background sphere to scene
            self.dataparser_outputs.point_cloud.xyz = np.concatenate([self.dataparser_outputs.point_cloud.xyz, background_sphere_point_xyz], axis=0)
            self.dataparser_outputs.point_cloud.rgb = np.concatenate([self.dataparser_outputs.point_cloud.rgb, background_sphere_point_rgb], axis=0)
            # increase prune extent
            # TODO: resize scene_extent without changing lr
            self.prune_extent = scene_radius * self.hparams["params"].background_sphere_distance * 1.0001

            print("added {} background sphere points, rescale prune extent from {} to {}".format(n_points, self.dataparser_outputs.camera_extent, self.prune_extent))

        # convert point cloud
        self.point_cloud = BasicPointCloud(
            points=self.dataparser_outputs.point_cloud.xyz,
            colors=self.dataparser_outputs.point_cloud.rgb / 255.,
            normals=np.zeros_like(self.dataparser_outputs.point_cloud.xyz),
        )

        # write some files that SIBR_viewer required
        if self.global_rank == 0 and stage == "fit":
            # write appearance group id
            if self.dataparser_outputs.appearance_group_ids is not None:
                torch.save(
                    self.dataparser_outputs.appearance_group_ids,
                    os.path.join(output_path, "appearance_group_ids.pth"),
                )
                with open(os.path.join(output_path, "appearance_group_ids.json"), "w") as f:
                    json.dump(self.dataparser_outputs.appearance_group_ids, f, indent=4, ensure_ascii=False)

            # write cameras.json
            camera_to_world = torch.linalg.inv(
                torch.transpose(self.dataparser_outputs.train_set.cameras.world_to_camera, 1, 2)
            ).numpy()
            cameras = []
            for idx, image_dict in enumerate(self.dataparser_outputs.train_set):
                image_name = image_dict["image_name"]
                camera = image_dict["camera"]
                cameras.append({
                    'id': idx,
                    'img_name': image_name,
                    'width': int(camera.width),
                    'height': int(camera.height),
                    'position': camera_to_world[idx, :3, 3].tolist(),
                    'rotation': [x.tolist() for x in camera_to_world[idx, :3, :3]],
                    'fy': float(camera.fy),
                    'fx': float(camera.fx),
                })
            with open(os.path.join(output_path, "cameras.json"), "w") as f:
                json.dump(cameras, f, indent=4, ensure_ascii=False)

            # save input point cloud to ply file
            store_ply(
                os.path.join(output_path, "input.ply"),
                xyz=self.dataparser_outputs.point_cloud.xyz,
                rgb=self.dataparser_outputs.point_cloud.rgb,
            )



    def train_dataloader(self) -> TRAIN_DATALOADERS:
        return CacheDataLoader(
            Dataset(self.dataparser_outputs.train_set, undistort_image=self.hparams["undistort_image"]),
            max_cache_num=self.hparams["params"].train_max_num_images_to_cache,
            shuffle=True,
            seed=torch.initial_seed() + self.global_rank,  # seed with global rank
            num_workers=self.hparams["params"].num_workers,
            distributed=self.hparams["distributed"],
            world_size=self.trainer.world_size,
            global_rank=self.trainer.global_rank,
        )

    def test_dataloader(self) -> EVAL_DATALOADERS:
        return CacheDataLoader(
            Dataset(self.dataparser_outputs.test_set, undistort_image=self.hparams["undistort_image"]),
            max_cache_num=self.hparams["params"].test_max_num_images_to_cache,
            shuffle=False,
            num_workers=0,
        )

    def val_dataloader(self) -> EVAL_DATALOADERS:
        return CacheDataLoader(
            Dataset(self.dataparser_outputs.val_set, undistort_image=self.hparams["undistort_image"]),
            max_cache_num=self.hparams["params"].val_max_num_images_to_cache,
            shuffle=False,
            num_workers=0,
        )
