"""
Check the volumes to make sure everything aligns.

"""

import h5py
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
import torch

from src.config_utils import Config
from src.shape_carving import create_3d_grid, shift_and_rotate_grid_points
from src.utils import get_cam_params
from src.plots import plot_imgs_and_volume


USAGE = "Usage:\n$ python check_volumes.py <config.json>"



if __name__ == '__main__':
    assert len(sys.argv) == 2, USAGE
    config = Config(sys.argv[1])

    # Get images.
    img_fn = os.path.join(config.image_directory, "images.h5")
    img_hdf5_file = h5py.File(img_fn, "r")
    images = img_hdf5_file["images"]

    total_images = len(images)
    print("total_images:", total_images)

    # Get center, rotation.
    d = np.load(config.center_rotation_fn)
    centers, angles = d["centers"], d["angles"]
    assert len(centers) == total_images

    grid = create_3d_grid(config.ell, config.grid_size, volume_idx=config.volume_idx)
    ds = config.image_downsample
    W = config.image_width // ds
    H = config.image_height // ds

    # Get volumes and calculate center.
    intrinsic, extrinsic, Ps = get_cam_params(
        config.camera_fn,
        ds=ds,
        up_fn=config.vertical_lines_fn,
        auto_orient=True,
        load_up_direction=True,
    )
    C = len(Ps)

    # Load the volume.
    volume_fn = os.path.join(config.volume_directory, "volumes.h5")
    volume_hdf5_file = h5py.File(volume_fn, "r")
    volumes = volume_hdf5_file["volumes"]

    # Choose a random frame.
    idx = np.random.randint(total_images)
    print("random index:", idx)

    volume = torch.tensor(volumes[idx]).to(torch.float32) / 255.0
    print("volume:", volume.shape)

    # Get the images.
    imgs = images[idx].astype(np.float32) / 255.0
    print("imgs:", imgs.shape)

    # Rotate grid.
    grid = shift_and_rotate_grid_points(grid, centers[idx], angles[idx])

    # Plot.
    perm = np.random.permutation(C)[:3]
    print("perm:", perm)
    plot_imgs_and_volume(imgs[perm], volume, grid, intrinsic[perm], extrinsic[perm], W, H)






###