import os
import argparse
import json
import shutil
from utils import ensure_dirs


class Config(object):
    """Base class of Config, provide necessary hyperparameters. 
    """
    def __init__(self, phase = 'train'):
        self.is_train = phase == "train"

        # init hyperparameters and parse from command-line
        parser, args = self.parse()

        # set as attributes
        print("----Experiment Configuration-----")
        for k, v in args.__dict__.items():
            print("{0:20}".format(k), v)
            self.__setattr__(k, v)

        # experiment paths
        self.exp_dir = os.path.join(self.proj_dir, self.exp_name)
        self.log_dir = os.path.join(self.exp_dir, 'log')
        self.model_dir = os.path.join(self.exp_dir, 'model')
        self.results_dir = os.path.join(self.exp_dir, 'results')
        ensure_dirs([self.log_dir, self.model_dir, self.results_dir])

        # GPU usage
        if args.gpu_ids is not None:
            os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_ids)

        # load saved config if not training
        if not self.is_train:
            assert os.path.exists(self.exp_dir)
            config_path = os.path.join(self.exp_dir, 'config.json')
            print(f"Load saved config from {config_path}")
            with open(config_path, 'r') as f:
                saved_args = json.load(f)
            for k, v in saved_args.items():
                if not hasattr(self, k):
                    self.__setattr__(k, v)
            return

        if args.ckpt is None and os.path.exists(self.exp_dir):
            print('Experiment log/model already exists. Overwrite.')

        # save this configuration for backup
        backup_dir = os.path.join(self.exp_dir, "backup")
        ensure_dirs(backup_dir)
        os.system(f"cp *.py {backup_dir}/")
        with open(os.path.join(self.exp_dir, 'config.json'), 'w') as f:
            json.dump(args.__dict__, f, indent=2)

    def parse(self):
        """initiaize argument parser. Define default hyperparameters and collect from command-line arguments."""
        parser = argparse.ArgumentParser()
        
        # basic configuration
        self._add_basic_config_(parser)

        if self.is_train:
            # model configuration
            self._add_network_config_(parser)

            # training or testing configuration
            self._add_training_config_(parser)
        else:
            self._add_testing_config_(parser)

        args = parser.parse_args()
        return parser, args

    def _add_basic_config_(self, parser):
        """add general hyperparameters"""
        group = parser.add_argument_group('basic')
        group.add_argument('--proj_dir', type=str, default="checkpoints", 
            help="path to project folder where models and logs will be saved")
        group.add_argument('--exp_name', type=str, default=os.getcwd().split('/')[-1], help="name of this experiment")
        group.add_argument('-g', '--gpu_ids', type=str, default=0, help="gpu to use, e.g. 0  0,1,2. CPU not supported.")

    def _add_network_config_(self, parser):
        """add hyperparameters for network architecture"""
        group = parser.add_argument_group('network')
        group.add_argument('--network', type=str, default='siren', choices=['siren', 'grid'])
        group.add_argument('--sdim', type=int, default=1, help='spatial dimension')
        group.add_argument('--num_hidden_layers', type=int, default=3)
        group.add_argument('--hidden_features', type=int, default=256)
        group.add_argument('--nonlinearity',type=str, default='sine')

    def _add_training_config_(self, parser):
        """training configuration"""
        group = parser.add_argument_group('training')
        group.add_argument('--ckpt', type=str, default=-1, required=False, help="desired checkpoint to restore")
        group.add_argument('--vis_frequency', type=int, default=500, help="visualize output every x iterations")
        group.add_argument('--max_n_iters', type=int, default=2000, help='number of iterations to train every time step')
        group.add_argument('--lr', type=float, default=1e-4, help='initial learning rate')
        group.add_argument('--early_stop', action='store_true', help="early_stopping")
        
        group.add_argument('--dt', type=float, default=0.05, help='time step size')
        group.add_argument('-T','--n_timesteps', type=int, default=30, help='number of time steps')
        group.add_argument('-O','--offset', type=float, default=0, help='offset center for initial condition')
        group.add_argument('-L','--length', type=float, default=1.0, help='half of field length')
        group.add_argument('--vel', type=float, default=0.25, help='constant velocity value')
        group.add_argument('-sr', '--sample_resolution', type=int, default=128, help='number of samples per iterations')
        group.add_argument('-vr', '--vis_resolution', type=int, default=500, help='number of samples per iterations')
        group.add_argument('--fps', type=int, default=10)

        group.add_argument('--src', type=str, default=None, help='which example to use', required=True)
        group.add_argument('--time_integrator', type=str, default='midpoint', choices=['explicit', 'implicit', 'midpoint'])
        group.add_argument('--boundary_cond', type=str, default='zero', choices=['zero', 'none'])

        group.add_argument('--save_h5', action='store_true', help="save grid values as h5 file")

    def _add_testing_config_(self, parser):
        """testing configuration"""
        group = parser.add_argument_group('testing')
        group.add_argument('-vr', '--vis_resolution', type=int, default=32)
        group.add_argument('--fps', type=int, default=10)
