import os
from argparse import ArgumentParser
from os import makedirs

import numpy as np
import torch
import torchvision
from tqdm import tqdm

from arguments import ModelParams, PipelineParams, get_combined_args
from gaussian_renderer import render
from scene import Scene
from scene.planar_model import GaussianModelPlanes
from utils.general_utils import safe_state
from utils.plane_utils import quaternion_to_rotation_matrix
from utils.render_utils import generate_path, create_videos


# --- Provided functions for mapping world coordinates to plane coordinates ---
def xyz_world_to_planes(xyz: torch.Tensor, R: torch.Tensor, t: torch.Tensor):
    # R is from plane to world; R_inverse is from world to plane
    origin = xyz - t
    R_inverse = R.transpose(0, 1)
    xyz_plane = torch.matmul(R_inverse, origin.unsqueeze(-1)).squeeze(-1)
    return xyz_plane


def xyz_world_to_planes_batch(
    xyz: torch.Tensor, R_batch: torch.Tensor, t_batch: torch.Tensor
):
    origin = xyz - t_batch
    R_inverse = R_batch.transpose(1, 2)
    xyz_plane = torch.bmm(R_inverse, origin.unsqueeze(-1)).squeeze(-1)
    return xyz_plane


def project_to_plane_2d(points_3d, plane_origin, R):
    """
    Projects 3D points to 2D coordinates in the plane using rotation R.

    Args:
        points_3d: (N, 3) tensor of 3D points in world space.
        plane_origin: (3,) tensor of plane position.
        R: (3,3) rotation matrix defining plane basis.

    Returns:
        (N, 2) tensor of 2D coordinates.
    """
    points_local = points_3d - plane_origin

    # Project onto the plane basis (R is assumed to be an orthonormal basis)
    points_2d = torch.inverse(R).unsqueeze(0) @ points_local.unsqueeze(-1)
    points_2d = points_2d.squeeze(-1)

    return points_2d


def intersection_camera_to_world(intersection, view):
    """
    Args:
        intersection (torch.Tensor): Intersection points in the camera frame, shape (N, 3).
        camera (Camera): Camera object containing world_view_transform.

    Returns:
        torch.Tensor: Intersection points in the world frame, shape (N, 3).
    """
    # Compute the inverse of the world-to-view transformation to get view-to-world
    view_to_world_transform = torch.inverse(view.world_view_transform)

    # Convert intersection points to homogeneous coordinates (N, 4)
    N = intersection.shape[0]
    intersection_h = torch.cat(
        [intersection, torch.ones((N, 1), device=intersection.device)], dim=1
    )

    # Apply the inverse transformation matrix
    intersection_world_h = intersection_h @ view_to_world_transform

    # Convert back to 3D coordinates by dividing by the homogeneous component
    intersection_world = intersection_world_h[:, :3] / intersection_world_h[:, 3:4]
    return intersection_world


# --- Helper: convert quaternion to a 2D rotation angle ---
def quaternion_to_angle(q: torch.Tensor):
    # Assumes the quaternion is [w, x, y, z] and that for planar rotation
    # only w and y are significant (i.e. q = [cos(angle/2), 0, sin(angle/2), 0])
    # You might need to adjust this extraction depending on your conventions.
    w = q[0].item()
    y = q[2].item()
    angle = 2 * np.arctan2(y, w)
    return angle


def create_basis_given_normal(normal: torch.Tensor):
    if torch.abs(normal[0]) > torch.abs(normal[1]):
        vector = torch.tensor([0, 1, 0], dtype=normal.dtype, device=normal.device)
    else:
        vector = torch.tensor([1, 0, 0], dtype=normal.dtype, device=normal.device)

    basis1 = torch.cross(normal, vector, dim=0)
    basis1 = basis1 / torch.norm(basis1)

    basis2 = torch.cross(normal, basis1, dim=0)
    basis2 = basis2 / torch.norm(basis2)

    return basis1, basis2


def get_plane_camera(view, plane):
    translation = plane.get_translation
    normal = plane.get_normal

    translation_view = (
        (translation - view.camera_center).unsqueeze(0)
        @ view.world_view_transform[:3, :3]
    ).squeeze(0)
    normal_view = (normal.unsqueeze(0) @ view.world_view_transform[:3, :3]).squeeze(0)

    D = -torch.dot(translation_view, normal_view)
    parameters = torch.tensor(
        [normal_view[0], normal_view[1], normal_view[2], D],
        device="cuda",
    )
    return parameters


