import argparse

parser = argparse.ArgumentParser(description='State Space Time Series')

# Time Series Task
## Forecasting
parser.add_argument('--lag', type=int, default=1)
parser.add_argument('--horizon', type=int, default=1)
parser.add_argument('--scale', type=float, default=1)
parser.add_argument('--task_norm', type=str, default='mean',
                    help='How to normalize input data. See `state-spaces/src/tasks/tasks.py`')
# Data setup
parser.add_argument('--nobs_per_ts', type=int, default=1000)
parser.add_argument('--n_ts', type=int, default=1)
## Loading
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--num_workers', type=int, default=4)

# Dataset - ARIMA
parser.add_argument('--dataset', type=str, default='arima')
parser.add_argument('--p', type=int, default=1)
parser.add_argument('--d', type=int, default=0)
parser.add_argument('--q', type=int, default=1)
parser.add_argument('--c', type=int, default=0)
## Seasonality -> Assume just month for now
parser.add_argument('--S', type=int, default=0)
parser.add_argument('--P', type=int, default=1)
parser.add_argument('--D', type=int, default=0)
parser.add_argument('--Q', type=int, default=1)
parser.add_argument('--C', type=int, default=0)

# Training hparams
parser.add_argument('--max_epochs', type=int, default=25)
# Optimizer
parser.add_argument('--optimizer', type=str, default='adamw')
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--weight_decay', type=float, default=0.0)
# Scheduler
parser.add_argument('--scheduler', type=str, default='timm_cosine')

# Model
parser.add_argument('--model', type=str, default='s4')
parser.add_argument('--d_model', type=int, default=64, help='Number of hippos / heads')
parser.add_argument('--n_layers', type=int, default=1, help='Number of layers in the entire model')
parser.add_argument('--model_dropout', type=float, default=0.0)
parser.add_argument('--model_norm', type=str, default='none')
parser.add_argument('--prenorm', type=int, default=0,
                    help='Whether to normalize input before layer forward pass')
parser.add_argument('--residual', type=str, default='R')

# Layer
parser.add_argument('--model_layer', type=str, default='s4_simple')
parser.add_argument('--d_state', type=int, default=64, help='Hidden state size')
parser.add_argument('--channels', type=int, default=1)
parser.add_argument('--layer_lr', type=float, default=0.01)
parser.add_argument('--layer_dropout', type=float, default=0.0)
parser.add_argument('--activation', type=str, default='gelu')
## Simple S4
parser.add_argument('--use_initial', action='store_true', default=False)
parser.add_argument('--learn_a', action='store_true', default=False)
parser.add_argument('--unconstrained_a', action='store_true', default=False)
parser.add_argument('--learn_theta', action='store_true', default=False)
parser.add_argument('--trap_rule', action='store_true', default=False)
parser.add_argument('--zero_order_hold', action='store_true', default=False)
parser.add_argument('--theta_scale', action='store_true', default=False)

# Saving
parser.add_argument('--log_dir', type=str, default='./logs')

# Misc.
parser.add_argument('--no_cuda', action='store_true', default=False)
parser.add_argument('--no_wandb', action='store_true', default=False)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--dataset_seed', type=int, default=0)

argparse_args = parser.parse_args()
