class optimizer_config:
    def __init__(self):
        # optimizer config
        self.max_grad_norm = 1
        self.batch_size = 128
        self.train_batch_size = 128
        self.dev_batch_size = 128
        self.bucket_size_factor = 5
        self.DataParallel = False
        self.num_workers = 6
        self.weight_decay = 1e-2
        self.lr = 1e-3
        self.epochs = 100
        self.early_stop_patience = 4
        self.scheduler = "ReduceLROnPlateau"
        self.scheduler_patience = 2
        self.scheduler_reduce_factor = 0.5
        self.optimizer = "Ranger"
        self.save_by = "accuracy"
        self.metric_direction = 1
        self.different_betas = False
        self.chunk_size = -1
        self.display_metric = "accuracy"


class base_config(optimizer_config):
    def __init__(self):
        super().__init__()
        self.word_embd_freeze = False
        self.initial_transform = False
        self.batch_pair = True
        self.parse_trees = False
        self.embd_dim = 200
        self.input_size = 200
        self.hidden_size = 200
        self.rao_k = 10
        self.classifier_hidden_size = 200
        self.rao = False
        self.stochastic = False
        self.test_time_stochastic = False
        self.treedrop = False
        self.global_state_only = True
        self.global_state_return = True
        self.gumbel_diff = False





class GumbelTreeLSTM_config(base_config):
    def __init__(self):
        super().__init__()
        self.in_dropout = 0.1
        self.dropout = 0.1
        self.out_dropout = 0.1
        self.conv_decision = False
        self.encoder_type = "GumbelTreeLSTM"
        self.model_name = "(GumbelTreeLSTM)"


class BiCell_config(base_config):
    def __init__(self):
        super().__init__()
        self.in_dropout = 0.1
        self.dropout = 0.1
        self.out_dropout = 0.1
        self.bidirectional = True
        self.encoder_type = "BiCell"
        self.model_name = "(BiCell)"

class RCell_config(base_config):
    def __init__(self):
        super().__init__()
        self.in_dropout = 0.1
        self.dropout = 0.1
        self.out_dropout = 0.1
        self.bidirectional = False
        self.encoder_type = "BiCell"
        self.model_name = "(RCell)"

class CYKCell_config(base_config):
    def __init__(self):
        super().__init__()
        self.in_dropout = 0.1
        self.dropout = 0.1
        self.out_dropout = 0.1
        self.encoder_type = "CYKCell"
        self.model_name = "(CYKCell)"


class BSRPCell_config(base_config):
    def __init__(self):
        super().__init__()
        self.in_dropout = 0.1
        self.dropout = 0.1
        self.out_dropout = 0.1
        self.beam_size = 5
        self.encoder_type = "BSRPCell"
        self.model_name = "(BSRPCell)"

class BigBSRPCell_config(base_config):
    def __init__(self):
        super().__init__()
        self.in_dropout = 0.1
        self.dropout = 0.1
        self.out_dropout = 0.1
        self.beam_size = 8
        self.encoder_type = "BSRPCell"
        self.model_name = "(BigBSRPCell)"

class BalancedTreeCell_config(BiCell_config):
    def __init__(self):
        super().__init__()
        self.encoder_type = "BalancedTreeCell"
        self.model_name = "(BalancedTreeCell)"

class CRvNN_config(BiCell_config):
    def __init__(self):
        super().__init__()
        self.encoder_type = "CRvNN"
        self.model_name = "(CRvNN)"

class OrderedMemory_config(BiCell_config):
    def __init__(self):
        super().__init__()
        self.batch_pair = True
        self.dropout = 0.1
        self.memory_dropout = 0.1
        self.in_dropout = 0.1
        self.out_dropout = 0.1
        self.memory_slots = 12
        self.hidden_size = 200
        self.encoder_type = "OrderedMemory"
        self.model_name = "(ordered_memory)"


class GumbelTreeCell_config(BiCell_config):
    def __init__(self):
        super().__init__()
        self.conv_decision = False
        self.encoder_type = "GumbelTreeCell"
        self.model_name = "(GumbelTreeCell)"

class NDR_config(base_config):
    def __init__(self):
        super().__init__()
        self.hidden_size = 200
        self.batch_pair = False
        self.train_max_depth = 15
        self.test_max_depth = 15
        self.ff_dim = 4 * self.hidden_size
        self.heads = 8
        self.encoder_type = "ndr_geometric_stack"
        self.model_name = "(NDR)"


class GoldTreeCell_config(BiCell_config):
    def __init__(self):
        super().__init__()
        self.batch_pair = False
        self.parse_trees = True
        self.encoder_type = "GoldTreeCell"
        self.model_name = "(GoldTreeCell)"


class RandomTreeCell_config(BiCell_config):
    def __init__(self):
        super().__init__()
        self.conv_decision = False
        self.encoder_type = "RandomTreeCell"
        self.model_name = "(RandomTreeCell)"

