"""
Config initialization
"""
from os.path import join
from omegaconf import OmegaConf  
from utils.config import print_config

from models.spacetime.network import process_network_config


def initialize_configs(args, root_dir='./configs/', verbose=False):
    TRAINER_CNAME = 'default'
    LOADER_CNAME = 'torch'
    
    if 'ett' in args.dataset:  # etth1, etth2, ettm1, ettm2
        # Filter out variant key
        args.variant = int(args.dataset[-1])
        args.dataset = args.dataset[:-1]  
    elif 'monash' in args.dataset:
        args.variant = args.dataset.split('monash_')[-1].split('-')[0]
        
        if 'nval' in args.dataset:  # ex.) monash_solar_weekly-nval_2
            args.nval = int(args.dataset.split('nval_')[-1].split('-')[0])
            
        else:
            args.nval = None
            
        args.dataset = 'monash'
    else:
        args.variant = 0  
        
    if args.big_d:
        args.d_state = args.lag

    config = OmegaConf.create()
    config.dataset = OmegaConf.load(join(root_dir, 
                                         f'dataset/{args.dataset}.yaml'))
    config.model = OmegaConf.load(join(root_dir, 
                                       f'model/{args.model}.yaml'))
    config.model.layer = OmegaConf.load(join(root_dir,
                                             f'model/layer/{args.model_layer}.yaml'))
    config.optimizer = OmegaConf.load(join(root_dir,
                                           f'optimizer/{args.optimizer}.yaml'))
    config.scheduler = OmegaConf.load(join(root_dir,
                                           f'scheduler/{args.scheduler}.yaml'))
    config.trainer = OmegaConf.load(join(root_dir,
                                         f'trainer/{TRAINER_CNAME}.yaml'))
    config.loader = OmegaConf.load(join(root_dir, 
                                        f'loader/{LOADER_CNAME}.yaml'))
    
    try:
        if args.network_config != '':
            config.network_config = OmegaConf.load(
                join(root_dir, f'model/network_configs/{args.network_config}.yaml'))
            config.network_config = process_network_config(config.network_config, args)
            print_config(config.network_config)
            
        else:
            config.network_config = None
    except Exception as e:
        print(e)
        config.network_config = None
    
    if args.dataset == 'arima':
        try:
            _overwrite_arima_config(config, args)
        except Exception as e:
            print('ARIMA dataset config exception:', e)
            pass
        
    elif 'ett' in args.dataset or 'ecl' in args.dataset or 'weather' in args.dataset or 'ili' in args.dataset or 'exchange' in args.dataset or 'traffic' in args.dataset:
        try:
            _overwrite_informer_config(config, args)
        except Exception as e:
            print(f'{args.dataset} dataset config exception:', e)
            pass
        
    elif 'monash' in args.dataset:
        try:
            _overwrite_monash_config(config, args)
        except Exception as e:
            print(f'Monash dataset ({args.dataset}) config exception:', e)
            pass
        
    try:
        _overwrite_model_config(config, args)
    except Exception as e:
        print('Model config exception:', e)
        pass
    try:
        _overwrite_layer_config(config, args)
    except Exception as e:
        print('Model layer config exception:', e)
        pass
    try:
        _overwrite_optimizer_config(config, args)
    except Exception as e:
        print('Optimizer config exception:', e)
        pass
    try:
        _overwrite_scheduler_config(config, args)
    except Exception as e:
        print('Scheduler config exception:', e)
        pass
    try:
        _overwrite_trainer_config(config, args)
    except Exception as e:
        print('Trainer config exception:', e)
        pass
    try:
        _overwrite_loader_config(config, args)
    except Exception as e:
        print('Loader config exception:', e)
        pass
        
    if verbose:
        print_config(config)
    return config



def _overwrite_monash_config(config, args):
    config.dataset.dataset_name = args.variant  
    # Defaults
    config.dataset.val_frac = 0.0  # 0.1
    config.dataset.save_processed = False  # True  
    if args.nval is not None:
        config.dataset.nval = args.nval
            

def _overwrite_informer_config(config, args):
    config.dataset.variant = args.variant
    if args.lag != 0 and args.horizon != 0:
        config.dataset['size'] = [args.lag, args.evaluate_horizon, args.evaluate_horizon]
    config.dataset['scale'] = bool(args.scale)
    if args.inverse_data is not None:
        config.dataset['inverse'] = bool(args.inverse_data)
    config.dataset['features'] = args.features

            
def _overwrite_arima_config(config, args):
    # Seasonality
    if int(getattr(args, 'S')) > 0:
        seasonal = OmegaConf.create({'M': {}})  # Only do monthly for now
        for arg in ['P', 'D', 'Q', 'C']:
            seasonal.M[arg] = int(getattr(args, arg))
    else:
        seasonal = None
    for arg in ['p', 'd', 'q', 'c', 'initial_x', 'dataset_seed']:
        config.dataset[arg] = int(getattr(args, arg))
    
    config.dataset.seed = args.dataset_seed
    config.dataset.seasonal = seasonal
    config.dataset.nobs_per_ts = args.nobs_per_ts
    config.dataset.horizon = args.horizon
    config.dataset.lag = args.lag
    config.dataset.scale = args.scale
    
    
def _overwrite_model_config(config, args):
    config.model.prenorm = args.prenorm
    config.model.n_layers = args.n_layers
    config.model.d_model = args.d_model
    config.model.residual = args.residual
    config.model.norm = args.model_norm
    config.model.dropout = args.model_dropout
    
    
def _overwrite_layer_config(config, args):
    config.model.layer.d_state = args.d_state
    config.model.layer.channels = args.channels
    config.model.layer.lr = args.layer_lr
    config.model.layer.dropout = args.layer_dropout
    config.model.layer.activation = args.activation
    # Diagonal kernel
    config.model.layer.num_diagonal_kernel = args.num_diagonal_kernel
    config.model.layer.use_initial = args.use_initial
    config.model.layer.learn_a = args.learn_a
    config.model.layer.learn_theta = args.learn_theta
    config.model.layer.trap_rule = args.trap_rule
    config.model.layer.zero_order_hold = args.zero_order_hold
    config.model.layer.theta_scale = args.theta_scale
    # Shift kernel
    config.model.layer.num_shift_kernel = args.num_shift_kernel
    config.model.layer.skip_connection_companion = args.skip_connection_companion
    config.model.layer.skip_connection_companion_fixed = args.skip_connection_companion_fixed
    
    # Testing / debugging replicate
    config.model.layer.replicate = args.replicate
    
    # Horizon + Lag
    config.model.layer.horizon = args.horizon
    config.model.layer.lag = args.lag
    
    
def _overwrite_optimizer_config(config, args):
    config.optimizer.lr = args.lr
    config.optimizer.weight_decay = args.weight_decay
    try:
        config.optimizer.momentum = args.momentum
    except:
        pass
    
def _overwrite_scheduler_config(config, args):
    pass


def _overwrite_trainer_config(config, args):
    config.trainer.max_epochs = args.max_epochs


def _overwrite_loader_config(config, args):
    config.loader.batch_size = args.batch_size
    config.loader.num_workers = args.num_workers
