#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

from argparse import ArgumentParser, Namespace
import sys
import os


class GroupParams:
    pass


class ParamGroup:
    def __init__(self, parser: ArgumentParser, name: str, fill_none=False):
        group = parser.add_argument_group(name)
        for key, value in vars(self).items():
            shorthand = False
            if key.startswith("_"):
                shorthand = True
                key = key[1:]
            t = type(value)
            value = value if not fill_none else None
            if shorthand:
                if t == bool:
                    group.add_argument(
                        "--" + key, ("-" + key[0:1]), default=value, action="store_true")
                else:
                    group.add_argument(
                        "--" + key, ("-" + key[0:1]), default=value, type=t)
            else:
                if t == bool:
                    group.add_argument(
                        "--" + key, default=value, action="store_true")
                else:
                    group.add_argument("--" + key, default=value, type=t)

    def extract(self, args):
        group = GroupParams()
        for arg in vars(args).items():
            if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
                setattr(group, arg[0], arg[1])
        return group


class ModelParams(ParamGroup):
    def __init__(self, parser, sentinel=False):
        self.sh_degree = 1
        self._source_path = ""
        self._model_path = ""
        self._images = "images"
        self._resolution = -1
        self._white_background = False
        self.data_device = "cuda"
        self.eval = True
        self.rot_type = "6d"
        self.view_dependent = True
        self.data_type = "Tanks"
        self.depth_model_type = "dpt"
        self.mode = "train"
        self.traj_opt = "bspline"
        self.inversion_prompt = 'A horse sculpture on the sunny day'
        self.seg_prompt = ''
        self.mask_attr = 0
        self.downsample = False
        self.sd_version = '2.1' # ['1.5', '2.0', '2.1', 'ControlNet', 'depth']
        self.save_steps = 50
        self.batch_size = 40
        self.seed = 1
        self.n_inversion_steps=50
        self.steps = 50
        self.optim_guidance_scale = 5.5
        self.enable_spatial = True
       
        self.guidance = dict(guidance_scale=7.5, n_timesteps=50, n_timesteps_sp=50,
                             prompt='A horse sculpture on the snowy day',
                             negative_prompt='ugly, blurry, low res, unrealistic, unaesthetic, unconsistent, bad anatomy, bad hands')

        self.guidance_prompt = ''
        self.num_inference_steps = 20
        self.pnp_params = dict(pnp_attn_t=0.5, pnp_f_t=0.8)
        self.test_prompt = ''
        super().__init__(parser, "Loading Parameters", sentinel)

    def extract(self, args):
        g = super().extract(args)
        g.source_path = os.path.abspath(g.source_path)
        return g


class PipelineParams(ParamGroup):
    def __init__(self, parser):
        self.convert_SHs_python = False
        self.compute_cov3D_python = False
        self.debug = False
        # self.mode = "color"
        self.use_gt_pcd = False
        self.use_mask = False
        self.use_ref_img = False
        self.init_mode = "rand"
        self.use_mono = True
        self.interval = 15
        self.expname = ""
        self.use_sampon = False
        self.refine = False
        self.distortion = False
        super().__init__(parser, "Pipeline Parameters")


class OptimizationParams(ParamGroup):
    def __init__(self, parser):
        self.iterations = 30_000
        self.test_iterations = [ ]
        self.max_iterations = 40_000
        self.position_lr_init = 0.00016
        self.position_lr_final = 0.0000016
        self.position_lr_delay_mult = 0.01
        self.position_lr_max_steps = 30_000
        self.feature_lr = 0.0025
        self.featuret_lr = 0.001
        self.opacity_lr = 0.05
        self.scaling_lr = 0.005
        self.rotation_lr = 0.001
        self.omega_lr = 0.0001
        self.movelr = 3.5
        self.t_lr_init = 0.0008
        self.scaling_t_lr = 0.002
        self.velocity_lr = 0.001
        self.position_lr_init = 0.00016
        self.position_lr_final = 0.0000016
        self.position_lr_delay_mult = 0.01
        self.percent_dense = 0.01
        self.lambda_dssim = 0.2
        self.lambda_depth = 0.05
        self.lambda_flow = 0.00
        self.lambda_opacity_entropy = 0.05
        self.lambda_inv_depth = 0.001
        self.lambda_self_supervision = 0.5
        self.lambda_lpips = 0.2
        self.use_edit = True
        self.use_optim = False
        self.n_views = 1
        self.elevation = 0
        self.radius = 1.0
        self.depth_loss_type = "invariant"
        self.match_method = "dense"
        self.densification_interval = 100
        self.densify_interval = 500
        self.prune_interval = 2000
        self.opacity_reset_interval = 3000
        self.densify_from_iter = 500
        self.densify_until_iter = 15_000
        self.reset_until_iter = 15_000
        self.densify_grad_threshold = 0.0002
        self.densify_grad_t_threshold = 0.002
        self.densify_until_num_points = 3000000
        self.edit_grad = 5
        self.edit_min_opacity = 0.05
        self.edit_densify_percent = 0.01
        self.mask_thres = 0.8
        self.max_strength = 1.0
        self.min_strength = 0.1
        super().__init__(parser, "Optimization Parameters")


def get_combined_args(parser: ArgumentParser):
    cmdlne_string = sys.argv[1:]
    cfgfile_string = "Namespace()"
    args_cmdline = parser.parse_args(cmdlne_string)

    try:
        cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args")
        print("Looking for config file in", cfgfilepath)
        with open(cfgfilepath) as cfg_file:
            print("Config file found: {}".format(cfgfilepath))
            cfgfile_string = cfg_file.read()
    except TypeError:
        print("Config file not found at")
        pass
    args_cfgfile = eval(cfgfile_string)

    merged_dict = vars(args_cfgfile).copy()
    for k, v in vars(args_cmdline).items():
        if v != None:
            merged_dict[k] = v
    return Namespace(**merged_dict)