class GumbelTreeCelltest_config(BiCell_config):
    def __init__(self):
        super().__init__()
        self.conv_decision = False
        self.encoder_type = "GumbelTreeCell"
        self.model_name = "(GumbelTreeCelltest)"

class ConvGumbelTreeCell_config(BiCell_config):
    def __init__(self):
        super().__init__()
        self.conv_decision = True
        self.encoder_type = "GumbelTreeCell"
        self.model_name = "(ConvGumbelTreeCell)"

class MCGumbelTreeCell_config(BiCell_config):
    def __init__(self):
        super().__init__()
        self.conv_decision = False
        self.sample_size = 5
        self.encoder_type = "MCGumbelTreeCell"
        self.model_name = "(MCGumbelTreeCell)"

class GumbelRaoTreeCell_config(GumbelTreeCell_config):
    def __init__(self):
        super().__init__()
        self.conv_decision = False
        self.rao = True
        self.encoder_type = "GumbelTreeCell"
        self.model_name = "(GumbelRaoTreeCell)"

class BeamTreeCell_config(GumbelTreeCell_config):
    def __init__(self):
        super().__init__()
        self.conv_decision = False
        self.diffop1 = False
        self.diffop2 = False
        self.stochastic = False
        self.beam_size = 5
        self.encoder_type = "BeamGumbelTreeCell"
        self.model_name = "(BeamTreeCell)"


class BeamTreeLSTM_config(GumbelTreeCell_config):
    def __init__(self):
        super().__init__()
        self.conv_decision = False
        self.diffop1 = False
        self.diffop2 = False
        self.stochastic = True
        self.beam_size = 5
        self.encoder_type = "BeamGumbelTreeLSTM"
        self.model_name = "(BeamTreeLSTM)"



class ContrastBeamTreeCell_config(GumbelTreeCell_config):
    def __init__(self):
        super().__init__()
        self.conv_decision = False
        self.diffop1 = False
        self.diffop2 = False
        self.stochastic = True
        self.beam_size = 2
        self.encoder_type = "ContrastBeamGumbelTreeCell"
        self.model_name = "(ContrastBeamTreeCell)"

class DiffBeamTreeCell_config(GumbelTreeCell_config):
    def __init__(self):
        super().__init__()
        self.conv_decision = False
        self.diffop1 = False
        self.diffop2 = False
        self.stochastic = True
        self.beam_size = 5
        self.gumbel_diff = False
        self.encoder_type = "DiffBeamTreeCell"
        self.model_name = "(DiffBeamTreeCell)"

class GumbelDiffBeamTreeCell_config(GumbelTreeCell_config):
    def __init__(self):
        super().__init__()
        self.conv_decision = False
        self.diffop1 = False
        self.diffop2 = False
        self.stochastic = True
        self.beam_size = 5
        self.gumbel_diff = True
        self.encoder_type = "DiffBeamTreeCell"
        self.model_name = "(GumbelDiffBeamTreeCell)"


class SmallerDiffBeamTreeCell_config(GumbelTreeCell_config):
    def __init__(self):
        super().__init__()
        self.conv_decision = False
        self.diffop1 = False
        self.diffop2 = False
        self.stochastic = True
        self.gumbel_diff = False
        self.beam_size = 2
        self.encoder_type = "DiffBeamTreeCell"
        self.model_name = "(SmallerDiffBeamTreeCell)"

class SmallerGumbelDiffBeamTreeCell_config(GumbelTreeCell_config):
    def __init__(self):
        super().__init__()
        self.conv_decision = False
        self.diffop1 = False
        self.diffop2 = False
        self.stochastic = True
        self.gumbel_diff = True
        self.beam_size = 2
        self.encoder_type = "DiffBeamTreeCell"
        self.model_name = "(SmallerGumbelDiffBeamTreeCell)"

class DiffSortBeamTreeCell_config(GumbelTreeCell_config):
    def __init__(self):
        super().__init__()
        self.conv_decision = False
        self.diffop1 = False
        self.diffop2 = False
        self.stochastic = True
        self.beam_size = 5
        self.encoder_type = "DiffSortBeamTreeCell"
        self.model_name = "(DiffSortBeamTreeCell)"

class TDBeamTreeCell_config(GumbelTreeCell_config):
    def __init__(self):
        super().__init__()
        self.conv_decision = False
        self.diffop1 = False
        self.diffop2 = False
        self.treedrop = True
        self.stochastic = True
        self.beam_size = 2
        self.encoder_type = "BeamGumbelTreeCell"
        self.model_name = "(TDBeamTreeCell)"

class StochasticBeamTreeCell_config(GumbelTreeCell_config):
    def __init__(self):
        super().__init__()
        self.conv_decision = False
        self.diffop1 = False
        self.diffop2 = False
        self.stochastic = True
        self.test_time_stochastic = False
        self.beam_size = 5
        self.encoder_type = "BeamGumbelTreeCell"
        self.model_name = "(StochasticBeamTreeCell)"