def get_rays(view):
    y = torch.linspace(
        -view.image_height / 2, view.image_height / 2, view.image_height, device="cuda"
    )
    x = torch.linspace(
        -view.image_width / 2, view.image_width / 2, view.image_width, device="cuda"
    )
    yy, xx = torch.meshgrid(y, x)
    rays = torch.stack([xx / view.Fx, yy / view.Fy, torch.ones_like(xx)], dim=-1)
    rays = rays / torch.norm(rays, dim=-1, keepdim=True)
    return rays


def get_ray_plane_intersection(parameters, rays):
    A = parameters[0]
    B = parameters[1]
    C = parameters[2]
    D = parameters[3]

    t = -D / (A * rays[:, :, 0] + B * rays[:, :, 1] + C * rays[:, :, 2])

    intersection = rays * t.unsqueeze(-1)

    return intersection


def project_to_plane_2d(points_3d, origin, R):
    """
    Projects 3D points to 2D coordinates in the plane using rotation R.

    Args:
        points_3d: (N, 3) tensor of 3D points in world space.
        plane_origin: (3,) tensor of plane position.
        R: (3,3) rotation matrix defining plane basis.

    Returns:
        (N, 2) tensor of 2D coordinates.
    """
    points_local = points_3d - origin

    # Project onto the plane basis (R is assumed to be an orthonormal basis)
    R_inverse = R.transpose(0, 1)
    points_2d = torch.matmul(R_inverse, points_local.unsqueeze(-1)).squeeze(-1)
    points_2d = points_2d[:, :2]
    return points_2d


def checkerboard_colors(xy_coord, scale):

    omega = 2 * np.pi / scale
    x = xy_coord[:, :, 0]
    y = xy_coord[:, :, 1]

    x_inf = torch.where(
        x == torch.inf, torch.tensor(0.0, device="cuda", dtype=torch.float32), x
    )
    y_inf = torch.where(
        y == torch.inf, torch.tensor(0.0, device="cuda", dtype=torch.float32), y
    )

    sin_y = torch.sin(omega * y_inf)
    sin_x = torch.sin(omega * x_inf)

    colors = torch.where(
        ((sin_x * sin_y) < 0).unsqueeze(-1).repeat(1, 1, 3),
        torch.tensor([0.4, 0.4, 0.4], device="cuda", dtype=torch.float32),
        torch.tensor([0.8, 0.8, 0.8], device="cuda", dtype=torch.float32),
    )
    colors = torch.where(
        (x == torch.inf).unsqueeze(-1).repeat(1, 1, 3),
        torch.tensor([0.0, 0.0, 0.0], device="cuda", dtype=torch.float32),
        colors,
    )
    return colors.permute(2, 0, 1)


