import os
import json
import math
import numpy as np
from PIL import Image
from omegaconf import DictConfig, ListConfig
from typing import Any, Tuple, Optional, List, Dict
import torch
from lightning import LightningDataModule
from torch.utils.data import Dataset, DataLoader
os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"
import cv2
from src.models.gs.cameras import Cameras
from src.utils.point import fov2focal
from torch.utils.data.dataloader import default_collate
from pytorch3d.ops import sample_farthest_points
from einops import rearrange, repeat
def collect_fn(batch):
    cameras = [item.pop('cameras') for item in batch]
    batch_processed = default_collate(batch)
    batch_processed['cameras'] = cameras
    return batch_processed
    

def point_random_sample(pc_coords: torch.TensorType, num_points: int):
    # pc_coords: N, 3
    return pc_coords[torch.randperm(pc_coords.shape[0])[:num_points]]

def ray_sample(cam2world_matrix, fx, fy, cx, cy, sk, resolution, depth_tensor):
    N, M = cam2world_matrix.shape[0], resolution**2
    uv = torch.stack(
        torch.meshgrid(
            torch.arange(
                resolution, dtype=torch.float32, device=cam2world_matrix.device
            ),
            torch.arange(
                resolution, dtype=torch.float32, device=cam2world_matrix.device
            ),
            indexing="ij",
        )
    ) + 0.5
    # * (1.0 / resolution) + (0.5 / resolution)
    uv = repeat(uv, "c h w -> b (h w) c", b=N)
    x_cam = uv[:, :, 0].view(N, -1)
    y_cam = uv[:, :, 1].view(N, -1)
    z_cam = torch.ones((N, M), device=cam2world_matrix.device)

    x_lift = (
        (
            x_cam
            - cx.unsqueeze(-1)
            + cy.unsqueeze(-1) * sk.unsqueeze(-1) / fy.unsqueeze(-1)
            - sk.unsqueeze(-1) * y_cam / fy.unsqueeze(-1)
        )
        / fx.unsqueeze(-1)
        * z_cam
    )
    y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam
    cam_rel_points = torch.stack(
        (x_lift, y_lift, z_cam), dim=-1)

    cam_depth_points = depth_tensor * cam_rel_points
    
    cam_depth_points = torch.cat(
        (cam_depth_points, torch.ones((N, M, 1), device=cam2world_matrix.device)), dim=-1
    )
    world_rel_depth_points = torch.bmm(
        cam2world_matrix, cam_depth_points.permute(0, 2, 1)
    ).permute(0, 2, 1)[:, :, :3]
    return world_rel_depth_points



