import os
import shutil
import torch
import torchvision
import numpy as np
from tqdm import tqdm
from argparse import ArgumentParser
from gaussian_splatting.scene import Scene
from gaussian_splatting.gaussian_renderer import render, GaussianModel
from gaussian_splatting.utils.general_utils import safe_state
from gaussian_splatting.arguments import ModelParams, PipelineParams, get_combined_args
from scipy.spatial.transform import Rotation
from simulator.api import simulate_events_from_info_txt

try:
    from diff_gaussian_rasterization import SparseGaussianAdam
    SPARSE_ADAM_AVAILABLE = True
except:
    SPARSE_ADAM_AVAILABLE = False


def render_set(model_path, name, iteration, views, gaussians, pipeline, background, train_test_exp, separate_sh, interpolate):
    base_dir = os.path.join(model_path, name, f"ours_{iteration}")
    render_path = os.path.join(base_dir, "renders")
    gts_path = os.path.join(base_dir, "gt")
    colmap_dir = os.path.join(base_dir, "colmap")
    esim_dir = os.path.join(base_dir, "esim")

    for path in [render_path, gts_path, colmap_dir, esim_dir]:
        if os.path.exists(path):
            shutil.rmtree(path)
        os.makedirs(path, exist_ok=True)

    images_txt = []
    esim_pose_lines = []

    base_time_ns = 0
    fps = args.basic_speed
    frame_interval_ns = int(1e9 / fps)

    for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
        # Render
        rendering = render(view, gaussians, pipeline, background,
                           use_trained_exp=train_test_exp, separate_sh=separate_sh)["render"]
        if not interpolate:
            gt = view.original_image[0:3, :, :]
        if args.train_test_exp:
            rendering = rendering[..., rendering.shape[-1] // 2:]
            gt = gt[..., gt.shape[-1] // 2:]

        img_name = f"{idx:05d}.png"
        torchvision.utils.save_image(rendering, os.path.join(render_path, img_name))
        if not interpolate:
            torchvision.utils.save_image(gt, os.path.join(gts_path, img_name))

        # Pose transformation
        T_cw = view.world_view_transform.transpose(0, 1).cpu().numpy()
        T_wc = np.linalg.inv(T_cw)
        R_wc = T_wc[:3, :3]
        t_wc = T_wc[:3, 3]
        q = Rotation.from_matrix(R_wc).as_quat()  # [x, y, z, w]
        qw, qx, qy, qz = q[3], q[0], q[1], q[2]

        # === COLMAP images.txt ===
        images_txt.append(f"{idx+1} {qw} {qx} {qy} {qz} {t_wc[0]} {t_wc[1]} {t_wc[2]} 1 {img_name}")

        # === ESIM pose_esim.txt ===
        timestamp_ns = base_time_ns + idx * frame_interval_ns
        esim_pose_lines.append(f"{timestamp_ns} {t_wc[0]:.6f} {t_wc[1]:.6f} {t_wc[2]:.6f} {qx:.6f} {qy:.6f} {qz:.6f} {qw:.6f}")

    # === write info.txt to DVS-Voltmeter ===
    dvs_input_dir = os.path.join(base_dir, "dvs_input")
    dvs_output_dir = os.path.join(base_dir, "dvs_output")
    os.makedirs(dvs_input_dir, exist_ok=True)
    os.makedirs(dvs_output_dir, exist_ok=True)

    info_txt_lines = []
    for idx in range(len(views)):
        timestamp_us = (base_time_ns + idx * frame_interval_ns) // 1000
        img_path = os.path.abspath(os.path.join(render_path, f"{idx:05d}.png"))
        info_txt_lines.append(f"{img_path} {timestamp_us}")

    with open(os.path.join(dvs_input_dir, "info.txt"), "w") as f:
        f.write("\n".join(info_txt_lines))
    # === Run DVS simulation ===
    # dvs_output_dir = os.path.join(base_dir, "dvs_output")
    # simulate_events_from_info_txt(
    #     info_txt_path=os.path.join(dvs_input_dir, "info.txt"),
    #     output_txt_path=os.path.join(dvs_output_dir, "events.txt")
    # )

    # write COLMAP images.txt
    with open(os.path.join(colmap_dir, "images.txt"), 'w') as f:
        f.write("# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, IMAGE_NAME\n")
        for line in images_txt:
            f.write(line + '\n\n')

    # write COLMAP cameras.txt
    with open(os.path.join(colmap_dir, "cameras.txt"), 'w') as f:
        f.write("# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n")
        width, height = view.image_width, view.image_height
        fx, fy = view.focal_x, view.focal_y
        cx, cy = view.principal_x, view.principal_y
        f.write(f"1 PINHOLE {width} {height} {fx} {fy} {cx} {cy}\n")

    # write ESIM pose_esim.txt
    with open(os.path.join(esim_dir, "pose_esim.txt"), 'w') as f:
        for line in esim_pose_lines:
            f.write(line + '\n')


def render_sets(dataset: ModelParams, iteration: int, pipeline: PipelineParams,
                skip_train: bool, skip_test: bool, separate_sh: bool, interpolate: bool):
    with torch.no_grad():
        gaussians = GaussianModel(dataset.sh_degree)
        scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
        bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
        background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

        if not skip_train:
            if not interpolate:
                render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(),
                           gaussians, pipeline, background, dataset.train_test_exp, separate_sh)
            else:
                nums_inserted = args.nums_inserted
                if args.interp_mode == "linear":
                    render_set(dataset.model_path, "interpolate_linear", scene.loaded_iter,
                               scene.getLinearInterpolatedCameras(nums_inserted=nums_inserted),
                               gaussians, pipeline, background, dataset.train_test_exp, separate_sh, interpolate=True)
                elif args.interp_mode == "bspline":
                    render_set(dataset.model_path, "interpolate_bspline", scene.loaded_iter,
                               scene.getBspineInterpolatedCameras(nums_inserted=nums_inserted),
                               gaussians, pipeline, background, dataset.train_test_exp, separate_sh, interpolate=True)
                elif args.interp_mode == "adaptive":
                    render_set(dataset.model_path, "interpolate_adaptive", scene.loaded_iter,
                               scene.getAdaptiveInterpolatedCameras(interp_multiplier=nums_inserted),
                               gaussians, pipeline, background, dataset.train_test_exp, separate_sh, interpolate=True)
                elif args.interp_mode == "edited_adaptive":
                    render_set(dataset.model_path, "edited_interpolate_adaptive", scene.loaded_iter,
                               scene.getEditedAdaptiveInterpolatedCameras(interp_multiplier=nums_inserted),
                               gaussians, pipeline, background, dataset.train_test_exp, separate_sh, interpolate=True)
                elif args.interp_mode == "ada_speed":
                    speed_profile = args.speed_profile
                    render_set(dataset.model_path, "edited_speed_adaptive", scene.loaded_iter,
                               scene.getEditedSpeedInterpolatedCameras(interp_multiplier=nums_inserted, speed_profile=speed_profile),
                               gaussians, pipeline, background, dataset.train_test_exp, separate_sh, interpolate=True)
                if args.novel_view:
                    speed_profile = args.speed_profile
                    cams = scene.getNovelViewInterpolatedCameras(interp_multiplier=nums_inserted, speed_profile=speed_profile)
                    for i, cam in enumerate(cams):
                        render_set(dataset.model_path, f"interpolate_novel_view_{i}", scene.loaded_iter,
                                cam, gaussians, pipeline, background, dataset.train_test_exp, separate_sh, interpolate=False)
                

        if not skip_test:
            render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(interpolate=False),
                       gaussians, pipeline, background, dataset.train_test_exp, separate_sh)


