import glob
import json
import os
import torchvision.transforms
import dearpygui.dearpygui as dpg
from scipy.spatial.transform import Rotation as R
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from gaussian_renderer import render_fn_dict
from scene import GaussianModel
from utils.general_utils import safe_state
from utils.camera_utils import Camera, JSON_to_camera
from argparse import ArgumentParser
from arguments import ModelParams, PipelineParams
from utils.system_utils import searchForMaxIteration
from scene.derect_light_sh import DirectLightEnv
from utils.graphics_utils import focal2fov, hdr2ldr
from scene.gamma_trans import LearningGammaTransform
import time

def safe_normalize(x, eps=1e-20):
    return x / torch.sqrt(torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps))


class OrbitCamera:
    def __init__(self, W, H, fovy=60, near=0.1, far=10, rot=None, translate=None, center=None):
        self.W = W
        self.H = H
        if translate is None:
            self.radius = 1
        else:
            self.radius = np.linalg.norm(translate)
        self.radius *= 2
        self.fovy = fovy  # in degree
        self.near = near
        self.far = far

        if center is None:
            self.center = np.array([0, 0, 0], dtype=np.float32)  # look at this point
        else:
            self.center = center

        if rot is None:
            self.rot = R.from_matrix(np.array([[1, 0, 0], [0, 1, 0], [0, 0, -1]]))  # looking back to z axis
        else:
            self.rot = R.from_matrix(rot)

        # self.up = np.array([0, -1, 0], dtype=np.float32)  # need to be normalized!
        self.up = -self.rot.as_matrix()[:3, 1]

    # pose
    @property
    def pose(self):
        # first move camera to radius
        res = np.eye(4, dtype=np.float32)
        res[2, 3] = self.radius
        # rotate
        rot = np.eye(4, dtype=np.float32)
        rot[:3, :3] = self.rot.as_matrix()
        res = rot @ res
        # translate
        res[:3, 3] -= self.center
        return res

    # view
    @property
    def view(self):
        return np.linalg.inv(self.pose)

    # intrinsics
    @property
    def intrinsics(self):
        focal = self.H / (2 * np.tan(np.radians(self.fovy) / 2))
        return np.array([focal, focal, self.W // 2, self.H // 2], dtype=np.float32)

    def orbit(self, dx, dy):
        # rotate along camera up/side axis!
        side = self.rot.as_matrix()[:3, 0]  # why this is side --> ? # already normalized.
        rotvec_x = self.up * np.radians(-0.05 * dx)
        rotvec_y = side * np.radians(-0.05 * dy)
        self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot

    def scale(self, delta):
        self.radius *= 1.1 ** (-delta)

    def pan(self, dx, dy, dz=0):
        # pan in camera coordinate system (careful on the sensitivity!)
        self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([-dx, -dy, dz])


class GUI:
    def __init__(self, H, W, fovy, c2w, center, render_fn, render_kwargs, 
                 mode="render", debug=True, use_hdr2ldr=False):
        """
        If the image is hdr, set use_hdr2ldr = True for LDR visualization. [0, 1]
        If the image is hdr, set use_hdr2ldr = False, the range of the image is not [0,1].
        """
        if use_hdr2ldr:
            print("[GUI INFO] HDR rendered images will be shown in LDR.")

        self.W = W
        self.H = H
        self.debug = debug
        rot = c2w[:3, :3]
        translate = c2w[:3, 3] - center
        self.render_fn = render_fn
        self.render_kwargs = render_kwargs
        self.use_hdr2ldr = use_hdr2ldr
        
        self.cam = OrbitCamera(self.W, self.H, fovy=fovy * 180 / np.pi, rot=rot, translate=translate, center=center)

        self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
        self.resize_fn = torchvision.transforms.Resize((self.H, self.W), antialias=True)
        self.downsample = 1
        self.start = torch.cuda.Event(enable_timing=True)
        self.end = torch.cuda.Event(enable_timing=True)

        self.menu = None
        self.mode = None
        self.step()
        self.mode = mode if mode in self.menu else self.menu[0]
        dpg.create_context()
        self.register_dpg()

    def __del__(self):
        dpg.destroy_context()

    def get_buffer(self, render_results, mode=None):
        if render_results is None or mode is None:
            output = torch.ones(self.H, self.W, 3, dtype=torch.float32, device='cuda').detach().cpu().numpy()
        else:
            output = render_results[mode]

            if mode == "depth":
                output = (output - output.min()) / (output.max() - output.min())
            elif mode == "num_contrib":
                output = output.clamp_max(1000) / 1000
            elif mode == "pbr" or mode == "render":
                if self.use_hdr2ldr:
                    output = hdr2ldr(output)

            if len(output.shape) == 2:
                output = output[None]
            if output.shape[0] == 1:
                output = output.repeat(3, 1, 1)
            if "normal" in mode:
                opacity = render_results["opacity"]
                output = output * 0.5 + 0.5 * opacity
            if (self.H, self.W) != tuple(output.shape[1:]):
                output = self.resize_fn(output)

            output = output.permute(1, 2, 0).contiguous().detach().cpu().numpy()
        return output

    @property
    def custom_cam(self):
        w2c = self.cam.view
        R = w2c[:3, :3].T
        T = w2c[:3, 3]
        down = self.downsample
        H, W = self.H // down, self.W // down
        fovy = self.cam.fovy * np.pi / 180
        fovx = fovy * W / H
        custom_cam = Camera(colmap_id=0, R=R, T=-T,
                            FoVx=fovx, FoVy=fovy, fx=None, fy=None, cx=None, cy=None,
                            image=torch.zeros(3, H, W), image_name=None, uid=0)
        return custom_cam

    @torch.no_grad()
    def render(self):
        self.step()
        dpg.render_dearpygui_frame()

    def step(self):
        self.start.record()
        
        render_pkg = self.render_fn(viewpoint_camera=self.custom_cam, **self.render_kwargs)
        self.render_kwargs['idx'] = render_pkg['idx']
        self.end.record()
        torch.cuda.synchronize()
        t = self.start.elapsed_time(self.end)

        buffer1 = self.get_buffer(render_pkg, self.mode)
        self.render_buffer = buffer1

        if t == 0:
            fps = 0
        else:
            fps = int(1000 / t)

        if self.menu is None:
            self.menu = [k for k, v in render_pkg.items() if
                         isinstance(v, torch.Tensor) and np.array(v.shape).prod() % (self.H * self.W) == 0]
        else:
            dpg.set_value("_log_infer_time", f'{t:.4f}ms ({fps} FPS)')
            dpg.set_value("_texture", self.render_buffer)

    def register_dpg(self):

        ### register texture

        with dpg.texture_registry(show=False):
            dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture")

        ### register window

        # the rendered image, as the primary window
        with dpg.window(tag="_primary_window", width=self.W, height=self.H):

            # add the texture
            dpg.add_image("_texture")

        dpg.set_primary_window("_primary_window", True)

        # control window
        with dpg.window(label="Control", tag="_control_window", width=300, height=200):

            # button theme
            with dpg.theme() as theme_button:
                with dpg.theme_component(dpg.mvButton):
                    dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18))
                    dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47))
                    dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83))
                    dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5)
                    dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3)

            with dpg.group(horizontal=True):
                dpg.add_text("Infer time: ")
                dpg.add_text("no data", tag="_log_infer_time")

            # rendering options
            with dpg.collapsing_header(label="Options", default_open=True):
                # mode combo
                def callback_change_mode(sender, app_data):
                    self.mode = app_data
                    self.need_update = True

                dpg.add_combo(self.menu, label='mode', default_value=self.mode, callback=callback_change_mode)

                def callback_set_downsample(sender, app_data):
                    self.downsample = app_data
                    self.need_update = True

                dpg.add_slider_int(label="Downsample", min_value=1, max_value=8, format="x%d",
                                   default_value=self.downsample, callback=callback_set_downsample)

                # fov slider
                def callback_set_fovy(sender, app_data):
                    self.cam.fovy = app_data
                    self.need_update = True

                dpg.add_slider_int(label="FoV (vertical)", min_value=1, max_value=120, format="%d deg",
                                   default_value=self.cam.fovy, callback=callback_set_fovy)

            # debug info
            if self.debug:
                with dpg.collapsing_header(label="Debug"):
                    # pose
                    dpg.add_separator()
                    dpg.add_text("Camera Pose:")
                    dpg.add_text(str(self.cam.pose), tag="_log_pose")

        ### register camera handler

        def callback_camera_drag_rotate(sender, app_data):

            if not dpg.is_item_focused("_primary_window"):
                return

            dx = app_data[1]
            dy = app_data[2]

            self.cam.orbit(dx, dy)
            self.need_update = True

            if self.debug:
                dpg.set_value("_log_pose", str(self.cam.pose))

        def callback_camera_wheel_scale(sender, app_data):

            if not dpg.is_item_focused("_primary_window"):
                return

            delta = app_data

            self.cam.scale(delta)
            self.need_update = True

            if self.debug:
                dpg.set_value("_log_pose", str(self.cam.pose))

        def callback_camera_drag_pan(sender, app_data):

            if not dpg.is_item_focused("_primary_window"):
                return

            dx = app_data[1]
            dy = app_data[2]

            self.cam.pan(dx, dy)
            self.need_update = True

            if self.debug:
                dpg.set_value("_log_pose", str(self.cam.pose))

        with dpg.handler_registry():
            dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Left, callback=callback_camera_drag_rotate)
            dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale)
            dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Right, callback=callback_camera_drag_pan)

        dpg.create_viewport(title='3D Gaussian Rendering Viewer', width=self.W, height=self.H, resizable=False)

        ### global theme
        with dpg.theme() as theme_no_padding:
            with dpg.theme_component(dpg.mvAll):
                # set all padding to 0 to avoid scroll bar
                dpg.add_theme_style(dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core)
                dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core)
                dpg.add_theme_style(dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core)

        dpg.bind_item_theme("_primary_window", theme_no_padding)

        dpg.setup_dearpygui()
        dpg.show_viewport()