class MultiViewDataset(Dataset):
    def __init__(
        self,
        root_dir: str,
        src_views: int,
        target_views: int,
        bg_color: str,
        img_wh: Tuple[int, int],
        caption_path: str,
        include_src: bool = True,
        num_samples: Optional[int] = -1,
        relative_pose: bool = False,
        return_pc: bool = False,
        repeat: int = 1,
    ):
        self.root_dir = root_dir
        self.src_views = src_views
        self.target_views = target_views
        self.bg_color = bg_color
        self.img_wh = img_wh
        self.caption_path = caption_path
        self.obj_paths = []
        self.captions = []
        self.include_src = include_src
        self.relative_pose = relative_pose
        self.repeat = repeat
        self.load_obj_caption_pairs(caption_path)
        self.obj_paths = self.obj_paths[-num_samples:] if num_samples > 0 else self.obj_paths
        self.caption_path = self.caption_path[-num_samples:] if num_samples > 0 else self.caption_path
        self.return_pc = return_pc

    def get_bg_color(self, bg_color):
        if bg_color == "white":
            bg_color = np.array([1.0, 1.0, 1.0], dtype=np.float32)
        elif bg_color == "black":
            bg_color = np.array([0.0, 0.0, 0.0], dtype=np.float32)
        elif bg_color == "gray":
            bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32)
        elif bg_color == "random":
            bg_color = np.random.rand(3)
        elif isinstance(bg_color, float):
            bg_color = np.array([bg_color] * 3, dtype=np.float32)
        else:
            raise NotImplementedError
        return bg_color

    def load_obj_caption_pairs(self, caption_path):
        self.captions = []
        with open(caption_path, "r") as f:
            for line in f:
                try:
                    obj_id, caption = line.strip().split("\t")
                except:
                    obj_id = line.strip()
                    caption = ""
                obj_path = os.path.join(self.root_dir, obj_id)
                # if os.path.exists(obj_path):
                self.obj_paths.append(obj_path)
                self.captions.append(caption)

    def load_image(self, img_path, bg_color, rescale=True, return_type="np"):
        # not using cv2 as may load in uint16 format
        # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255]
        # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC)
        # pil always returns uint8
        img = np.array(Image.open(img_path).resize(self.img_wh))
        img = img.astype(np.float32) / 255.0  # [0, 1]
        if img.shape[-1] == 4:
            alpha = img[..., 3:4]
            img = img[..., :3] * alpha + bg_color * (1 - alpha)
        # if rescale:
        #     img = img * 2.0 - 1.0  # to -1 ~ 1

        if return_type == "np":
            pass
        elif return_type == "pt":
            img = torch.from_numpy(img)
        else:
            raise NotImplementedError

        return img

    def load_normal(self, normal_path, return_type="np"):
        img = np.array(Image.open(normal_path).resize(self.img_wh))
        img = img.astype(np.float32) / 255.0  # [0, 1]

        if img.shape[-1] == 4:
            alpha = img[..., 3:4]
            img = img[..., :3] * alpha

        img = img * 2.0 - 1.0  # to -1 ~ 1

        if return_type == "np":
            pass
        elif return_type == "pt":
            img = torch.from_numpy(img)
        else:
            raise NotImplementedError
        return img
    
    def load_depth(self, depth_path, return_type="np"):
        img = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)
        img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_NEAREST)
        mask = img > 1000 # depth = 65535 is invalid
        img[mask] = 0
        mask = ~mask
        return torch.from_numpy(img[None, :, :, 0]), torch.from_numpy(mask).to(torch.float32).unsqueeze(0)
    
    def __len__(self):
        return len(self.obj_paths) * self.repeat

    def get_cameras(self, c2w, meta):
        w2c = torch.linalg.inv(c2w).to(torch.float)
        # parse focal length
        try:
            instrinsics = meta["K_matrix"]
            fx = instrinsics[0][0]
            cx = instrinsics[0][2]
            fov = 2 * math.atan(cx / fx)
        except:
            fov = meta['camera_angle_x']
        width = self.img_wh[0]
        fx = torch.tensor(
            [fov2focal(fov=fov, pixels=width)],
            dtype=torch.float32,
        ).expand(c2w.shape[0])
        # TODO: allow different fy
        fy = torch.clone(fx)
        width = torch.tensor([width], dtype=torch.float32).expand(c2w.shape[0])
        height = torch.clone(width)
        cameras = Cameras(
            R = w2c[:, :3, :3],
            T = w2c[:, :3, 3],
            fx = fx,
            fy = fy,
            cx = width / 2,
            cy = height / 2,
            width=width,
            height=height,
            appearance_id=torch.zeros_like(width),
            normalized_appearance_id=torch.zeros_like(width),
            distortion_params=None,
            camera_type=torch.zeros_like(width),
        )
        return cameras

    def __getitem__(self, index):
        index = index % len(self.obj_paths)
        obj_path = self.obj_paths[index] 
        meta_fp = os.path.join(obj_path, "meta.json")
        with open(meta_fp, "r") as f:
            meta = json.load(f)
        # caption
        caption = self.captions[index]

        # sample or select view ids in a mode
        num_views_all = len(meta["locations"])
        view_ids = np.random.choice(num_views_all, self.src_views+self.target_views, replace=False)

        locations = [meta["locations"][i] for i in view_ids]
        
        # load images, elevations, azimuths, c2w_matrixs
        bg_color = self.get_bg_color(self.bg_color)
        image_list, normal_list, depth_list, c2w_list, mask_list = [], [], [], [], []
        for loc in locations:
            img_path = os.path.join(obj_path, loc["frames"][0]["name"])
            img = self.load_image(img_path, bg_color, return_type="pt").permute(2, 0, 1)
            image_list.append(img)

            normal_path = os.path.join(obj_path, loc["frames"][1]["name"])
            normal = self.load_normal(normal_path, return_type="pt").permute(2, 0, 1)
            normal_list.append(normal)

            depth_path = os.path.join(obj_path, loc["frames"][2]["name"])
            depth, mask = self.load_depth(depth_path, return_type="pt")
            depth_list.append(depth)
            mask_list.append(mask)
            c2w_list.append(torch.tensor(loc["transform_matrix"]))
        
        # concat and stack
        img_tensors = torch.stack(image_list, dim=0)
        normal_tensors = torch.stack(normal_list, dim=0)
        depth_tensors = torch.stack(depth_list, dim=0)
        c2w_matrixs = torch.stack(c2w_list, dim=0)
        masks = torch.stack(mask_list, dim=0)
        # blender to opencv
        c2w_matrixs[:, :3, 1:3] *= -1
        if self.relative_pose:
            src_w2c = torch.inverse(c2w_matrixs[:1]) # (1, 4, 4)
            src_distance = c2w_matrixs[:1, :3, 3].norm(dim=-1) # (1)
            canonical_c2w = torch.matmul(src_w2c, c2w_matrixs) # (Nv, 4, 4) z as x axis 
            # shift to origin depth
            shift = torch.tensor(
                [
                    [1, 0, 0, 0],
                    [0, 1, 0, 0],
                    [0, 0, 1, -src_distance],
                    [0, 0, 0, 1],
                ]
            )
            canonical_c2w = torch.matmul(shift, canonical_c2w)
            c2w_matrixs = canonical_c2w
        cameras = self.get_cameras(c2w_matrixs, meta)
        # flatten intrinsics_matrixs_4x4 and c2w_matrixs to (Nv, 16), and concat them
        src_indices = torch.arange(self.src_views)
        target_indices = torch.arange(self.src_views, self.src_views + self.target_views) if not self.include_src else torch.arange(self.src_views + self.target_views)

        pc_fps = torch.zeros((1, 3), device=c2w_matrixs.device)
        
        if self.return_pc:
            depth_tensors_flat = repeat(depth_tensors, 'b 1 h w -> b (w h) c', c=3)
            pc_coords = ray_sample(
                c2w_matrixs, fx=cameras.fx, fy=cameras.fy, cx=cameras.cx, cy=cameras.cy, sk=torch.zeros_like(cameras.fx), resolution=self.img_wh[0], depth_tensor=depth_tensors_flat
            )
            mask_image_tensor = (depth_tensors_flat > 0)
            FPS_NUM = pow(2, 10)
            RANDOM_NUM = pow(10, 6)
            pc_coords = pc_coords[mask_image_tensor.bool()].reshape(-1, 3)

            pc_random = point_random_sample(pc_coords, RANDOM_NUM)
            pc_fps, _= sample_farthest_points(points=pc_random.view(1, -1, 3), K=FPS_NUM)
            pc_fps = pc_fps.view(-1, 3)
            # def save_ply(path: str, xyz: np.ndarray):
            #     l = ['x', 'y', 'z']
            #     dtype_full = [(attribute, 'f4') for attribute in l]
            #     elements = np.empty(xyz.shape[0], dtype=dtype_full)
            #     attributes = xyz
            #     elements[:] = list(map(tuple, attributes))
            #     el = PlyElement.describe(elements, 'vertex')
            #     PlyData([el]).write(path)

            # from scripts.point_cloud import PointCloud
            # from plyfile import PlyData, PlyElement
            # pc_class = PointCloud(pc_fps.detach().cpu().numpy(), {})

            # save_ply('tmp.ply', pc_class.coords)

        return {
            "view_ids": torch.tensor(view_ids),
            "images": img_tensors,
            "normals": normal_tensors,
            "depths": depth_tensors,
            "cameras": cameras,
            "masks": masks,
            "src_indices": src_indices,
            "target_indices": target_indices,
            "pc_fps": pc_fps,
            "prompt": caption,
        }


