import copy
import json
import numpy as np
import os
import torch
import torch.nn.functional as F
from typing import Tuple, Optional, Any

import logging as log

from .base_dataset import BaseDataset
from .data_loading import parallel_load_images
from .intrinsics import Intrinsics
from .ray_utils import get_ray_directions, generate_hemispherical_orbit, get_rays


class SyntheticNerfDataset(BaseDataset):
    def __init__(self,
                 datadir,
                 split: str,
                 batch_size: Optional[int] = None,
                 downsample: float = 1.0,
                 max_frames: Optional[int] = None,
                 scales=[1]):
        self.downsample = downsample
        self.max_frames = max_frames
        self.near_far = [2.0, 6.0]

        if split == 'render':
            frames, transform = load_360_frames(datadir, 'test', self.max_frames)
            imgs, poses = load_360_images(frames, datadir, 'test', self.downsample)
            render_poses = generate_hemispherical_orbit(poses, n_frames=120)
            self.poses = render_poses
            intrinsics = load_360_intrinsics(
                transform, img_h=imgs[0].shape[0], img_w=imgs[0].shape[1],
                downsample=self.downsample)
            imgs = None
        else:
            frames, transform = load_360_frames(datadir, split, self.max_frames)
            imgs, poses = load_360_images(frames, datadir, split, self.downsample)
            intrinsics = load_360_intrinsics(
                transform, img_h=imgs[0].shape[0], img_w=imgs[0].shape[1],
                downsample=self.downsample)
        rays_o, rays_d, radii, imgs = create_360_rays(
            imgs, poses, merge_all=split == 'train', intrinsics=intrinsics,
            scales=scales if split == 'train' else [1])
        super().__init__(
            datadir=datadir,
            split=split,
            scene_bbox=get_360_bbox(datadir, is_contracted=False),
            is_ndc=False,
            is_contracted=False,
            batch_size=batch_size,
            imgs=imgs,
            rays_o=rays_o,
            rays_d=rays_d,
            radii=radii,
            intrinsics=intrinsics,
        )
        log.info(f"SyntheticNerfDataset. Loaded {split} set from {datadir}."
                 f"{len(poses)} images of shape {self.img_h}x{self.img_w}. "
                 f"Images loaded: {imgs is not None}. "
                 f"Sampling without replacement={self.use_permutation}. {intrinsics}")

    def __getitem__(self, index):
        out = super().__getitem__(index)
        pixels = out["imgs"]

        if self.split == 'train':
            bg_color = torch.rand((1, 3), dtype=pixels.dtype,
                                  device=pixels.device)
        else:
            if pixels is None:
                bg_color = torch.ones((1, 3), dtype=torch.float32,
                                      device='cuda:0')
            else:
                bg_color = torch.ones((1, 3), dtype=pixels.dtype,
                                      device=pixels.device)

        # Alpha compositing
        if pixels is not None:
            pixels = pixels[:, :3] * pixels[:, 3:] + bg_color * (1.0 - pixels[:, 3:])
        out["imgs"] = pixels
        out["bg_color"] = bg_color
        out["near_fars"] = torch.tensor([[2.0, 6.0]])
        return out


def get_360_bbox(datadir, is_contracted=False):
    if is_contracted:
        radius = 2
    elif "ship" in datadir:
        radius = 1.5
    else:
        radius = 1.3
    return torch.tensor([[-radius, -radius, -radius], [radius, radius, radius]])


