import os
import json
import numpy as np
import torch
import trimesh
from PIL import Image
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple
from functools import cache
from torch.utils.data import Dataset, DataLoader
from pytorch3d.io import load_objs_as_meshes
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.renderer import PerspectiveCameras

# from tqdm.auto import tqdm
from pytorch3d.renderer.camera_utils import join_cameras_as_batch

# from config.structured import Pix3DConfig, DataloaderConfig, ProjectConfig

"""
Please collect `num_frame` PerspectiveCameras based on the model name in other instances. For example, if the `num_frame` is equal to 10, please give me 10 PerspectiveCameras.

"""


class Pix3D(Dataset):
    """
    Pix3D dataset.
    """

    def __init__(
        self,
        root_dir: str = "/data/datasets/pix3d",
        split: str = "train",
        sample_size: int = 4096,
        img_size: int = 224,
        pc_dict: str = "pix3d.json",
        category: str = "chair",
        subset_ratio: float = 1.0,
        processed: bool = True,
        prior_frame_num=10,
    ):
        self.prior_frame_num = prior_frame_num
        self.categroy = category
        # Load JSON file
        json_path = os.path.join(root_dir, pc_dict)
        json_file = json.load(open(json_path, "r"))

        # Filter samples by category
        cat_json = [x for x in json_file if x["category"] == category]
        print(f"Found {len(cat_json)} samples for category {category}")

        # Split data into train/test
        if split == "train":
            json_file = cat_json[: int(len(cat_json) * 0.8)]
            if subset_ratio != 1.0:
                json_file = json_file[: int(len(json_file) * subset_ratio)]
            print(f"Using {len(json_file)} samples for training")
        elif split == "test":
            json_file = cat_json[int(len(cat_json) * 0.8) :]
            print(f"Using {len(json_file)} samples for testing")
        else:
            raise ValueError("split must be 'train' or 'test'")

        self.data = json_file
        self.root_dir = root_dir
        self.processed = processed
        self.processed_root_dir = root_dir.replace("pix3d", "pix3d_processed")
        self.sample_size = sample_size
        self.img_size = img_size
        print(f"Using {'processed' if self.processed else 'raw'} data")

        # Initialize the cameras dictionary
        self.cameras = self._collect_all_cameras()
        print()

    def _collect_all_cameras(self):
        """
        Collects cameras for all models in the dataset and stores them in a dictionary.
        Returns:
            Dict[str, List[PerspectiveCameras]]: Dictionary mapping model names to lists of cameras.
        """
        cameras_dict = {}
        for sample in self.data:
            if self.categroy not in sample["model"]:
                continue
            model_name = os.path.join(
                self.processed_root_dir if self.processed else self.root_dir,
                sample["model"],
            )
            if model_name not in cameras_dict:
                cameras_dict[model_name] = []

            # Load and normalize point cloud
            pts = self._load_pointcloud(sample)
            pts_v1 = self._normalize_points(pts)

            # Compute camera parameters and add to the list
            camera = self._get_camera(sample, pts_v1)
            cameras_dict[model_name].append(camera)

        return cameras_dict

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        sample = self.data[idx]

        # Load and normalize point cloud
        pts = self._load_pointcloud(sample)
        pts_v1 = self._normalize_points(pts)

        # Compute camera parameters
        camera = self._get_camera(sample, pts)
        # additional_cameras = self.collect_perspective_cameras(
        #     sample["model"], self.prior_frame_num
        # )

        # Process image
        img_cropped_tensor = self._process_image(sample)

        return self._create_return_dict(
            sample,
            pts_v1,
            camera,
            img_cropped_tensor,
            # additional_cameras
        )

    @cache
    def collect_perspective_cameras(
        self, model_name: str, num_frame: int
    ) -> List[PerspectiveCameras]:
        """
        Collect `num_frame` PerspectiveCameras for the specified model name.

        Args:
            model_name (str): The name of the model to filter.
            num_frame (int): The number of PerspectiveCameras to collect.

        Returns:
            List[PerspectiveCameras]: A list of PerspectiveCameras instances.
        """
        # Filter dataset for the specified model name
        matching_samples = [
            sample for sample in self.data if model_name in sample["model"]
        ]

        if len(matching_samples) < num_frame:
            raise ValueError(
                f"Requested {num_frame} frames, but only found {len(matching_samples)} matching samples."
            )

        cameras = []
        for sample in matching_samples[:num_frame]:
            # Load and normalize point cloud (needed for camera calculation)
            pts = self._load_pointcloud(sample)
            pts_v1 = self._normalize_points(pts)

            # Compute camera parameters and add to the list
            camera = self._get_camera(sample, pts_v1)
            cameras.append(camera)

        return cameras

    def _load_pointcloud(self, sample: Dict) -> np.ndarray:
        """Load point cloud from processed or raw data."""
        if self.processed:
            pointcloud = trimesh.load(
                os.path.join(self.processed_root_dir, sample["model"])
            )
            return np.array(pointcloud.vertices)
        else:
            mesh = load_objs_as_meshes([os.path.join(self.root_dir, sample["model"])])
            pointcloud = sample_points_from_meshes(mesh, self.sample_size).squeeze()
            return np.array(pointcloud)

    def _normalize_points(self, pts: np.ndarray) -> np.ndarray:
        """Normalize points to Pix3D format."""
        m = pts.mean(axis=0)
        s = pts.reshape(1, -1).std(axis=1)
        pts_norm = (pts - m) / s
        v2_to_v1 = np.array([[0, 0, -1], [0, 1, 0], [1, 0, 0]])
        return (v2_to_v1 @ pts_norm.T).T

    def _get_camera(self, sample: Dict, pts: np.ndarray) -> PerspectiveCameras:
        """Calculate camera parameters for the dataset."""
        w, h = sample["img_size"]
        x0, y0, x1, y1 = sample["bbox"]
        cx, cy = (x0 + x1) / 2, (y0 + y1) / 2
        half_w = max(y1 - y0, x1 - x0) / 2
        f = sample["focal_length"] * w / 32  # width of sensor is 32mm
        s = self.img_size / (2 * half_w)

        # Compute intrinsic and extrinsic parameters
        affine = np.array([[s, 0, s * (-x0)], [0, s, s * (-y0)], [0, 0, 1]])
        proj_trans = affine @ np.array([[f, 0, w / 2], [0, f, h / 2], [0, 0, 1]])
        fx, fy = proj_trans[0, 0], proj_trans[1, 1]
        tx, ty = proj_trans[0, 2], proj_trans[1, 2]

        # Convert to PyTorch3D format
        R_v1, t_v1 = self._convert_extrinsics(sample, pts)
        return PerspectiveCameras(
            focal_length=torch.tensor([fx, fy], dtype=torch.float32)[None],
            principal_point=torch.tensor((tx, ty), dtype=torch.float32)[None],
            R=torch.tensor(R_v1, dtype=torch.float32)[None],
            T=torch.tensor(t_v1, dtype=torch.float32)[None],
            in_ndc=False,
            image_size=torch.tensor([self.img_size, self.img_size])[None],
            device="cuda",
        )

    def _convert_extrinsics(
        self, sample: Dict, pts: np.ndarray
    ) -> Tuple[np.ndarray, np.ndarray]:
        """Convert extrinsics to PyTorch3D format."""
        R = np.array(sample["rot_mat"])
        t = np.array(sample["trans_mat"])
        s = pts.reshape(1, -1).std(axis=1)
        R_norm = R * s
        t_norm = t + pts.mean(axis=0) @ R.T
        convert = np.array([[0, 0, 1], [0, 1, 0], [-1, 0, 0]])
        R_v1 = (R_norm @ convert).T
        return R_v1, t_norm

    def _process_image(self, sample: Dict) -> torch.Tensor:
        """Process and crop image data."""
        img_path = os.path.join(
            self.processed_root_dir if self.processed else self.root_dir, sample["img"]
        )
        img = Image.open(img_path)
        if not self.processed:
            w, h = sample["img_size"]
            x0, y0, x1, y1 = sample["bbox"]
            cx, cy = (x0 + x1) / 2, (y0 + y1) / 2
            half_w = max(y1 - y0, x1 - x0) / 2
            img = img.crop((cx - half_w, cy - half_w, cx + half_w, cy + half_w))
            img = img.resize((self.img_size, self.img_size))
            if img.mode != "RGB":
                img = img.convert("RGB")
        img_tensor = (
            torch.from_numpy(np.array(img) / 255.0)[..., :3].permute(2, 0, 1).float()
        )
        return img_tensor

    def _create_return_dict(
        self,
        sample: Dict,
        pts_v1: np.ndarray,
        camera: PerspectiveCameras,
        img_cropped_tensor: torch.Tensor,
        # additional_cameras,
    ) -> Dict:
        """Create the dictionary for __getitem__ return."""
        rt = OrderedDict()
        rt["frame_number"] = sample["img"].split("/")[-1].split(".")[0]
        rt["sequence_name"] = sample["model"].split("/")[-2] + "_" + rt["frame_number"]
        rt["sequence_category"] = sample["category"]
        rt["image_size_hw"] = torch.tensor(
            [sample["img_size"][1], sample["img_size"][0]]
        ).long()
        rt["effective_image_size_hw"] = torch.tensor(
            [self.img_size, self.img_size]
        ).long()
        rt["image_path"] = os.path.join(
            self.processed_root_dir if self.processed else self.root_dir, sample["img"]
        )
        rt["image_rgb"] = img_cropped_tensor
        rt["camera"] = camera
        rt["sequence_point_cloud_path"] = os.path.join(
            self.processed_root_dir if self.processed else self.root_dir,
            sample["model"],
        )
        rt["sequence_point_cloud"] = torch.tensor(pts_v1).float()
        # rt["addition_camera"] = join_cameras_as_batch(additional_cameras)
        return rt


