from argparse import ArgumentParser, Namespace
import sys
import os

class GroupParams:
    pass

class ParamGroup:
    def __init__(self, parser: ArgumentParser, name : str, fill_none = False):
        group = parser.add_argument_group(name)
        for key, value in vars(self).items():
            shorthand = False
            if key.startswith("_"):
                shorthand = True
                key = key[1:]
            t = type(value)
            value = value if not fill_none else None 
            if shorthand:
                if t == bool:
                    group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true")
                else:
                    group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t)
            else:
                if t == bool:
                    group.add_argument("--" + key, default=value, action="store_true")
                else:
                    group.add_argument("--" + key, default=value, type=t)

    def extract(self, args):
        group = GroupParams()
        for arg in vars(args).items():
            if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
                setattr(group, arg[0], arg[1])
        return group

class ModelParams(ParamGroup): 
    def __init__(self, parser, sentinel=False):
        self.sh_degree = 3
        self._source_path = ""
        self._model_path = ""
        self._images = "images"
        self._resolution = -1
        self._white_background = False
        self.data_device = "cuda"
        self.eval = False
        self.train_ours = False
        self.train_GOF = False
        self._kernel_size = 0.0
        # self.use_spatial_gaussian_bias = False
        self.ray_jitter = False
        self.resample_gt_image = False
        self.load_allres = False
        self.sample_more_highres = False
        self.use_decoupled_appearance = False

        # self.is_blender = True
        # self.is_6dof = True
        self.load2gpu_on_the_fly = False
        self.load_time_camera = True
        #self.prismatic = False
        #self.revolute = False
        self.use_mask = False
        self.use_udf = True
        self.color_filter = False
        self.extract_fid = 0.0
        self.init_num = False
        self.use_axis = False
        self.reset_mask_iteration = 600
        self.no_filter = False
        self.use_canonical = False
        self.not_reset_mask = False
        super().__init__(parser, "Loading Parameters", sentinel)

    def extract(self, args):
        g = super().extract(args)
        g.source_path = os.path.abspath(g.source_path)
        return g

class PipelineParams(ParamGroup):
    def __init__(self, parser):
        self.convert_SHs_python = False
        self.compute_cov3D_python = False
        self.compute_view2gaussian_python = False
        self.debug = False
        super().__init__(parser, "Pipeline Parameters")


class SegParams(ParamGroup):
    def __init__(self, parser):
        self.gpu = '0'
        self.num_point = 20000
        self.normal = False
        self.num_votes = 3
        self.category ='USB'
        self.root = ""
        self.num_classes = 1
        self.num_part = 3
        super().__init__(parser, "Segline Parameters")

class OptimizationParams(ParamGroup):
    def __init__(self, parser):
        #self.iterations = 45_000

        self.iterations = 40_000
        self.warm_up = 3_000
        #self.warm_up = 0
        self.deform_lr_max_steps = 40_000
        #self.deform_lr_max_steps = 30_000

        self.position_lr_init = 0.00016
        self.position_lr_final = 0.0000016
        self.position_lr_delay_mult = 0.01
        #self.position_lr_max_steps = 30_000
        self.position_lr_max_steps = 30_000
        self.feature_lr = 0.0025
        self.opacity_lr = 0.05
        #self.scaling_lr = 0.005
        self.scaling_lr = 0.001 #Deformable 3dgs
        self.rotation_lr = 0.001
        self.appearance_embeddings_lr = 0.001
        self.appearance_network_lr = 0.001
        self.percent_dense = 0.01
        self.lambda_dssim = 0.2
        self.lambda_distortion = 100
        self.lambda_depth_normal = 0.05
        self.distortion_from_iter = 15000
        self.depth_normal_from_iter = 15000
        self.densification_interval = 100
        self.opacity_reset_interval = 3000
        self.densify_from_iter = 500
        self.densify_until_iter = 15_000
        #self.densify_until_iter = 30_000
        #self.densify_grad_threshold = 0.0002
        #self.densify_grad_threshold = 0.0007 # deformable 3dgs
        self.densify_grad_threshold = 0.0007 #OURS
        #self.densify_grad_threshold = 0.002
        super().__init__(parser, "Optimization Parameters")

class onestage_OptimizationParams(ParamGroup):
    def __init__(self, parser):
        self.iterations = 30_000
        #self.iterations = 15000
        self.position_lr_init = 0.00016
        self.position_lr_final = 0.0000016
        self.position_lr_delay_mult = 0.01
        #self.position_lr_max_steps = 60_000
        self.position_lr_max_steps = 30_000
        self.feature_lr = 0.0025
        self.opacity_lr = 0.05
        self.scaling_lr = 0.005
        self.rotation_lr = 0.001
        self.appearance_embeddings_lr = 0.001
        self.appearance_network_lr = 0.001
        self.percent_dense = 0.01
        self.lambda_dssim = 0.2
        self.lambda_distortion = 100
        self.lambda_depth_normal = 0.05
        self.distortion_from_iter = 15000
        self.depth_normal_from_iter = 15000
        self.densification_interval = 100
        self.opacity_reset_interval = 3000
        self.densify_from_iter = 500
        self.densify_until_iter = 15_000
        #self.densify_grad_threshold = 0.0007
        self.densify_grad_threshold = 0.0002 # ours
        #self.densify_grad_threshold = 0.00002
        super().__init__(parser, "Optimization Parameters")

def get_combined_args(parser : ArgumentParser):
    cmdlne_string = sys.argv[1:]
    cfgfile_string = "Namespace()"
    args_cmdline = parser.parse_args(cmdlne_string)

    try:
        cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args")
        print("Looking for config file in", cfgfilepath)
        with open(cfgfilepath) as cfg_file:
            print("Config file found: {}".format(cfgfilepath))
            cfgfile_string = cfg_file.read()
    except TypeError:
        print("Config file not found at")
        pass
    args_cfgfile = eval(cfgfile_string)

    merged_dict = vars(args_cfgfile).copy()
    for k,v in vars(args_cmdline).items():
        if v != None:
            merged_dict[k] = v
    return Namespace(**merged_dict)