def create_360_rays(imgs: Optional[torch.Tensor], poses: torch.Tensor,
                    merge_all: bool, intrinsics: Intrinsics,
                    scales=[1]
                    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    all_rays_o, all_rays_d = [], []
    all_radii = []
    all_imgs = []
    n_scales = len(scales)

    # [H, W, 3]
    for scale in scales:
        new_intrinsics = copy.deepcopy(intrinsics)
        new_intrinsics.height = int(new_intrinsics.height / scale)
        new_intrinsics.width = int(new_intrinsics.width / scale)
        new_intrinsics.center_x = int(new_intrinsics.center_x / scale)
        new_intrinsics.center_y = int(new_intrinsics.center_y / scale)
        new_intrinsics.focal_x = int(new_intrinsics.focal_x / scale)
        new_intrinsics.focal_y = int(new_intrinsics.focal_y / scale)

        directions = get_ray_directions(new_intrinsics, opengl_camera=True)
        num_frames = poses.shape[0]

        for i in range(num_frames):
            rays_o, rays_d, radii = get_rays(
                directions, poses[i], ndc=False, normalize_rd=True) # [H*W, 3]
            all_rays_o.append(rays_o)
            all_rays_d.append(rays_d)
            all_radii.append(radii)

        if imgs is not None:
            new_imgs = F.adaptive_avg_pool2d(
                imgs.permute(0, 3, 1, 2),
                (new_intrinsics.height, new_intrinsics.width)
            ).permute(0, 2, 3, 1)
            all_imgs.append(new_imgs.view(-1, imgs.shape[-1])) # [N*H*W, 3/4]

    # [n_frames * h * w, 3]
    all_rays_o = torch.cat(all_rays_o, 0).to(dtype=torch.float32)
    all_rays_d = torch.cat(all_rays_d, 0).to(dtype=torch.float32)
    all_radii = torch.cat(all_radii, 0).to(dtype=torch.float32) # [N*H*W, 1]

    if imgs is not None:
        all_imgs = torch.cat(all_imgs, 0).to(dtype=torch.float32)

    if not merge_all:
        num_pixels = intrinsics.height * intrinsics.width * n_scales
        if imgs is not None:
            # [N, H*W, 3/4]
            all_imgs = all_imgs.view(num_frames, -1, imgs.shape[-1])
        all_rays_o = all_rays_o.view(num_frames, -1, 3) # [N, H*W, 3]
        all_rays_d = all_rays_d.view(num_frames, -1, 3) # [N, H*W, 3]
        all_radii = all_radii.view(num_frames, -1, 1) # [N, H*W, 1]
    return all_rays_o, all_rays_d, all_radii, all_imgs


def load_360_frames(datadir, split, max_frames: int) -> Tuple[Any, Any]:
    with open(os.path.join(datadir, f"transforms_{split}.json"), 'r') as f:
        meta = json.load(f)
        frames = meta['frames']

        # Subsample frames
        tot_frames = len(frames)
        num_frames = min(tot_frames, max_frames or tot_frames)
        if split == 'train' or split == 'test':
            subsample = int(round(tot_frames / num_frames))
            frame_ids = np.arange(tot_frames)[::subsample]
            if subsample > 1:
                log.info(f"Subsampling {split} set to 1 every {subsample} images.")
        else:
            frame_ids = np.arange(num_frames)
        frames = np.take(frames, frame_ids).tolist()
    return frames, meta


def load_360_images(frames, datadir, split, downsample) -> Tuple[torch.Tensor, torch.Tensor]:
    img_poses = parallel_load_images(
        dset_type="synthetic",
        tqdm_title=f'Loading {split} data',
        num_images=len(frames),
        frames=frames,
        data_dir=datadir,
        out_h=None,
        out_w=None,
        downsample=downsample,
    )
    imgs, poses = zip(*img_poses)
    imgs = torch.stack(imgs, 0)  # [N, H, W, 3/4]
    poses = torch.stack(poses, 0)  # [N, ????]
    return imgs, poses


def load_360_intrinsics(transform, img_h, img_w, downsample) -> Intrinsics:
    height = img_h
    width = img_w
    # load intrinsics
    if 'fl_x' in transform or 'fl_y' in transform:
        fl_x = (transform['fl_x'] if 'fl_x' in transform else transform['fl_y']) / downsample
        fl_y = (transform['fl_y'] if 'fl_y' in transform else transform['fl_x']) / downsample
    elif 'camera_angle_x' in transform or 'camera_angle_y' in transform:
        # blender, assert in radians. already downscaled since we use H/W
        fl_x = width / (2 * np.tan(transform['camera_angle_x'] / 2)) if 'camera_angle_x' in transform else None
        fl_y = height / (2 * np.tan(transform['camera_angle_y'] / 2)) if 'camera_angle_y' in transform else None
        if fl_x is None:
            fl_x = fl_y
        if fl_y is None:
            fl_y = fl_x
    else:
        raise RuntimeError('Failed to load focal length, please check the transforms.json!')

    cx = (transform['cx'] / downsample) if 'cx' in transform else (width / 2)
    cy = (transform['cy'] / downsample) if 'cy' in transform else (height / 2)
    return Intrinsics(height=height, width=width, focal_x=fl_x, focal_y=fl_y, center_x=cx, center_y=cy)

