import numpy as np
import imageio
from torch.utils.data import Dataset
import json
import random
import os

class ClevrDataset(Dataset):
    def __init__(self, path, mode, points_per_item=2048, max_len=None):

        self.path = path
        self.mode = mode
        self.points_per_item = points_per_item
        self.max_len = max_len

        self.max_num_entities = 3
        self.num_views = 3
        self.num_target_views = 2
        self.num_objs = 3

        self.start_idx, self.end_idx = {'train': (0, 20000),
                                        'val': (20000, 21000),
                                        'test': (20000, 21000)}[mode]


        self.idxs = np.arange(self.start_idx, self.end_idx) 
        dataset_name = 'CLEVR3D'

        print(f'Initialized {dataset_name} {mode} set, {len(self.idxs)} examples')
        print(self.idxs)

    def __len__(self):
        if self.max_len is not None:
            return self.max_len
        return len(self.idxs) * self.num_views

    def __getitem__(self, idx, noisy=True):
        scene_idx = idx % len(self.idxs)
        scene_idx = self.idxs[scene_idx]

        num_of_views= list(range(5))
        input_index = random.sample(num_of_views, 3)]
        target_index = random.sample([num for num in num_of_views if num not in input_index], 2)

        # Images
        input_imgs = [np.asarray(imageio.imread(
            os.path.join(self.path, 'images', f'CLEVR_3d_{scene_idx:06d}_{v:02d}.png')))
            for v in input_index]
        input_imgs = [img[..., :3].astype(np.float32) / 255 for img in input_imgs]
        input_image = [np.transpose(img, (2, 0, 1)) for img in input_imgs]
        input_images = np.stack(input_image, axis=0)

        target_imgs = [np.asarray(imageio.imread(
            os.path.join(self.path, 'images', f'CLEVR_3d_{scene_idx:06d}_{v:02d}.png')))
            for v in target_index]
        target_imgs = [img[..., :3].astype(np.float32) / 255 for img in target_imgs]
        target_image = [np.transpose(img, (2, 0, 1)) for img in target_imgs]
        target_images = np.stack(target_image, axis=0)

        # Masks
        fg_masks = []
        for v in target_index:
            mask_folder = os.path.join(self.path, 'masks', f'CLEVR_3d_{scene_idx:06d}_{v:02d}')
            masks = [np.asarray(imageio.imread(
                os.path.join(mask_folder, f'{v:03d}.png')))
                for v in range(self.num_objs)]
            fg_mask = np.zeros_like(masks[0][..., :1])
            for mask in masks:
                binary_mask = (mask[..., :1] > 64).astype(np.float32) #torch.Tensor(fg_mask).permute(2, 0, 1)
                fg_mask = np.maximum(fg_mask, binary_mask)
            fg_masks.append(fg_mask)
        fg_masks = np.stack(fg_masks, axis=0).astype(np.float32)

        # Camera parameters
        with open(os.path.join(self.path, 'scenes', f'CLEVR_3d_{scene_idx:06d}.json'), 'r') as f:
            metadata = json.load(f)
        input_camera_pos = [metadata.get('camera', {})['location'][i] for i in input_index]
        input_camera_pos = np.stack(input_camera_pos, axis=0).astype(np.float32)

        target_camera_pos = [metadata.get('camera', {})['location'][i] for i in target_index]
        target_camera_pos = np.stack(target_camera_pos, axis=0).astype(np.float32)

        input_rays = []
        target_rays = []
        for i in range(self.num_views):
            cur_input_rays = get_camera_rays(input_camera_pos[i], width=128, height=128, noisy=False)
            input_rays.append(cur_input_rays)

        for i in range(self.num_target_views):
            cur_target_rays = get_camera_rays(target_camera_pos[i], width=128, height=128, noisy=False)
            target_rays.append(cur_target_rays)          

        input_rays = np.stack(input_rays, axis=0).astype(np.float32)
        target_rays = np.stack(target_rays, axis=0).astype(np.float32)

        target_camera_pos_ori = target_camera_pos
        target_rays_ori = target_rays

        # Sample points from target views
        target_pixels = np.reshape(target_images.transpose(0, 2, 3, 1), (self.num_target_views, -1, 3))
        target_rays = np.reshape(target_rays, (self.num_target_views, 128 * 128, 3))
        target_camera_pos = np.tile(np.expand_dims(target_camera_pos, 1), (1, 128 * 128, 1))

        num_points = target_rays.shape[1]
        replace = num_points < self.points_per_item
        sampled_idxs = np.random.choice(np.arange(num_points), size=(self.points_per_item,), replace=replace)
        target_rays = target_rays[:, sampled_idxs]
        target_camera_pos = target_camera_pos[:, sampled_idxs]
        target_pixels = target_pixels[:, sampled_idxs]
        sampled_idxs = np.expand_dims(sampled_idxs, axis=0)
        sampled_idxs = np.tile(sampled_idxs, (self.num_target_views, 1))
        
        result = {
            'input_images':             input_images,               # [T, 3, h, w]
            'input_camera_pos':         input_camera_pos,           # [T, 3]
            'input_rays':               input_rays,                 # [T, h, w, 3]

            'fg_mask':                  fg_masks,                   # [T, 1, h, w]

            'target_pixels':            target_pixels,              # [T, p, 3]
            'target_camera_pos':        target_camera_pos,          # [T, p, 3]
            'target_rays':              target_rays,                # [T, p, 3]

            'target_pixels_ori':        target_images,              # [T, 3, h, w]
            'target_camera_pos_ori':    target_camera_pos_ori,      # [T, 3]
            'target_rays_ori':          target_rays_ori,            # [T, h, w, 3]

            'sceneid':                  scene_idx,                  
            'target_index':             sampled_idxs,               # [T, p]
        }

        return result


