
class optimizer_config:
    def __init__(self):
        # optimizer config
        self.max_grad_norm = 5
        self.batch_size = 32
        self.train_batch_size = 32
        self.dev_batch_size = 32
        self.bucket_size_factor = 1
        self.DataParallel = False
        self.num_workers = 6
        self.weight_decay = 0
        self.lr = 0.001
        self.epochs = 100
        self.early_stop_patience = 4
        self.scheduler = "ReduceLROnPlateau"
        self.scheduler_patience = 2
        self.scheduler_reduce_factor = 0.5
        self.optimizer = "Adam"
        self.save_by = "F1"
        self.metric_direction = 1
        self.different_betas = False
        self.chunk_size = -1


class base_config(optimizer_config):
    def __init__(self):
        super().__init__()
        self.word_embd_freeze = False
        self.position_max_len = 5000
        self.embd_dim = 100
        self.hidden_size = 300
        self.global_state_return = False
        self.global_state_only = False
        self.display_metric = "F1"
        self.initial_transform = False


class LSTM_config(base_config):
    def __init__(self):
        super().__init__()
        self.in_dropout = 0.1
        self.out_dropout = 0.1
        self.input_size = self.embd_dim
        self.encoder_type = "LSTM"
        self.model_name = "(LSTM)"

class Transformer_config(base_config):
    def __init__(self):
        super().__init__()
        self.encode_layers = 6
        self.in_dropout = 0.1
        self.dropout = 0.1
        self.attn_dropout = 0.0
        self.out_dropout = 0.2
        self.heads = 8
        self.head_dim = 64
        self.attention_type = "Multiheaded_Attention"
        self.encoder_type = "TransformerEncoder"
        self.model_name = "(Transformer)"

class TreeTransformer_config(Transformer_config):
    def __init__(self):
        super().__init__()
        self.encoder_type = "TreeTransformerEncoder"
        self.model_name = "(TreeTransformer)"

class UniversalTransformer_config(Transformer_config):
    def __init__(self):
        super().__init__()
        self.in_dropout = 0.1
        self.dropout = 0.1
        self.out_dropout = 0.2
        self.heads = 8
        self.head_dim = 64
        self.penalty_gamma = 0.001
        self.upperbound_style = "fixed"
        self.upperbound = 10
        self.threshold = 0.99
        self.attention_type = "Multiheaded_Attention"
        self.encoder_type = "UniversalTransformerEncoder"
        self.model_name = "(Universal Transformer)"

class Structformer_config(Transformer_config):
    def __init__(self):
        super().__init__()
        self.encoder_type = "StructFormer"
        self.model_name = "(Structformer)"

class AnotherTreeTransformer_config(base_config):
    def __init__(self):
        super().__init__()
        self.max_depth = 10
        self.scorer_window_size = 5
        self.stop_threshold = 0.2

        self.in_dropout = 0.1
        self.dropout = 0.1
        self.out_dropout = 0.2
        self.heads = 8
        self.head_dim = 64

        self.encoder_type = "AnotherTreeTransformer"
        self.model_name = "(ATT)"



class OrderedMemory_config(base_config):
    def __init__(self):
        super().__init__()
        self.batch_pair = True
        self.dropout = 0.2
        self.output_last = False
        self.left_padded = False
        self.memory_dropout = 0.2
        self.in_dropout = 0.1
        self.out_dropout = 0.3
        self.memory_slots = 12
        self.double_slots_during_val = True
        self.hidden_size = 200
        self.bidirection = False
        self.different_betas = True
        self.encoder_type = "OrderedMemory"
        self.model_name = "(OrderedMemory)"
        self.optimizer = "Adam"
        self.weight_decay = 1.2e-6
        self.max_grad_norm = 1

class CRvNNplus_config(base_config):
    def __init__(self):
        super().__init__()
        self.max_depth = 20
        self.scorer_window_size = 5
        self.stop_threshold = 0.2
        self.in_dropout = 0.1
        self.dropout = 0.1
        self.out_dropout = 0.3
        self.lr = 0.001

        self.encoder_type = "CRvNNplus"
        self.model_name = "(CRvNNplus)"


class RECONzero_config(CRvNNplus_config):
    def __init__(self):
        super().__init__()
        self.train_batch_size = 32
        self.encoder_type = "RECONzero"
        self.model_name = "(RECONzero)"