def render_set(model_path, name, iteration, views, gaussians, pipeline, background):
    render_path = os.path.join(
        model_path, name, "ours_{}".format(iteration), "renders_planes"
    )
    masks_path = os.path.join(
        model_path, name, "ours_{}".format(iteration), "renders_mask"
    )
    makedirs(render_path, exist_ok=True)
    makedirs(masks_path, exist_ok=True)

    plane_basis = []
    for plane in gaussians.planes:
        normal = plane.get_normal
        basis1, basis2 = create_basis_given_normal(normal)

        plane_basis.append((basis1, basis2))

    for idx, view in enumerate(tqdm(views, desc="Rendering progress")):

        render_pkg = render(view, gaussians, pipeline, background)
        img = render_pkg["render"]
        vis_filter = render_pkg["visibility_filter"].squeeze(-1)

        for idx_plane, plane in enumerate(gaussians.planes):

            condition = gaussians.plane_ids == idx_plane
            opacity = (gaussians.get_opacity > 0.1).squeeze()

            if (vis_filter & condition).sum() == 0:
                continue
            vis_condition = vis_filter & condition & opacity
            basis1, basis2 = plane_basis[idx_plane]
            R = quaternion_to_rotation_matrix(plane.get_rotation)
            plane_origin = plane.get_translation

            parameters = get_plane_camera(view, plane)
            rays = get_rays(view)
            intersection = get_ray_plane_intersection(parameters, rays)
            H, W = intersection.shape[0], intersection.shape[1]
            intersection_world = intersection_camera_to_world(
                intersection.reshape(-1, 3), view
            )
            intersection_plane_2d = project_to_plane_2d(
                intersection_world, plane_origin, R
            )

            if intersection is not None:
                basis1_view = basis1 @ view.world_view_transform[:3, :3]
                basis2_view = basis2 @ view.world_view_transform[:3, :3]

                translation_camera = (
                    plane.get_translation.repeat(
                        intersection.shape[0], intersection.shape[1], 1
                    )
                    @ view.world_view_transform[:3, :3]
                )

                x = torch.sum(
                    (intersection + translation_camera)
                    * basis1_view.repeat(
                        intersection.shape[0], intersection.shape[1], 1
                    ),
                    dim=-1,
                )
                y = torch.sum(
                    (intersection + translation_camera)
                    * basis2_view.repeat(
                        intersection.shape[0], intersection.shape[1], 1
                    ),
                    dim=-1,
                )

                p = torch.stack([x, y], dim=-1)
                p[intersection[:, :, 2] < 0] = torch.tensor(
                    [torch.inf, torch.inf], device="cuda", dtype=torch.float32
                )

                checkerboard_texture = checkerboard_colors(
                    xy_coord=intersection_plane_2d.reshape(H, W, 2),
                    scale=0.3,
                )

                override_color = torch.where(
                    condition.unsqueeze(1),
                    torch.tensor([1.0, 1.0, 1.0], device="cuda", dtype=torch.float32),
                    torch.tensor([0.0, 0.0, 0.0], device="cuda", dtype=torch.float32),
                ).unsqueeze(0)

                mask = render(
                    view, gaussians, pipeline, background, override_color=override_color
                )["render"]

                img = mask * checkerboard_texture + (1 - mask) * img

        result_final = img

        torchvision.utils.save_image(
            result_final, os.path.join(render_path, "{0:05d}".format(idx) + ".png")
        )


def render_sets(
    dataset: ModelParams,
    iteration: int,
    pipeline: PipelineParams,
    skip_train: bool,
    skip_test: bool,
    skip_video: bool,
):
    with torch.no_grad():

        gaussians = GaussianModelPlanes(dataset.sh_degree)
        scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)

        bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
        background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

        if not skip_train:
            render_set(
                dataset.model_path,
                "train",
                scene.loaded_iter,
                scene.getTrainCameras(),
                gaussians,
                pipeline,
                background,
            )

        if not skip_test:
            render_set(
                dataset.model_path,
                "test",
                scene.loaded_iter,
                scene.getTestCameras(),
                gaussians,
                pipeline,
                background,
            )
        if not skip_video:
            print("render videos ...")
            traj_dir = os.path.join(
                args.model_path, "traj", "ours_{}".format(scene.loaded_iter)
            )
            os.makedirs(traj_dir, exist_ok=True)
            n_frames = 240
            cam_traj = generate_path(
                scene.getTrainCameras(), n_frames=n_frames, path_type="interpolated"
            )
            render_set(
                dataset.model_path,
                "traj",
                scene.loaded_iter,
                cam_traj,
                gaussians,
                pipeline,
                background,
            )

            create_videos(
                base_dir=traj_dir,
                input_dir=traj_dir,
                out_name="render_traj",
                num_frames=n_frames,
            )


if __name__ == "__main__":
    # Set up command line argument parser
    parser = ArgumentParser(description="Testing script parameters")
    model = ModelParams(parser, sentinel=True)
    pipeline = PipelineParams(parser)
    parser.add_argument("--iteration", default=-1, type=int)
    parser.add_argument("--skip_train", action="store_true")
    parser.add_argument("--skip_test", action="store_true")
    parser.add_argument("--skip_video", action="store_true")
    parser.add_argument("--quiet", action="store_true")
    args = get_combined_args(parser)
    print("Rendering " + args.model_path)

    # Initialize system state (RNG)
    safe_state(args.quiet)

    render_sets(
        model.extract(args),
        args.iteration,
        pipeline.extract(args),
        args.skip_train,
        args.skip_test,
        args.skip_video,
    )
