"""
Helpers for setting up experiments  
"""
import os
import copy
import torch

from os.path import join

def get_updated_config(args, base_config, exp_args):
    input_transform = lambda x: x
    input_dim = 1

    if exp_args['num_shift_kernel'] > 0 and exp_args['num_diagonal_kernel'] == 0:
        exp_args['learn_a'] = False
        exp_args['learn_theta'] = False
        
    _config = copy.deepcopy(base_config)
    
    exp_args['d_model'] = exp_args['n_hippos']
    _config.model.update({k: exp_args[k] 
                          for k in ['d_model', 'n_layers']})
    _config.model.layer.update({k: exp_args[k] 
                                for k in ['n_hippos',
                                          'd_state', 
                                          'linear',
                                          'identity',
                                          'feedforward',
                                          'num_diagonal_kernel',
                                          'num_shift_kernel',
                                          # Diagonal
                                          'learn_a',
                                          'learn_theta',
                                          'trap_rule',
                                          'zero_order_hold',
                                          'unconstrained_a',
                                          'skip_connection',
                                          'skip_connection_companion',
                                          'skip_connection_companion_fixed']})
    
    _config.optimizer.update({k: exp_args[k] 
                              for k in ['lr', 'weight_decay']})
    return _config


def get_updated_config_from_argparse(args, base_config):
    input_transform = lambda x: x
    input_dim = 1

    if args.num_shift_kernel > 0 and args.num_diagonal_kernel == 0:
        args.learn_a = False
        args.learn_theta = False
        
    if args.feedforward is False:
        args.d_ffn = 0
        
    _config = copy.deepcopy(base_config)
    args.n_hippos = args.d_model
    _config.model.update({k: vars(args)[k] 
                          for k in ['d_model', 'n_layers']})
    _config.model.layer.update({k: vars(args)[k] 
                                for k in ['n_hippos',
                                          'd_state',
                                          'mimo',
                                          'linear',
                                          'identity',
                                          'feedforward',
                                          'd_ffn',
                                          'num_diagonal_kernel',
                                          'num_shift_kernel',
                                          # Diagonal
                                          'learn_a',
                                          'learn_theta',
                                          'trap_rule',
                                          'zero_order_hold',
                                          'unconstrained_a',
                                          'skip_connection',
                                          'skip_connection_companion',
                                          'skip_connection_companion_fixed',
                                          'discrete',
                                          # Other
                                          'replicate'
                                          # Shift                             
                                         ]})
    _config.optimizer.update({k: vars(args)[k] 
                              for k in ['lr', 'weight_decay']})
    return _config
    


