import sys
import os
from scene import Scene, MixGaussianModel
import torch
import numpy as np
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig

from argparse import ArgumentParser
from arguments import ModelParams, PipelineParams, OptimizationParams, ModelHiddenParams
from tqdm import tqdm
from utils.image_utils import save_video
from utils.general_utils import safe_state
from gaussian_renderer import differentiable_render_with_hexplane
from loguru import logger

def render_sets(dataset, pipe, load_iteration, skip_train, skip_test,
                gaussian_dim, time_duration, num_pts, num_pts_ratio, rot_4d, force_sh_3d,
                target_cam, target_ts_train, target_ts_test, render_flow, render_depth, hyper=None, opt=None):
    with torch.no_grad():
        time_duration = [time_duration[0] / dataset.frame_ratio,  time_duration[1] / dataset.frame_ratio]

        min_timestep = 1.0/dataset.Ntime/dataset.frame_ratio
        scene = Scene(dataset, num_pts=num_pts, num_pts_ratio=num_pts_ratio, time_duration=time_duration, target_cam=target_cam,
                      target_time_train=[ts*min_timestep for ts in target_ts_train], target_time_test=[ts*min_timestep for ts in target_ts_test],
                      shuffle=False, cached_dataset=False)
        gaussians = MixGaussianModel(dataset.sh_degree, gaussian_dim=gaussian_dim, time_duration=time_duration, rot_4d=rot_4d, force_sh_3d=force_sh_3d,
                                     sh_degree_t=2 if pipe.eval_shfs_4d else 0, distortion_args=hyper, frame_ratio=dataset.frame_ratio, Ntime=dataset.Ntime)

        if load_iteration == -1:
            loaded_iter = "best"
            checkpoint = os.path.join(dataset.model_path, f"chkpnt_{loaded_iter}.pth")
        else:
            loaded_iter = load_iteration
            # loaded_iter = "30000_0722"
            checkpoint = os.path.join(dataset.model_path, f"chkpnt{loaded_iter}.pth")
        (model_params, first_iter) = torch.load(checkpoint)
        gaussians.restore(model_params, None, scene.getTrainDifferentiableCameras())
        logger.info(f"Loaded trained model from {dataset.model_path} at iteration {first_iter}")

        bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
        background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

        differentiable_render_set(dataset.model_path, "test", first_iter, scene.getTestCameras(), scene.getTestDifferentiableCameras(), gaussians, pipe, dataset, background)

def differentiable_render_set(model_path, name, iteration, views, differentiable_cams, gaussians, pipeline, dataset, background):
    video_path = os.path.join(model_path, name, "ours_{}".format(iteration), "videos")
    os.makedirs(video_path, exist_ok=True)
    dataloader = [views[idx] for idx in range(len(views))]
    print(f"render {len(dataloader)} images to {video_path}")

    renders = []
    for idx, batch_data in enumerate(tqdm(dataloader, desc="Rendering progress")):
        _, viewpoint, _ = batch_data
        viewpoint = viewpoint.cuda()
        differentiable_cam = differentiable_cams.get_camera(viewpoint.camera_id)

        render_pkg = differentiable_render_with_hexplane(viewpoint, differentiable_cam, gaussians, pipeline, background, apply_distortion=False)
        image = torch.clamp(render_pkg["render"], 0.0, 1.0)
        renders.append(image[None].detach().cpu())
    renders = torch.cat(renders).permute(0, 2, 3, 1)
    os.makedirs(video_path, exist_ok=True)
    fps = 15
    save_video(renders, os.path.join(video_path, f"spiral.mp4"), target_fps=fps)


if __name__ == "__main__":
    # Set up command line argument parser
    parser = ArgumentParser(description="Training script parameters")
    lp = ModelParams(parser)
    op = OptimizationParams(parser)
    pp = PipelineParams(parser)
    hp = ModelHiddenParams(parser)
    parser.add_argument("--iteration", default=-1, type=int)
    parser.add_argument("--skip_train", action="store_true")
    parser.add_argument("--skip_test", action="store_true")
    parser.add_argument("--render_flow", action="store_true")
    parser.add_argument("--render_depth", action="store_true")
    parser.add_argument("--quiet", action="store_true")
    parser.add_argument("--config", type=str)

    parser.add_argument("--gaussian_dim", type=int, default=3)
    parser.add_argument("--time_duration", nargs=2, type=float, default=[-0.5, 0.5])
    parser.add_argument('--num_pts', type=int, default=100_000)
    parser.add_argument('--num_pts_ratio', type=float, default=1.0)
    parser.add_argument("--rot_4d", action="store_true")
    parser.add_argument("--force_sh_3d", action="store_true")
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--seed", type=int, default=6666)
    parser.add_argument("--exhaust_test", action="store_true")
    parser.add_argument("--target_ts_train", nargs="+", type=int, default=[-1], help="target timestamp for training views, -1 means loading all timestamp")
    parser.add_argument("--target_ts_test", nargs="+", type=int, default=[-1], help="target timestamp for testing views, -1 means loading all timestamp")
    parser.add_argument("--target_cam", nargs="+", type=str, default=[])
    parser.add_argument("--extra_model_path", type=str, default="")

    args = parser.parse_args(sys.argv[1:])
    cfg = OmegaConf.load(args.config)
    def recursive_merge(key, host):
        if key == "kplanes_config":
            setattr(args, key, host[key])
        elif isinstance(host[key], DictConfig):
            for key1 in host[key].keys():
                recursive_merge(key1, host[key])
        else:
            assert hasattr(args, key), key
            setattr(args, key, host[key])
    for k in cfg.keys():
        recursive_merge(k, cfg)
    # Initialize system state (RNG)
    safe_state(args.quiet)

    if os.path.exists(args.extra_model_path):
        args.model_path = args.extra_model_path

    args.transforms_file = os.path.join(os.path.dirname(args.transforms_file), "spiral_render_path.json")
    render_sets(lp.extract(args), pp.extract(args), args.iteration, args.skip_train, args.skip_test,
                args.gaussian_dim, args.time_duration, args.num_pts, args.num_pts_ratio, args.rot_4d, args.force_sh_3d,
                args.target_cam, args.target_ts_train, args.target_ts_test, args.render_flow, args.render_depth,
                hyper=hp.extract(args), opt=op.extract(args))