def scene_composition(scene_dict: dict, dataset: ModelParams):
    gaussians_list = []
    for scene in scene_dict:
        gaussians = GaussianModel(dataset.sh_degree, render_type="neilf")
        
        gaussians.load_ply(scene_dict[scene]["path"])

        torch_transform = torch.tensor(scene_dict[scene]["transform"], device="cuda").reshape(4, 4)
        gaussians.set_transform(transform=torch_transform)

        gaussians_list.append(gaussians)

    gaussians_composite = GaussianModel.create_from_gaussians(gaussians_list, dataset)
    n = gaussians_composite.get_xyz.shape[0]
    print(f"Totally {n} points loaded.")

    gaussians_composite._visibility_rest = (
        torch.nn.Parameter(torch.cat(
            [gaussians_composite._visibility_rest.data,
             torch.zeros(n, 5 ** 2 - 4 ** 2, 1, device="cuda", dtype=torch.float32)],
            dim=1).requires_grad_(True)))

    gaussians_composite._incidents_dc.data[:] = 0
    gaussians_composite._incidents_rest.data[:] = 0

    return gaussians_composite


def comp_ground(scene_dict: dict, dataset: ModelParams, base_gaussians:GaussianModel):
    gaussians_list = []
    gaussians_list.append(base_gaussians)
    


    ground_gaussians = GaussianModel(dataset.sh_degree, render_type="neilf")
    

    ground_gaussians.load_ply(scene_dict['ground']["path"])
    torch_transform = torch.tensor(scene_dict['ground']["transform"], device="cuda").reshape(4, 4)
    ground_gaussians.set_transform(transform=torch_transform)
    means3d=ground_gaussians.get_xyz
    # print("max_x",means3d[:,0].max())
    # print("max_y",means3d[:,1].max())
    # print("max_z",means3d[:,2].max())
    # print("min_x",means3d[:,0].min())
    # print("max_y",means3d[:,1].min())
    # print("max_z",means3d[:,2].min())
    # print(aaaaaa)
    gaussians_list.append(ground_gaussians)
    # for scene in scene_dict:
    #     gaussians = GaussianModel(dataset.sh_degree, render_type="neilf")
        
    #     gaussians.load_ply(scene_dict[scene]["path"])

    #     torch_transform = torch.tensor(scene_dict[scene]["transform"], device="cuda").reshape(4, 4)
    #     gaussians.set_transform(transform=torch_transform)

    #     gaussians_list.append(gaussians)

    gaussians_composite = GaussianModel.create_from_gaussians(gaussians_list, dataset)
    n = gaussians_composite.get_xyz.shape[0]
    print(f"Totally {n} points loaded.")

    gaussians_composite._visibility_rest = (
        torch.nn.Parameter(torch.cat(
            [gaussians_composite._visibility_rest.data,
             torch.zeros(n, 5 ** 2 - 4 ** 2, 1, device="cuda", dtype=torch.float32)],
            dim=1).requires_grad_(True)))

    gaussians_composite._incidents_dc.data[:] = 0
    gaussians_composite._incidents_rest.data[:] = 0

    return gaussians_composite


