import argparse

parser = argparse.ArgumentParser(description='State Space Time Series')

# Time Series Task
## Forecasting
parser.add_argument('--lag', type=int, default=0)  # Loads from the dataset config yaml
parser.add_argument('--horizon', type=int, default=0)  # Loads from the dataset config yaml
parser.add_argument('--scale', type=float, default=1)
parser.add_argument('--inverse_data', type=int, default=None) 
parser.add_argument('--task_norm', type=str, default='none',
                    choices=['mean', 'last', 'first', 'none', 'ts_mean', 'ts_last'],
                    help='Transform input data to deal with batch-wise mean shifts')
parser.add_argument('--multihorizon', nargs='+', type=int, default=0)
parser.add_argument('--mask_horizon', default=False, action='store_true')

# Data setup
parser.add_argument('--nobs_per_ts', type=int, default=600)
parser.add_argument('--n_ts', type=int, default=1)
## Data loading
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--num_workers', type=int, default=4)

# Dataset - ARIMA
parser.add_argument('--dataset', type=str, default='arima')
parser.add_argument('--dataset_seed', type=int, default=4)
parser.add_argument('--features', type=str, default='S')
parser.add_argument('--no_standardize', action='store_true',
                    default=False)
parser.add_argument('--p', type=int, default=5)
parser.add_argument('--d', type=int, default=0)
parser.add_argument('--q', type=int, default=0)
parser.add_argument('--c', type=int, default=0)
parser.add_argument('--initial_x', type=float, default=0,
                    help='Starting value')

## Seasonality -> Assume just month for now
parser.add_argument('--S', type=int, default=0)
parser.add_argument('--P', type=int, default=1)
parser.add_argument('--D', type=int, default=0)
parser.add_argument('--Q', type=int, default=1)
parser.add_argument('--C', type=int, default=0)

# Training hparams
parser.add_argument('--max_epochs', type=int, default=25)
# Optimizer
parser.add_argument('--optimizer', type=str, default='adamw')
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--weight_decay', type=float, default=0.0)
# Scheduler
parser.add_argument('--scheduler', type=str, default='timm_cosine')

# Loss objective
parser.add_argument('--loss', type=str, default='rmse')  # currently does root of this

# Model
parser.add_argument('--model', type=str, default='spacetime')
# Network config
parser.add_argument('--network_config', type=str, default='')

parser.add_argument('--d_model', type=int, default=64, help='Number of hippos / heads')
parser.add_argument('--n_layers', type=int, default=1, help='Number of layers in the entire model')
parser.add_argument('--model_dropout', type=float, default=0.0)
parser.add_argument('--model_norm', type=str, default='none')
parser.add_argument('--prenorm', type=int, default=0,
                    help='Whether to normalize input before layer forward pass')
parser.add_argument('--residual', type=str, default='R')
parser.add_argument('--encoder_decoder', action='store_true', default=False)

# SSM Layers
parser.add_argument('--model_layer', type=str, default='spacetime')
parser.add_argument('--d_state', type=int, default=None, help='Hidden state size')
parser.add_argument('--big_d', action='store_true', default=False)

parser.add_argument('--channels', type=int, default=1)
parser.add_argument('--layer_lr', type=float, default=0.01)
parser.add_argument('--layer_dropout', type=float, default=0.0)
parser.add_argument('--activation', type=str, default='gelu')

# SSM Kernels
## MIMO
parser.add_argument('--mimo', type=int, default=1)
parser.add_argument('--memory_norm', type=int, default=0)

## Diagonal 
parser.add_argument('--num_diagonal_kernel', type=int, default=1)
parser.add_argument('--use_initial', action='store_true', default=False)
parser.add_argument('--learn_a', action='store_true', default=False)
parser.add_argument('--learn_theta', action='store_true', default=False)
parser.add_argument('--trap_rule', action='store_true', default=False)
parser.add_argument('--zero_order_hold', action='store_true', default=False)
parser.add_argument('--theta_scale', action='store_true', default=False)
parser.add_argument('--skip_connection', action='store_true', default=False)
parser.add_argument('--unconstrained_a', action='store_true', default=False)

### Discrete Diagonal 
parser.add_argument('--discrete', action='store_true', default=False)

## Shift + Companion
parser.add_argument('--num_shift_kernel', type=int, default=1)
parser.add_argument('--skip_connection_companion', action='store_true', default=False)
parser.add_argument('--skip_connection_companion_fixed', action='store_true', default=False)
parser.add_argument('--learn_b', action='store_true', default=False,
                    help='Only for learning b in shift matrix')
parser.add_argument('--learn_c', action='store_true', default=False,
                    help='Learn c or not')
parser.add_argument('--c_bias', action='store_true', default=False,
                    help='Learn with bias in c')
parser.add_argument('--ground_truth_c', action='store_true', default=False)
parser.add_argument('--ground_truth_b', action='store_true', default=False)

### Companion-specific
parser.add_argument('--learn_p', action='store_true', default=False,
                    help='Only for learning p in companion matrix')
parser.add_argument('--ground_truth_p', action='store_true', default=False)

### Chebyshev representation for Companion
parser.add_argument('--chebyshev', action='store_true', default=False)

# Layer after SSM kernel
parser.add_argument('--linear', action='store_true', default=False)
parser.add_argument('--identity', action='store_true', default=False)

# FFN after SSM layer
parser.add_argument('--feedforward', action='store_true', default=False)
parser.add_argument('--d_ffn', type=int, default=1024)


# Saving
parser.add_argument('--log_dir', type=str, default='./logs')
parser.add_argument('--checkpoint_path', type=str, default='./checkpoints')

# Misc.
parser.add_argument('--no_cuda', action='store_true', default=False)
parser.add_argument('--no_wandb', action='store_true', default=False)
parser.add_argument('--verbose', action='store_true', default=False)
parser.add_argument('--replicate', type=int, default=0)
parser.add_argument('--seed', type=int, default=0)

parser.add_argument('--bash', action='store_true', default=False,
                    help='Indicate whether run was executed with bash')


argparse_args = parser.parse_args()