class BigBeamTreeCell_config(GumbelTreeCell_config):
    def __init__(self):
        super().__init__()
        self.conv_decision = False
        self.diffop1 = False
        self.diffop2 = False
        self.stochastic = True
        self.test_time_stochastic = False
        self.beam_size = 8
        self.encoder_type = "BeamGumbelTreeCell"
        self.model_name = "(BigBeamTreeCell)"

class Beam1GumbelTreeCell_config(GumbelTreeCell_config):
    def __init__(self):
        super().__init__()
        self.conv_decision = False
        self.diffop1 = True
        self.diffop2 = False
        self.stochastic = True
        self.test_time_stochastic = False
        self.beam_size = 5
        self.encoder_type = "BeamGumbelTreeCell"
        self.model_name = "(Beam1GumbelTreeCell)"

class Beam1GumbelRaoTreeCell_config(GumbelTreeCell_config):
    def __init__(self):
        super().__init__()
        self.conv_decision = False
        self.diffop1 = True
        self.diffop2 = False
        self.stochastic = True
        self.test_time_stochastic = False
        self.rao = True
        self.rao_k = 10
        self.temperature = 1
        self.beam_size = 5
        self.encoder_type = "BeamGumbelTreeCell"
        self.model_name = "(Beam1GumbelRaoTreeCell)"


class BeamDisTreeCell_config(GumbelTreeCell_config):
    def __init__(self):
        super().__init__()
        self.conv_decision = False
        self.diffop1 = False
        self.diffop2 = False
        self.stochastic = True
        self.test_time_stochastic = False
        self.rao = False
        self.beam_size = 5
        self.encoder_type = "BeamGumbelDisentangledTreeCell"
        self.model_name = "(BeamDisTreeCell)"


class Beam1GumbelRao100TreeCell_config(GumbelTreeCell_config):
    def __init__(self):
        super().__init__()
        self.conv_decision = False
        self.diffop1 = True
        self.diffop2 = False
        self.stochastic = True
        self.test_time_stochastic = False
        self.rao = True
        self.rao_k = 100
        self.beam_size = 5
        self.encoder_type = "BeamGumbelTreeCell"
        self.model_name = "(Beam1GumbelRao100TreeCell)"


class Beam2GumbelTreeCell_config(GumbelTreeCell_config):
    def __init__(self):
        super().__init__()
        self.conv_decision = False
        self.diffop1 = False
        self.diffop2 = True
        self.stochastic = True
        self.test_time_stochastic = False
        self.beam_size = 5
        self.encoder_type = "BeamGumbelTreeCell"
        self.model_name = "(Beam2GumbelTreeCell)"

class Beam2GumbelRaoTreeCell_config(GumbelTreeCell_config):
    def __init__(self):
        super().__init__()
        self.conv_decision = False
        self.diffop1 = False
        self.diffop2 = True
        self.stochastic = True
        self.test_time_stochastic = False
        self.rao = True
        self.beam_size = 5
        self.encoder_type = "BeamGumbelTreeCell"
        self.model_name = "(Beam2GumbelRaoTreeCell)"


class Beam12GumbelTreeCell_config(GumbelTreeCell_config):
    def __init__(self):
        super().__init__()
        self.conv_decision = False
        self.diffop1 = True
        self.diffop2 = True
        self.stochastic = True
        self.test_time_stochastic = False
        self.beam_size = 5
        self.encoder_type = "BeamGumbelTreeCell"
        self.model_name = "(Beam12GumbelTreeCell)"

class Beam12GumbelRaoTreeCell_config(GumbelTreeCell_config):
    def __init__(self):
        super().__init__()
        self.conv_decision = False
        self.diffop1 = True
        self.diffop2 = True
        self.stochastic = True
        self.test_time_stochastic = False
        self.beam_size = 5
        self.rao = True
        self.encoder_type = "BeamGumbelTreeCell"
        self.model_name = "(Beam12GumbelRaoTreeCell)"


class SmallBeamTreeCell_config(GumbelTreeCell_config):
    def __init__(self):
        super().__init__()
        self.conv_decision = False
        self.diffop1 = False
        self.diffop2 = False
        self.stochastic = True
        self.test_time_stochastic = False
        self.beam_size = 3
        self.encoder_type = "BeamGumbelTreeCell"
        self.model_name = "(SmallBeamTreeCell)"

class SmallBeam1GumbelRaoTreeCell_config(GumbelTreeCell_config):
    def __init__(self):
        super().__init__()
        self.conv_decision = False
        self.diffop1 = True
        self.diffop2 = False
        self.stochastic = True
        self.test_time_stochastic = False
        self.rao = True
        self.beam_size = 3
        self.encoder_type = "BeamGumbelTreeCell"
        self.model_name = "(SmallBeam1GumbelRaoTreeCell)"


class SmallerBeamTreeCell_config(GumbelTreeCell_config):
    def __init__(self):
        super().__init__()
        self.conv_decision = False
        self.diffop1 = False
        self.diffop2 = False
        self.stochastic = True
        self.test_time_stochastic = False
        self.beam_size = 2
        self.encoder_type = "BeamGumbelTreeCell"
        self.model_name = "(SmallerBeamTreeCell)"