def load_json_config(json_file):
    if not os.path.exists(json_file):
        return None

    with open(json_file, 'r', encoding='UTF-8') as f:
        load_dict = json.load(f)

    return load_dict

def custom_cam2(cam):
        w2c = cam.view
        R = w2c[:3, :3].T
        T = w2c[:3, 3]
        down = 1 #TODO
        H, W = cam.H // down, cam.W // down
        fovy = cam.fovy * np.pi / 180
        fovx = fovy * W / H
        custom_cam = Camera(colmap_id=0, R=R, T=-T,
                            FoVx=fovx, FoVy=fovy, fx=None, fy=None, cx=None, cy=None,
                            image=torch.zeros(3, H, W), image_name=None, uid=0)
        return custom_cam

if __name__ == '__main__':
    # Set up command line argument parser
    parser = ArgumentParser(description="Testing script parameters")
    model = ModelParams(parser)
    pipeline = PipelineParams(parser)
    parser.add_argument("--iteration", default=-1, type=int)
    parser.add_argument("--skip_train", action="store_true")
    parser.add_argument("--skip_test", action="store_true")
    parser.add_argument('-t', '--type', choices=['render','neilf'], default='render')
    parser.add_argument("--quiet", action="store_true")
    parser.add_argument("-c", "--checkpoint", type=str, default=None,
                        help="resume from checkpoint")
    parser.add_argument("--scale", type=int, default=1)
    parser.add_argument('--hdr2ldr', action="store_true")
    parser.add_argument('--light_file', type=str, help="light file path")
    parser.add_argument('--light_deg', type=int, help="light sh degree")


    
    
    
    args = parser.parse_args()
    light_shs_file = args.light_file
    light_deg = args.light_deg
    if light_shs_file is None:
        # /d/codes/SphericalHarmonicsLighting/data8/
        light_shs_file = 'D:/codes/SphericalHarmonicsLighting/data8/city/coefficients.txt'
    # light_shs_file = 'D:/codes/SphericalHarmonicsLighting/data_sh3/city'
    print("Rendering " + args.model_path)

    # Initialize system state (RNG)
    safe_state(args.quiet)
    dataset = model.extract(args)
    pipe = pipeline.extract(args)


    scene_config_file = "D:\\r3d\Relightable3DGaussian\configs\\nerf_syn_lego_our\\transform.json"
    scene_dict = load_json_config(scene_config_file)



    base_gaussians = GaussianModel(dataset.sh_degree, render_type=args.type)
    
    pbr_kwargs = dict()
    pbr_kwargs['sample_num'] = pipe.sample_num
    checkpoints = glob.glob(os.path.join(args.model_path, "chkpnt*.pth"))
    if args.checkpoint is not None or len(checkpoints) > 0:
        if args.checkpoint is not None:
            checkpoint = args.checkpoint
        else:
            checkpoint = sorted(checkpoints, key=lambda x: int(x.split("chkpnt")[-1].split(".")[0]))[-1]
        (model_params, first_iter) = torch.load(checkpoint)
        base_gaussians.create_from_ckpt(checkpoint, restore_optimizer=False)

        env_checkpoint = checkpoint.split("chkpnt")[0] + "env_light_chkpnt" + checkpoint.split("chkpnt")[-1]
        if os.path.exists(env_checkpoint):
            env_light = DirectLightEnv(dataset.global_shs_degree)
            env_light.create_from_ckpt(env_checkpoint, restore_optimizer=False)

            pbr_kwargs["env_light"] = env_light
        else:
            print("cannot find env_light_checkpoint at {}, and env light will be ignore.".format(env_checkpoint))
            
        gamma_checkpoint = checkpoint.split("chkpnt")[0] + "gamma_chkpnt" + checkpoint.split("chkpnt")[-1]
        if os.path.exists(gamma_checkpoint):
            gamma_transform = LearningGammaTransform(True)
            gamma_transform.create_from_ckpt(gamma_checkpoint, restorsibilitye_optimizer=False)

            pbr_kwargs['gamma'] = gamma_transform
        else:
            print("cannot find gamma_checkpoint at {}, and gamma transform will be ignore.".format(gamma_checkpoint))
        # gaussians
        # gaussians = comp_ground(scene_dict, dataset, base_gaussians)
        gaussians = scene_composition(scene_dict, dataset)
    else:
        gaussians = scene_composition(scene_dict, dataset)
        # pass
        # if args.iteration == -1:
        #     loaded_iter = searchForMaxIteration(os.path.join(args.model_path, "point_cloud"))
        # else:
        #     loaded_iter = args.loaded_iter
        # gaussians.load_ply(
        #     os.path.join(args.model_path, "point_cloud", "iteration_" + str(loaded_iter), "point_cloud.ply"))
    # print("gui")
    gaussians.finetune_visibility(iterations=1000)
    # render_fn = render_fn_dict[args.type]
    render_fn = render_fn_dict['neilf_composite_prt_rotate']    
    bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
    background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

    if os.path.exists(os.path.join(args.model_path, "cameras.json")):
        with open(os.path.join(args.model_path, "cameras.json"), 'r') as file:
            cam = JSON_to_camera(json.load(file)[0])
        c2w = cam.c2w.detach().cpu().numpy()
        H, W = int(cam.image_height / args.scale), int(cam.image_width / args.scale)
        fovy = cam.FoVy

        if fovy is None:
            fovy = focal2fov(cam.fy, cam.image_height)
    else:
        H, W = 800, 800
        fovy = 50 * np.pi / 180
        c2w = np.array([
            [0.0, 0.0, -1.0, 2.0],
            [1.0, 0.0, 0.0, 0.0],
            [0.0, -1.0, 0.0, 0.0],
            [0.0, 0.0, 0.0, 1.0]
        ])
    center = gaussians.get_xyz.mean(dim=0).detach().cpu().numpy()
    prt=True
    occ=[]
    render_kwargs = {
        "pc": gaussians,
        "pipe": pipe,
        "bg_color": background,
        "is_training": False,
        "dict_params": pbr_kwargs,     
        "bake": False,
        "light_transport":None,
        "precompute":True,
        "light_shs_file": light_shs_file,
        "light_deg": light_deg,
        "occ":None
    }
    if prt:
        W = W
        H = H
        #self.debug = debug
        rot = c2w[:3, :3]
        translate = c2w[:3, 3] - center
        #self.render_fn = render_fn
        #self.render_kwargs = render_kwargs
        #self.use_hdr2ldr = use_hdr2ldr
        
        cam = OrbitCamera(W, H, fovy=fovy * 180 / np.pi, rot=rot, translate=translate, center=center)
        # import pdb; pdb.set_trace()
        render_pkg = render_fn(viewpoint_camera=custom_cam2(cam), **render_kwargs)
        light_transport=render_pkg["precompute_trans"]
        idx = 0
        occ=render_pkg["occ"]
        render_kwargs = {
            "pc": gaussians,
            "pipe": pipe,
            "bg_color": background,
            "is_training": False,
            "dict_params": pbr_kwargs,     
            "bake": False,
            "light_transport":light_transport,
            "precompute":False,
            "idx":idx,
            "light_shs_file": light_shs_file,
            "light_deg": light_deg,
            "occ":occ
        }
    else:
        render_kwargs = {
            "pc": gaussians,
            "pipe": pipe,
            "bg_color": background,
            "is_training": False,
            "dict_params": pbr_kwargs
        }

    windows = GUI(H, W, fovy,
                  c2w=c2w, center=center,
                  render_fn=render_fn, render_kwargs=render_kwargs,
                  mode='pbr', use_hdr2ldr=args.hdr2ldr)

    while True:
        windows.render()