def get_dataset_pix3d(cfg):

    dataset_cfg = cfg.dataset
    dataloader_cfg = cfg.dataloader

    dataloader_train = None
    if "sample" not in cfg.run.job:
        dataset_train = Pix3D(
            root_dir=dataset_cfg.root,
            pc_dict=dataset_cfg.pc_dict,
            category=dataset_cfg.category,
            split="train",
            sample_size=dataset_cfg.max_points,
            img_size=dataset_cfg.image_size,
            subset_ratio=dataset_cfg.subset_ratio,
            processed=dataset_cfg.processed,
        )
        dataloader_train = DataLoader(
            dataset_train,
            batch_size=dataloader_cfg.batch_size,
            shuffle=True,
            num_workers=int(dataloader_cfg.num_workers),
            drop_last=True,
            collate_fn=custom_collate,
        )

    dataset_val = Pix3D(
        root_dir=dataset_cfg.root,
        pc_dict=dataset_cfg.pc_dict,
        category=dataset_cfg.category,
        split="test",
        sample_size=dataset_cfg.max_points,
        img_size=dataset_cfg.image_size,
        processed=dataset_cfg.processed,
    )

    dataloader_val = DataLoader(
        dataset_val,
        batch_size=dataloader_cfg.batch_size,
        shuffle=False,
        num_workers=int(dataloader_cfg.num_workers),
        drop_last=False,
        collate_fn=custom_collate,
    )

    return dataloader_train, dataloader_val, dataloader_val