def get_camera_rays(c_pos, width=320, height=240, focal_length=0.035, sensor_width=0.032, noisy=False,
                    vertical=None, track_point=None):
    """Compute per-pixel ray directions for a pinhole camera positioned at c_pos."""
    if track_point is None:
        track_point = np.array((0., 0., 0.))

    if vertical is None:
        vertical = np.array((0., 0., 1.))

    c_dir = (track_point - c_pos)
    c_dir = c_dir / np.linalg.norm(c_dir)
    img_plane_center = c_pos + c_dir * focal_length

    img_plane_horizontal = np.cross(c_dir, vertical)
    img_plane_horizontal = img_plane_horizontal / np.linalg.norm(img_plane_horizontal)

    img_plane_vertical = np.cross(c_dir, img_plane_horizontal)
    img_plane_vertical = img_plane_vertical / np.linalg.norm(img_plane_vertical)

    sensor_height = (sensor_width / width) * height

    horizontal_offsets = np.linspace(-1, 1, width + 1) * sensor_width / 2
    vertical_offsets = np.linspace(-1, 1, height + 1) * sensor_height / 2

    horizontal_offsets = (horizontal_offsets[:-1] + horizontal_offsets[1:]) / 2
    vertical_offsets = (vertical_offsets[:-1] + vertical_offsets[1:]) / 2

    horizontal_offsets = np.repeat(np.reshape(horizontal_offsets, (1, width)), height, 0)
    vertical_offsets = np.repeat(np.reshape(vertical_offsets, (height, 1)), width, 1)

    if noisy:
        pixel_width = sensor_width / width
        pixel_height = sensor_height / height
        horizontal_offsets += (np.random.random((height, width)) - 0.5) * pixel_width
        vertical_offsets += (np.random.random((height, width)) - 0.5) * pixel_height

    horizontal_offsets = (np.reshape(horizontal_offsets, (height, width, 1)) *
                          np.reshape(img_plane_horizontal, (1, 1, 3)))
    vertical_offsets = (np.reshape(vertical_offsets, (height, width, 1)) *
                        np.reshape(img_plane_vertical, (1, 1, 3)))

    image_plane = horizontal_offsets + vertical_offsets
    image_plane = image_plane + np.reshape(img_plane_center, (1, 1, 3))

    c_pos_exp = np.reshape(c_pos, (1, 1, 3))
    rays = image_plane - c_pos_exp
    ray_norms = np.linalg.norm(rays, axis=2, keepdims=True)
    rays = rays / ray_norms
    return rays.astype(np.float32)