import os
import torch
import numpy as np
import imageio
import json
import skimage

trans_t = lambda t: torch.Tensor([
    [1, 0, 0, 0],
    [0, 1, 0, 0],
    [0, 0, 1, t],
    [0, 0, 0, 1]]).float()

rot_phi = lambda phi: torch.Tensor([
    [1, 0, 0, 0],
    [0, np.cos(phi), -np.sin(phi), 0],
    [0, np.sin(phi), np.cos(phi), 0],
    [0, 0, 0, 1]]).float()

rot_theta = lambda th: torch.Tensor([
    [np.cos(th), 0, -np.sin(th), 0],
    [0, 1, 0, 0],
    [np.sin(th), 0, np.cos(th), 0],
    [0, 0, 0, 1]]).float()

rot_z = lambda z: torch.Tensor([
    [np.cos(z), -np.sin(z), 0, 0],
    [np.sin(z), np.cos(z), 0, 0],
    [0, 0, 1, 0],
    [0, 0, 0, 1]]).float()


def pose_spherical(theta, phi, radius):
    c2w = trans_t(radius)
    c2w = rot_phi(phi / 180. * np.pi) @ c2w
    c2w = rot_theta(theta / 180. * np.pi) @ c2w
    c2w = torch.Tensor(np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])) @ c2w
    return c2w


def load_blender_data(
        basedir, image_downscale_factor=1, image_downscale_filter='area',
        testskip=1, scene_scale=2 / 3, scene_rot_z_deg=0
):
    """
    Load Blender (NeRF paper) dataset
    :param basedir:
    :param image_downscale_factor:
    :param image_downscale_filter:
    :param testskip:
    :param scene_scale: 2/3 is taken from svox2 (Plenoxels) source code for this dataset
    :return:
    """
    splits = ['train', 'val', 'test']
    metas = {}
    for s in splits:
        with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp:
            metas[s] = json.load(fp)

    all_imgs = []
    all_poses = []
    counts = [0]
    for s in splits:
        meta = metas[s]
        imgs = []
        poses = []
        if s == 'train' or testskip == 0:
            skip = 1
        else:
            skip = testskip

        for frame in meta['frames'][::skip]:
            fname = os.path.join(basedir, frame['file_path'] + '.png')
            imgs.append(imageio.imread(fname))
            poses.append(np.array(frame['transform_matrix']))
            if scene_rot_z_deg != 0:
                poses[-1] = rot_z(scene_rot_z_deg / 180. * np.pi).numpy() @ poses[-1]
        imgs = (np.array(imgs) / 255.).astype(np.float32)  # keep all 4 channels (RGBA)
        poses = np.array(poses).astype(np.float32)
        counts.append(counts[-1] + imgs.shape[0])
        all_imgs.append(imgs)
        all_poses.append(poses)

    i_split = [np.arange(counts[i], counts[i + 1]) for i in range(3)]

    imgs = np.concatenate(all_imgs, 0)
    poses = np.concatenate(all_poses, 0)

    H, W = imgs[0].shape[:2]
    camera_angle_x = float(meta['camera_angle_x'])
    focal = .5 * W / np.tan(.5 * camera_angle_x)

    render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180, 180, 40 + 1)[:-1]], 0)

    poses[:, :3, 3] *= scene_scale
    render_poses[:, :3, 3] *= scene_scale

    if image_downscale_factor != 1:
        if type(image_downscale_factor) is not int or image_downscale_factor < 2 \
                or image_downscale_factor & (image_downscale_factor - 1) != 0:
            raise ValueError(f'Invalid {image_downscale_factor=}')
        if H % image_downscale_factor != 0 or W % image_downscale_factor != 0:
            raise ValueError(f'Invalid {image_downscale_factor=} with {W=}, {H=}')

        H = H // image_downscale_factor
        W = W // image_downscale_factor
        focal = focal / image_downscale_factor

        imgs_small = np.zeros((imgs.shape[0], H, W, 4))
        for i, img in enumerate(imgs):
            if image_downscale_filter == 'area':
                img = np.reshape(img, (H, image_downscale_factor, W, image_downscale_factor, 4))
                img = np.mean(img, axis=(1, 3))
                imgs_small[i] = img
            elif image_downscale_filter == 'antialias':
                imgs_small[i] = skimage.transform.resize(img, (H, W), anti_aliasing=True)
            else:
                raise ValueError(f'Unknown interpolation filter: {image_downscale_filter}')
        imgs = imgs_small

    return imgs, poses, render_poses, [H, W, focal], i_split