def custom_collate(batch):
    data = {}
    for key in batch[0].keys():
        if isinstance(batch[0][key], PerspectiveCameras):
            data[key] = [sample[key] for sample in batch]
        elif batch[0][key] is None:
            data[key] = None
        else:
            data[key] = torch.utils.data.dataloader.default_collate(
                [sample[key] for sample in batch]
            )
    return data


if __name__ == "__main__":
    import torch
    import torchshow
    from pytorch3d.io import load_objs_as_meshes, load_obj
    from pytorch3d.ops import sample_points_from_meshes
    from pytorch3d.structures import Pointclouds
    from pytorch3d.utils import ico_sphere
    from pytorch3d.renderer import (
        look_at_view_transform,
        FoVPerspectiveCameras,
        PointsRasterizationSettings,
        PointsRenderer,
        AlphaCompositor,
        PointsRasterizer,
    )
    import torch
    from pytorch3d.renderer import CamerasBase, PerspectiveCameras
    from pytorch3d.transforms import random_rotations, Rotate, RotateAxisAngle
    from pytorch3d.renderer.camera_utils import camera_to_eye_at_up

    # dataset = Pix3D()
    # for idx, sample in enumerate(iter(dataset)):
    #     points = sample["sequence_point_cloud"]
    #     pointcloud = (
    #         Pointclouds(points=[points], features=[torch.ones_like(points)])
    #         .extend(11)
    #         .to("cuda")
    #     )
    #     raster_settings = PointsRasterizationSettings(image_size=512, radius=0.04)

    #     # eye, at, up = camera_to_eye_at_up(sample["camera"].get_world_to_view_transform())

    #     num_cameras = 10
    #     import random

    #     # Define the range for azimuth and elevation angles
    #     min_azimuth = 0
    #     max_azimuth = 360
    #     min_elevation = -90
    #     max_elevation = 90
    #     R_list = []
    #     T_list = []
    #     a = sample["camera"].clone().cuda()
    #     ori_d = torch.norm(a.T, p=2)
    #     print("T: ", ori_d)
    #     min_distance = ori_d - ori_d * 0.2
    #     max_distance = ori_d + ori_d * 0.2
    #     for _ in range(num_cameras):
    #         # Randomly sample azimuth, elevation, and distance
    #         azimuth = random.uniform(min_azimuth, max_azimuth)
    #         elevation = random.uniform(min_elevation, max_elevation)
    #         distance = random.uniform(min_distance, max_distance)
    #         print(distance)

    #         # Get the rotation and translation matrices for this camera
    #         R, T = look_at_view_transform(dist=distance, elev=elevation, azim=azimuth)

    #         R_list.append(R)
    #         T_list.append(T)

    #     # Convert lists to tensors
    #     R = torch.cat(R_list, dim=0).cuda()
    #     T = torch.cat(T_list, dim=0).cuda()
    #     a.R = torch.cat((a.R, R))
    #     a.T = torch.cat((a.T, T))

    #     # Create a points renderer by compositing points using an alpha compositor (nearer points
    #     # are weighted more heavily). See [1] for an explanation.
    #     rasterizer = PointsRasterizer(
    #         cameras=a,
    #         # cameras=sample["camera"],
    #         # cameras=sample["addition_camera"],
    #         raster_settings=raster_settings,
    #     )
    #     # Create new cameras

    #     renderer = PointsRenderer(rasterizer=rasterizer, compositor=AlphaCompositor())
    #     torchshow.save(
    #         renderer(pointcloud),
    #         f"/data/Hypothesis/theorem/3dgen/projection-conditioned-point-cloud-diffusion/outputs/random_cameras/render_{idx}.png",
    #     )
    #     if idx == 6:
    #         break

    # ======
    dataset = Pix3D()
    raster_settings = PointsRasterizationSettings(image_size=512, radius=0.04)
    for idx, sample in enumerate(iter(dataset)):
        cameras = dataset.cameras[sample["sequence_point_cloud_path"]]
        points = sample["sequence_point_cloud"]
        pointcloud = (
            Pointclouds(points=[points], features=[torch.ones_like(points)])
            .extend(11)
            .to("cuda")
        )
        l = cameras[:10]
        l.append(sample["camera"])
        rasterizer = PointsRasterizer(
            cameras=join_cameras_as_batch(l),
            # cameras=sample["camera"],
            # cameras=sample["addition_camera"],
            raster_settings=raster_settings,
        )
        renderer = PointsRenderer(rasterizer=rasterizer, compositor=AlphaCompositor())
        torchshow.save(
            renderer(pointcloud),
            f"/data/Hypothesis/theorem/3dgen/projection-conditioned-point-cloud-diffusion/outputs/random_cameras/new_render_{idx}.png",
        )