class MultiViewDataModule(LightningDataModule):
    def __init__(
        self,
        train_dataset: Dataset[Any],
        val_dataset: Dataset[Any],
        test_dataset: Dataset[Any],
        train_batch_size: int = 1,
        val_batch_size: int = 1,
        test_batch_size: int = 1,
        num_workers: Optional[int] = None,
        pin_memory: bool = True,
        real_dataset = None,
        real_batch_size = -1,
    ):
        super().__init__()

        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False)

        self.data_train = train_dataset
        self.data_val = val_dataset
        self.data_test = test_dataset
        self.data_real = real_dataset

        self.num_workers = num_workers if num_workers else train_batch_size * 2

    def prepare_data(self) -> None:
        # TODO: check if data is available
        pass
    
    def _dataloader(self, dataset: Dataset, batch_size: int, shuffle: bool) -> DataLoader[Any]:
        return DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=self.num_workers,
            pin_memory=self.hparams.pin_memory,
            collate_fn=collect_fn
        )

    def real_dataloader(self):
        return DataLoader(
            self.data_real,
            batch_size=self.hparams.real_batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=self.hparams.pin_memory,
            collate_fn=collect_fn
        )


    def train_dataloader(self) -> DataLoader[Any]:
        return DataLoader(
            self.data_train,
            batch_size=self.hparams.train_batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=self.hparams.pin_memory,
            collate_fn=collect_fn
        )

    def val_dataloader(self) -> DataLoader[Any]:
        if isinstance(self.data_val, ListConfig):
            return [self._dataloader(dataset, self.hparams.val_batch_size, False) for dataset in self.data_val]
        elif isinstance(self.data_val, DictConfig):
            return [self._dataloader(dataset, self.hparams.val_batch_size, False) for _, dataset in self.data_val.items()]
        else:
            return self._dataloader(self.data_val, self.hparams.val_batch_size, False)

    def test_dataloader(self) -> DataLoader[Any]:
        if isinstance(self.data_test, ListConfig):
            return [self._dataloader(dataset, self.hparams.test_batch_size, False) for dataset in self.data_test]
        elif isinstance(self.data_test, DictConfig):
            return [self._dataloader(dataset, self.hparams.test_batch_size, False) for _, dataset in self.data_test.items()]
        else:
            return self._dataloader(self.data_test, self.hparams.test_batch_size, False)


if __name__ == "__main__":
    from torchvision.utils import save_image
    from pytorch_lightning import seed_everything   
    seed_everything(42)
    dataset = MultiViewDataset(
        root_dir="data/nerf/my_synthetic",
        src_views=1,
        target_views=6,
        bg_color="black",
        img_wh=(384, 384),
        caption_path="data/nerf/my_synthetic/caption.txt",
        relative_pose=True, return_pc=True)
    item = dataset[0]
    print(item['cameras'][0])