if __name__ == "__main__":
    parser = ArgumentParser(description="Testing script parameters")
    model = ModelParams(parser, sentinel=True)
    pipeline = PipelineParams(parser)
    parser.add_argument("--iteration", default=-1, type=int)
    parser.add_argument("--skip_train", action="store_true")
    parser.add_argument("--skip_test", action="store_true")
    parser.add_argument("--quiet", action="store_true")
    parser.add_argument("--interpolate", action="store_true")
    parser.add_argument("--interp_mode", default="None")
    parser.add_argument("--nums_inserted", default=10, type=int)
    parser.add_argument("--basic_speed", default=2400)
    parser.add_argument("--speed_profile", default='None',type=str)
    parser.add_argument("--novel_view", action="store_true")
    args = get_combined_args(parser)


    print("Rendering " + args.model_path)
    safe_state(args.quiet)
    render_sets(model.extract(args), args.iteration, pipeline.extract(args),
                args.skip_train, args.skip_test, SPARSE_ADAM_AVAILABLE, interpolate=args.interpolate)
    print("Rendering done")

'''
python _1_render_interpolate.py --model_path output/1d2806d7-c --iteration 30000 --skip_test --interpolate --interp_mode bspline --nums_inserted 100
python _1_render_interpolate.py --model_path output/1d2806d7-c --iteration 30000 --skip_test --interpolate --interp_mode linear --nums_inserted 100
python _1_render_interpolate.py --model_path output/1d2806d7-c --iteration 30000 --skip_test --interpolate --interp_mode adaptive --nums_inserted 50
python _1_render_interpolate.py --model_path output/mipnerf-garden --iteration 30000 --skip_test --interpolate --interp_mode edited_adaptive --nums_inserted 50
python _1_render_interpolate.py --model_path /share/magic_group/gs2e_data/output/1a005c25 --source_path /share/magic_group/gs2e_data/mvimgnet/mv_1k/mv_cuda6/1a005c25 --iteration 30000 --skip_test --interpolate --interp_mode edited_adaptive --nums_inserted 50
python _1_render_interpolate.py --model_path /share/magic_group/gs2e_data/output/0000a21d --source_path /share/magic_group/gs2e_data/mvimgnet/mv_1k/0000a21d --skip_test --interpolate --interp_mode ada_speed --nums_inserted 50 --novel_view

python _1_render_interpolate.py --model_path /share/magic_group/gs2e_data/output/0000a21d --source_path /share/magic_group/gs2e_data/source/0000a21d --skip_test --interpolate --interp_mode ada_speed --nums_inserted 5
python _1_render_interpolate.py --model_path /share/magic_group/gs2e_data/output/0000a21d --source_path /share/magic_group/gs2e_data/mvimgnet/mv_1k/0000a21d --skip_test --interpolate --novel_view --nums_inserted 5

python _1_render_interpolate.py --model_path /share/magic_group/gs2e_data/output/0000a21d --source_path /share/magic_group/gs2e_data/source/0000a21d --skip_test --interpolate --interp_mode linear --nums_inserted 5
python _1_render_interpolate.py --model_path /share/magic_group/GS2Event-Simulator/test_quality/output --source_path /share/magic_group/GS2Event-Simulator/test_quality/source --skip_test --interpolate --interp_mode ada_speed --nums_inserted 5


'''