import numpy as np
import imageio
from torch.utils.data import Dataset
import json
import random
import os

class IsaacGymDataset(Dataset):
    def __init__(self, path, mode, points_per_item=2048, max_len=None):

        self.path = path
        self.mode = mode
        self.points_per_item = points_per_item
        self.max_len = max_len

        self.max_num_entities = 3
        self.num_views = 3
        self.num_target_views = 2
        self.num_objs = 3

        self.start_idx, self.end_idx = {'train': (1, 30001),
                                        'val': (30001, 30201),
                                        'test': (30001, 30201)}[mode]

        self.idxs = np.arange(self.start_idx, self.end_idx)
        dataset_name = 'IsaacGym3D'

        print(f'Initialized {dataset_name} {mode} set, {len(self.idxs)} examples')
        print(self.idxs)
            
    def __len__(self):
        if self.max_len is not None:
            return self.max_len
        #return len(self.idxs) * self.num_views
        return len(self.idxs)

    def __getitem__(self, idx, noisy=True):
        scene_idx = idx % len(self.idxs)
        scene_idx = self.idxs[scene_idx]

        num_of_views= list(range(6))
        input_index = random.sample(num_of_views, self.num_views)
        target_index = random.sample([num for num in num_of_views if num not in input_index], self.num_target_views)

        # Images
        input_imgs = [np.asarray(imageio.imread(
            os.path.join(self.path, 'images', f'{scene_idx:05d}_{v}.png')))
            for v in input_index]
        input_imgs = [img[..., :3].astype(np.float32) / 255 for img in input_imgs]
        input_image = [np.transpose(img, (2, 0, 1)) for img in input_imgs]
        input_images = np.stack(input_image, axis=0)

        target_imgs = [np.asarray(imageio.imread(
            os.path.join(self.path, 'images', f'{scene_idx:05d}_{v}.png')))
            for v in target_index]
        target_imgs = [img[..., :3].astype(np.float32) / 255 for img in target_imgs]
        target_image = [np.transpose(img, (2, 0, 1)) for img in target_imgs]
        target_images = np.stack(target_image, axis=0)

        # Masks 
        fg_masks = [np.asarray(imageio.imread(
            os.path.join(self.path, 'masks', f'{scene_idx:05d}_{v}_fg.png')))
            for v in target_index]
        fg_masks = [((mask > 64).astype(np.float32)) for mask in fg_masks]
        fg_masks = [np.expand_dims(mask, axis=0) for mask in fg_masks]
        fg_masks = np.stack(fg_masks, axis=0)
        fg_masks = np.transpose(fg_masks, (0, 2, 3, 1))

        ag_masks = [np.asarray(imageio.imread(
            os.path.join(self.path, 'masks', f'{scene_idx:05d}_{v}_ag.png')))
            for v in target_index]
        ag_masks = [((mask > 64).astype(np.float32)) for mask in ag_masks]
        ag_masks = [np.expand_dims(mask, axis=0) for mask in ag_masks]
        ag_masks = np.stack(ag_masks, axis=0)
        ag_masks = np.transpose(ag_masks, (0, 2, 3, 1))

        # Camera parameters
        with open(os.path.join(self.path, 'scenes', f'{scene_idx:05d}_5.json'), 'r') as f:
            cam_info = json.load(f)["views"]

        input_ext = [cam_info[str(i)]["ext"] for i in input_index]
        input_ext = np.stack(input_ext, axis=0).astype(np.float32)
        input_int = [cam_info[str(i)]["int"] for i in input_index]
        input_int = np.stack(input_int, axis=0).astype(np.float32)

        input_camera_pos = []
        input_rays = []
        for i in range(self.num_views):
            cur_input_cam_pos, cur_input_ray = self.get_camera_ray(input_ext[i], input_int[i], 128, 128)
            input_camera_pos.append(cur_input_cam_pos)
            input_rays.append(cur_input_ray)
        input_camera_pos = np.stack(input_camera_pos, axis=0).astype(np.float32)
        input_rays = np.stack(input_rays, axis=0).astype(np.float32)

        target_camera_pos = []
        target_rays = []
        target_ext = [cam_info[str(i)]["ext"] for i in target_index]
        target_ext = np.stack(target_ext, axis=0).astype(np.float32)
        target_int = [cam_info[str(i)]["int"] for i in target_index]
        target_int = np.stack(target_int, axis=0).astype(np.float32)
        for i in range(self.num_target_views):
            cur_target_cam_pos, cur_target_ray = self.get_camera_ray(target_ext[i], target_int[i], 128, 128)
            target_camera_pos.append(cur_target_cam_pos)
            target_rays.append(cur_target_ray)
        target_camera_pos = np.stack(target_camera_pos, axis=0).astype(np.float32)
        target_rays = np.stack(target_rays, axis=0).astype(np.float32)

        target_camera_pos_ori = target_camera_pos
        target_rays_ori = target_rays

        # Sample points from target views
        target_pixels = np.reshape(target_images.transpose(0, 2, 3, 1), (self.num_target_views, -1, 3))
        target_rays = np.reshape(target_rays, (self.num_target_views, 128 * 128, 3))
        target_camera_pos = np.tile(np.expand_dims(target_camera_pos, 1), (1, 128 * 128, 1))

        num_points = target_rays.shape[1]
        replace = num_points < self.points_per_item
        sampled_idxs = np.random.choice(np.arange(num_points), size=(self.points_per_item,), replace=replace)
        target_rays = target_rays[:, sampled_idxs]
        target_camera_pos = target_camera_pos[:, sampled_idxs]
        target_pixels = target_pixels[:, sampled_idxs]
        sampled_idxs = np.expand_dims(sampled_idxs, axis=0)
        sampled_idxs = np.tile(sampled_idxs, (self.num_target_views, 1))

        result = {
            'input_images':             input_images,           # [T, 3, h, w]
            'input_camera_pos':         input_camera_pos,       # [T, 3]
            'input_rays':               input_rays,             # [T, h, w, 3]

            'fg_mask':                  fg_masks,               # [T, 1, h, w]
            'ag_mask':                  ag_masks,               # [T, 1, h, w]

            'target_pixels':            target_pixels,          # [T, p, 3]
            'target_camera_pos':        target_camera_pos,      # [T, p, 3]
            'target_rays':              target_rays,            # [T, p, 3]

            'target_pixels_ori':        target_images,          # [T, 3, h, w]
            'target_camera_pos_ori':    target_camera_pos_ori,  # [T, 3]
            'target_rays_ori':          target_rays_ori,        # [T, h, w, 3]

            'sceneid':                  scene_idx,
            'target_index':             sampled_idxs,           # [T, p]
        }

        return result
    
    def get_camera_ray(self, camera_view_matrix, camera_intrinsic_matrix, H, W):
        """
        camera_view_matrix: (4, 4) numpy array (world-to-camera pose)
        camera_intrinsic_matrix: (3, 3) numpy array
        H, W: image height and width

        Returns:
            cam_pos: (3,) camera position in world coordinates
            rays_world: (H, W, 3) ray directions in world space, normalized
        """
        # Invert view matrix to get camera-to-world
        cam_to_world = np.linalg.inv(camera_view_matrix)

        # Extract camera position from the inverse view matrix
        cam_pos = cam_to_world[:3, 3]

        # Create meshgrid of pixel coordinates
        i, j = np.meshgrid(np.arange(W), np.arange(H), indexing='xy')
        pixels = np.stack([i + 0.5, j + 0.5, np.ones_like(i)], axis=-1)  # shape (H, W, 3)

        # Apply inverse intrinsics to get direction in camera space
        inv_K = np.linalg.inv(camera_intrinsic_matrix)
        dirs_cam = pixels @ inv_K.T  # shape (H, W, 3)

        # Convert directions to world space
        R = cam_to_world[:3, :3]  # rotation part
        dirs_world = - (dirs_cam @ R.T)

        # Normalize rays
        dirs_world = dirs_world / np.linalg.norm(dirs_world, axis=-1, keepdims=True)  # (H, W, 3)

        return cam_pos, dirs_world.astype(np.float32)