import os
import math
import json
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset
from .ray_utils import *

from torchvision.transforms.functional import pil_to_tensor


def normalize(v):
    """Normalize a vector."""
    return v / np.linalg.norm(v)


def average_poses(poses):
    """
    Calculate the average pose, which is then used to center all poses
    using @center_poses. Its computation is as follows:
    1. Compute the center: the average of pose centers.
    2. Compute the z axis: the normalized average z axis.
    3. Compute axis y': the average y axis.
    4. Compute x' = y' cross product z, then normalize it as the x axis.
    5. Compute the y axis: z cross product x.

    Note that at step 3, we cannot directly use y' as y axis since it's
    not necessarily orthogonal to z axis. We need to pass from x to y.
    Inputs:
        poses: (N_images, 3, 4)
    Outputs:
        pose_avg: (3, 4) the average pose
    """
    # 1. Compute the center
    center = poses[..., 3].mean(0)  # (3)

    # 2. Compute the z axis
    z = normalize(poses[..., 2].mean(0))  # (3)

    # 3. Compute axis y' (no need to normalize as it's not the final output)
    y_ = poses[..., 1].mean(0)  # (3)

    # 4. Compute the x axis
    x = normalize(np.cross(z, y_))  # (3)

    # 5. Compute the y axis (as z and x are normalized, y is already of norm 1)
    y = np.cross(x, z)  # (3)

    pose_avg = np.stack([x, y, z, center], 1)  # (3, 4)

    return pose_avg


def viewmatrix(z, up, pos):
    vec2 = normalize(z)
    vec1_avg = up
    vec0 = normalize(np.cross(vec1_avg, vec2))
    vec1 = normalize(np.cross(vec2, vec0))
    m = np.eye(4)
    m[:3] = np.stack([-vec0, vec1, vec2, pos], 1)
    return m


def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, N_rots=2, N=120):
    render_poses = []
    rads = np.array(list(rads) + [1.])

    for theta in np.linspace(0., 2. * np.pi * N_rots, N + 1)[:-1]:
        c = np.dot(c2w[:3, :4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads)
        z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.])))
        render_poses.append(viewmatrix(z, up, c))
    return render_poses


def get_spiral(c2ws_all, near_fars, rads_scale=1.0, N_views=120):
    # center pose
    c2w = average_poses(c2ws_all)

    # Get average pose
    up = normalize(c2ws_all[:, :3, 1].sum(0))

    # Find a reasonable "focus depth" for this dataset
    dt = 0.75
    close_depth, inf_depth = near_fars.min() * 0.9, near_fars.max() * 5.0
    focal = 1.0 / (((1.0 - dt) / close_depth + dt / inf_depth))

    # Get radii for spiral path
    zdelta = near_fars.min() * .2
    tt = c2ws_all[:, :3, 3]
    rads = np.percentile(np.abs(tt), 90, 0) * rads_scale
    render_poses = render_path_spiral(c2w, up, rads, focal, zdelta, zrate=.5, N=N_views)
    return np.stack(render_poses)



