import torch

class Seq2SeqAttrs:
    def __init__(self, level_sizes, **model_kwargs):
        self.cl_decay_steps = int(model_kwargs.get('cl_decay_steps', 1000))
        self.layer_num = int(model_kwargs.get('layer_num', 2))
        self.enc_layer_num = int(model_kwargs.get('enc_layer_num', 1))
        self.embed_size = int(model_kwargs.get('embed_size', 32))
        self.input_dim = int(model_kwargs.get('input_dim', 1))
        self.output_dim = int(model_kwargs.get('output_dim', 1))
        self.seq_len = int(model_kwargs.get('seq_len', 12))
        self.location_dim = int(model_kwargs.get('location_dim', 2))
        self.time_dim = int(model_kwargs.get('time_dim', 4))
        self.horizon = int(model_kwargs.get('horizon', 12))
        self.ker_embed_size = int(model_kwargs.get('ker_embed_size', 2*self.embed_size))
        self.cont_dec = bool(model_kwargs.get('cont_dec', False))
        self.reconst_loss = bool(model_kwargs.get('reconst_loss', False))
        self.dec_layer_num = int(model_kwargs.get('dec_layer_num', 2))
        self.level_sizes = level_sizes
        self.level_num = len(level_sizes)