def init_experiment(config, args, prefix='s4_simple-d=arima',
                    network_config=None):
    args.device = (torch.device('cuda:0') if torch.cuda.is_available()
                   else torch.device('cpu'))
    args.experiment_name = prefix
    args.dataset_name = ''
    
    if args.verbose: print('\n', '-' * 5, 'Dataset Args', '-' * 5)
    for k, v in config.dataset.items():
        # ,
        if k not in ['_name_', 'val_gap', 'test_gap', 'seed', 'seasonal',
                     'target', 'eval_stamp', 'eval_mask', 'timeenc']:
            _k = format_arg(k) if args.dataset not in ['etth', 'ettm', 'ecl', 'weather'] else k[0]
            if isinstance(v, bool):
                v = int(v)
            if args.verbose: print(' - ', k, f'({_k})', v)
            # args.experiment_name += f'-{_k}={v}'  # don't save dataset args to checkpoint
            args.dataset_name += f'{args.dataset}-{_k}={v}' if args.dataset_name == '' else f'-{_k}={v}'
            
    args.dataset_name = args.dataset_name.replace(' ', '').replace(',', '_')
            
    if args.verbose: print('\n', '-' * 5, 'Model Args', '-' * 5)
    
    network_config = config.network_config if network_config is None else network_config
    if network_config is not None:
        try:
            args.experiment_name += f'-mc={network_config["_name_"]}'
        except:
            args.experiment_name += f'-mc={network_config}'
    else:
        for k, v in config.model.items():
            if k not in ['defaults', '_name_', 'prenorm', 
                         'transposed', 'pool', 'layer',
                         'residual', 'dropout', 'norm']:
                _k = format_arg(k)
                if isinstance(v, bool):
                    v = int(v)
                if args.verbose: print(' - ', k, f'({_k})', v)
                args.experiment_name += f'-{_k}={v}'

        if args.verbose: print('\n', '-' * 5, 'Model Layer Args', '-' * 5)
        for k, v in config.model.layer.items():
            if k not in ['_name_', 'channels', 'bidirectional',
                         'postact', 'initializer', 'weight_norm', 
                         'dt_min', 'dt_max', 'n_hippos'  # covered by d_model
                        ]:
                _k = format_arg(k)
                if isinstance(v, bool):
                    v = int(v)
                if args.verbose: print(' - ', k, f'({_k})', v)

                if args.num_diagonal_kernel == 0 and args.num_shift_kernel > 0:
                    if k in ['learn_theta', 'learn_a', 'theta_scale', 'zero_order_hold', 
                             'unconstrained_a']:
                        pass
                    else:
                        args.experiment_name += f'-{_k}={v}'

                # if args.num_shift_kernel == 0 and args.num_diagonal_kernel > 0:
                #     if k in ['ground_truth_b', 'ground_truth_c', 'ground_truth_p',
                #              'learn_b', 'learn_c', 'learn_p', 'c_bias']:
                #         pass
                #     else:
                #         args.experiment_name += f'-{_k}={v}'
                else:
                    args.experiment_name += f'-{_k}={v}'
            
    if args.verbose: print('\n', '-' * 5, 'Dataloader Args', '-' * 5)
    for k, v in config.loader.items():
        if k in ['batch_size']:
            _k = format_arg(k)
            if isinstance(v, bool):
                v = int(v)
            if args.verbose: print(' - ', k, f'({_k})', v)
            args.experiment_name += f'-{_k}={v}'
            
    if args.verbose: print('\n', '-' * 5, 'Optimizer Args', '-' * 5)
    for k, v in config.optimizer.items():
        if k not in ['_name_']:
            _k = format_arg(k)
            if isinstance(v, bool):
                v = int(v)
            if args.verbose: print(' - ', k, f'({_k})', v)
            args.experiment_name += f'-{_k}={v}'
            
    args.experiment_name += f'-std={1 - int(args.no_standardize)}'  
    args.experiment_name += f'-ed={int(args.encoder_decoder)}' 
    
    if network_config is None:
        pass
    else:
        if args.num_shift_kernel != 0:
            args.experiment_name += f'-gtb={int(args.ground_truth_b)}'
            args.experiment_name += f'-lb={int(args.learn_b)}'
            args.experiment_name += f'-gtc={int(args.ground_truth_c)}'
            args.experiment_name += f'-lc={int(args.learn_c)}'
            args.experiment_name += f'-cb={int(args.c_bias)}'
            args.experiment_name += f'-gtp={int(args.ground_truth_p)}'
            args.experiment_name += f'-lp={int(args.learn_p)}'
        if args.chebyshev:
            args.experiment_name += f'-cheb={int(args.chebyshev)}'
    
    args.experiment_name += f'-loss={args.loss}'
    args.experiment_name += f'-tn={args.task_norm[:2]}'
    
    args.experiment_name += f'-bash={int(args.bash)}'
    
    if args.mask_horizon is True:
        args.experiment_name += f'-mask={args.mask_horizon}'
    
    if args.multihorizon != 0:
        args.experiment_name += f'-mh={"_".join([str(h) for h in args.multihorizon])}'
        
    if args.memory_norm != 0:
        args.experiment_name += '-mn=1'

    args.experiment_name += f'-r={args.replicate}'
    args.experiment_name += f'-s={args.seed}'
    args.best_train_metric = 1e10  # RMSE
    args.best_val_metric = 1e10  # RMSE
    
    # Dataset setup
    dataset_dir = join(args.checkpoint_path, args.dataset_name)
    if not os.path.isdir(dataset_dir):
        os.makedirs(dataset_dir)
    args.checkpoint_path = dataset_dir
    
    args.best_train_checkpoint_path = join(
        args.checkpoint_path,
        f'btrain_cp-{args.experiment_name}.pth')
    args.best_val_checkpoint_path = join(
        args.checkpoint_path,
        f'bval_cp-{args.experiment_name}.pth')
    
    return args.experiment_name


def format_arg(arg):
    if '_' in arg:
        try:
            abbrev = ''.join([str(a[0]) for a in arg.split('_')])
        except:
            print(arg)
            abbrev = arg
    else:
        abbrev = arg[:3]
    return abbrev
            
    
    