import os
import argparse
import json
import shutil
# from utils import ensure_dirs


def ensure_dir(path):
    """
    create path by first checking its existence,
    :param paths: path
    :return:
    """
    if not os.path.exists(path):
        os.makedirs(path)


def ensure_dirs(paths):
    """
    create paths by first checking their existence
    :param paths: list of path
    :return:
    """
    if isinstance(paths, list) and not isinstance(paths, str):
        for path in paths:
            ensure_dir(path)
    else:
        ensure_dir(paths)


class Config(object):
    """Base class of Config, provide necessary hyperparameters. 
    """
    def __init__(self):
        # init hyperparameters and parse from command-line
        parser, args = self.parse()

        # set as attributes
        print("----Experiment Configuration-----")
        for k, v in args.__dict__.items():
            print("{0:20}".format(k), v)
            self.__setattr__(k, v)

        # experiment paths
        self.exp_dir = os.path.join(self.proj_dir, self.exp_name)
        if args.ckpt is None and os.path.exists(self.exp_dir):
            print('Experiment log/model already exists.')
            # response = input('Experiment log/model already exists, overwrite? (y/n) ')
            # if response != 'y':
            #     exit()
            # shutil.rmtree(self.exp_dir)

        self.log_dir = os.path.join(self.exp_dir, 'log')
        self.model_dir = os.path.join(self.exp_dir, 'model')
        self.model_save_dir = os.path.join(self.exp_dir, 'model', self.tag)
        ensure_dirs([self.log_dir, self.model_dir, self.model_save_dir])

        # GPU usage
        if args.gpu_ids is not None:
            os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_ids)

        # create soft link to experiment log directory
        # if not os.path.exists('train_log'):
        #     os.symlink(self.exp_dir, 'train_log')

        # save this configuration
        if args.cont is not True:
            with open(os.path.join(self.exp_dir, 'config.json'), 'w') as f:
                json.dump(args.__dict__, f, indent=2)
            copy_code_dir = os.path.join(self.exp_dir, "code")
            ensure_dirs(copy_code_dir)
            os.system("cp *.py {}".format(copy_code_dir))

    def parse(self):
        """initiaize argument parser. Define default hyperparameters and collect from command-line arguments."""
        parser = argparse.ArgumentParser()
        
        # basic configuration
        self._add_basic_config_(parser)

        # # dataset configuration
        # self._add_dataset_config_(parser)

        # model configuration
        self._add_network_config_(parser)

        # training or testing configuration
        self._add_training_config_(parser)

        args = parser.parse_args()
        return parser, args

    def _add_basic_config_(self, parser):
        """add general hyperparameters"""
        group = parser.add_argument_group('basic')
        group.add_argument('--proj_dir', type=str, default="./project_log/neuralARAP", 
            help="path to project folder where models and logs will be saved")
        group.add_argument('--exp_name', type=str, default=os.getcwd().split('/')[-1], help="name of this experiment")
        group.add_argument('--tag', type=str, default="")
        group.add_argument('-g', '--gpu_ids', type=str, default=0, help="gpu to use, e.g. 0  0,1,2. CPU not supported.")
        group.add_argument('--num_cpu', type=int, default=0)

    def _add_network_config_(self, parser):
        """add hyperparameters for network architecture"""
        group = parser.add_argument_group('network')
        group.add_argument('--network', type=str, default='siren', choices=['siren', 'grid'])
        group.add_argument('--num_hidden_layers', type=int, default=3)
        group.add_argument('--hidden_features', type=int, default=256)
        group.add_argument('--nonlinearity',type=str, default='sine')
        # group.add_argument('--n_levels', type=int, default=4)
        # group.add_argument('--fdim', type=int, default=16)
        # group.add_argument('--fsize', type=int, default=4)

    def _add_training_config_(self, parser):
        """training configuration"""
        group = parser.add_argument_group('training')
        group.add_argument('--continue', dest='cont',  action='store_true', help="continue training from checkpoint")
        group.add_argument('--ckpt', type=str, default=None, required=False, help="desired checkpoint to restore")
        group.add_argument('--save_frequency', type=int, default=1000, help="save models every x steps")
        group.add_argument('--vis_frequency', type=int, default=500, help="visualize output every x iterations")
        group.add_argument('--max_n_iters', type=int, default=2000, help='number of epochs to train per scale')
        # group.add_argument('--gamma', type=float, help='scheduler gamma', default=0.1)
        # group.add_argument('--lr_stepsize', type=int, help='scheduler lr_stepsize', default=10000)
        group.add_argument('--lr', type=float, default=1e-4, help='learning rate, default=0.0005')
        group.add_argument('--early_stop', action='store_true', help="early_stopping")
        
        group.add_argument('--dim', type=int, default=2)
        group.add_argument('--dt', type=float, default=0.01)
        group.add_argument('-T','--n_timesteps', type=int, default=100)
        group.add_argument('--stage', type=str, default=None, choices=['init', 'simulate'], required=True)
        group.add_argument('--sample_resolution', type=int, default=128)
        group.add_argument('--vis_resolution', type=int, default=50)

        group.add_argument('--energy', type=str, nargs='*', 
                            default=['arap', 'kinematics', 'external', 'constraint'],
                            help='The energy to be used.')
        group.add_argument('--sample', type=str, nargs='*',
                            default=['random', 'uniform'],
                            help='The sampling strategy to be used during the training.')
        group.add_argument('-T_ext', '--external_force_timesteps', type=int, default=5)
        group.add_argument('-f_ext_x', '--external_force_x', type=float, default=0)
        group.add_argument('-f_ext_y', '--external_force_y', type=float, default=0)
        group.add_argument('-f_ext_z', '--external_force_z', type=float, default=0)
        group.add_argument('--ratio_constraint', type=float, default=1e3)
        group.add_argument('--ratio_volume', type=float, default=1e1)
        group.add_argument('--ratio_arap', type=float, default=1e0)
        group.add_argument('--ratio_collide', type=float, default=1e0)
        group.add_argument('--ratio_kinematics', type=float, default=1e0)
        group.add_argument('-W_pc', '--write_pointcloud', type=bool, default=False)
        group.add_argument('-save', '--save_model', type=bool, default=False)

        group.add_argument('-fix_right_x', '--constraint_right_offset_x', type=float, default=1e0)
        group.add_argument('-fix_right_y', '--constraint_right_offset_y', type=float, default=0)
        group.add_argument('-fix_right_z', '--constraint_right_offset_z', type=float, default=0)

        group.add_argument('--use_mesh', type=bool, default=False)
        group.add_argument('--mesh_path', type=str, default="./data/woody.obj", help="path to the mesh")

        group.add_argument('--plane_height', type=float, default=-2)

        group.add_argument('-collide_circle_x', '--collide_circle_x', type=float, default=0)
        group.add_argument('-collide_circle_y', '--collide_circle_y', type=float, default=-2e0)
        group.add_argument('-collide_circle_z', '--collide_circle_z', type=float, default=0)
        group.add_argument('-collide_circle_r', '--collide_circle_radius', type=float, default=1)