class LLFFMultiscaleDataset(Dataset):
    def __init__(self, datadir, split="train", downsample=4, is_stack=False, hold_every=8, n_scales=4):
        super(LLFFMultiscaleDataset, self).__init__()
        assert downsample == 4
        assert n_scales <= 5

        self.root_dir = datadir
        self.split = split
        self.is_stack = is_stack
        self.hold_every = hold_every
        self.n_scales = n_scales
        self.white_bg = False

        self._read_meta()
        self.near_far = [0.0, 1.0]
        self.scene_bbox = torch.tensor([[-1.5, -1.67, -1.0], [1.5, 1.67, 1.0]])

    def _read_meta(self):
        with open(os.path.join(self.root_dir, f"metadata.json"), "r") as f:
            self.meta = json.load(f)

        # hard-coded
        if self.n_scales <= 4:
            img_scales = torch.linspace(1, -1, len(np.unique(self.meta["label"])))
        else:
            img_scales = torch.linspace(0.75, -0.75, len(np.unique(self.meta["label"])))

        self.image_paths = self.meta["file_path"]
        self.poses = np.array(self.meta["cam2world"], dtype=np.float32)
        self.near_fars = np.stack([
            np.array(self.meta["near"], dtype=np.float32),
            np.array(self.meta["far"], dtype=np.float32),
        ], axis=-1)

        # build renderin path (every 4 step of poses and near_fars are the same)
        self.render_path = torch.tensor(np.repeat(
            get_spiral(self.poses[::4], self.near_fars[::4], N_views=120),
            repeats=4, axis=0
        ), dtype=torch.float32)

        i_test = np.concatenate([
            np.arange(begin, begin + 4) for begin in \
            np.arange(0, len(self.poses), self.hold_every * 4)
        ])

        # i_test = np.arange(0, len(self.poses), self.hold_every * 4)
        if self.split != "train":
            img_list = i_test
        else:
            img_list = list(set(np.arange(len(self.poses))) - set(i_test))

        # convert to tensor
        self.poses = torch.from_numpy(self.poses)
        self.near_fars = torch.from_numpy(self.near_fars)
        
        self.all_rays = []
        self.all_rgbs = []
        self.all_scales = []
        self.all_lossmults = []
        self.all_heights = []
        self.all_widths = []

        self.directions = []

        # use N_images - 1 to train, the LAST is val
        for i in tqdm(img_list):

            img = Image.open(os.path.join(self.root_dir, self.image_paths[i])).convert("RGB")
            img = pil_to_tensor(img).view(3, -1).permute(1, 0) / 255. # (h*w, 3)
            self.all_rgbs.append(img)

            # image dimensions
            height = self.meta["height"][i]
            width = self.meta["width"][i]
            self.all_heights.append(height)
            self.all_widths.append(width)

            # focal length
            fy = self.meta["focal_y"][i]
            fx = self.meta["focal_x"][i]

            # rays from directions (cx = W/2, cy = H/2), fy == fx
            directions = get_ray_directions_blender(height, width, (fy, fx))
            self.directions.append(directions)

            ray_o, ray_d = get_rays(directions, self.poses[i])  # both (h*w, 3)
            ray_o, ray_d = ndc_rays_blender(height, width, fy, 1.0, ray_o, ray_d)
            self.all_rays.append(torch.cat([ray_o, ray_d], 1))  # (h*w, 6)

            # loss multipliers
            lossmult = self.meta["lossmult"][i]
            assert math.sqrt(lossmult) % 1 == 0
            lossmult_t = torch.tensor(lossmult, dtype=torch.float32)
            self.all_lossmults.append(torch.broadcast_to(lossmult_t, (height * width, 1)))

            # scales
            label = self.meta["label"][i]
            self.all_scales.append(torch.broadcast_to(img_scales[label], (height * width, 1)))

        self.all_lossmults = torch.cat(self.all_lossmults)
        self.all_heights = torch.tensor(self.all_heights, dtype=torch.int64)
        self.all_widths = torch.tensor(self.all_widths, dtype=torch.int64)
        if not self.is_stack:
            self.all_rays = torch.cat(self.all_rays, 0)  # (#Frames*h*w, 3)
            self.all_rgbs = torch.cat(self.all_rgbs, 0)  # (#Frames*h*w, 3)
            self.all_scales = torch.cat(self.all_scales, 0) # (#Frames*h*w, 3)
        else:
            self.all_rgbs = [img.view(h, w, 3) for img, h, w in zip(self.all_rgbs, self.all_heights, self.all_widths)] # (#Frames,h,w,3)
    
    def __getitem__(self, idx):
        if self.split == "train":
            sample = {
                "rays": self.all_rays[idx],
                "rgbs": self.all_rgbs[idx],
                "scales": self.all_scales[idx],
                "lossmults": self.all_lossmults[idx]
            }
        else:
            sample = {
                "rays": self.all_rays[idx],
                "rgbs": self.all_rgbs[idx],
                "scales": self.all_scales[idx],
                "height": self.all_heights[idx],
                "width": self.all_widths[idx]
            }
        return sample


if __name__ == "__main__":
    dataset = LLFFMultiscaleDataset(
        datadir="/workspace/dataset/nerf_llff_data_multiscale/fern",
        split="test",
        downsample=4,
        is_stack=True,
        hold_every=8,
        n_scales=4
    )

    import pdb; pdb.set_trace()



