import os


class Args(object):
    def __init__(self):
        
        # CUDA part

        # self.cuda_visible_devices = "0, 1, 2, 3"
        self.cuda_visible_devices = "0"
        # self.cuda_visible_devices = "2, 3"


        # visual & symbol part

        self.use_resnet = False

        self.speaker_freeze_visual = False
        self.listener_freeze_visual = False

        self.speaker_freeze = False
        self.listener_freeze = False

        self.speaker_use_pretrain_model = False
        self.speaker_pretrain_path = "./dump/SCL_egg_4x10_spo_symbol_r_cs_10k_num/checkpoints/best_epoch.pth"
        self.speaker_pretrain_params_key = "speaker"

        self.listener_use_pretrain_model = False
        self.listener_pretrain_path = ""
        self.listener_pretrain_params_key = "listener"
        
        # hyper-paramerter part
        self.max_pooling_message_embedding = False
        self.mlp_pooling_message_embedding = False

        assert not(self.max_pooling_message_embedding and self.mlp_pooling_message_embedding)


        self.image_embedding_dim = 120
        self.message_embedding_dim = 5 * self.image_embedding_dim if not self.max_pooling_message_embedding and not self.mlp_pooling_message_embedding else 80
        

        self.message_max_len = 4
        self.vocab_size = 15


        self.message_encoder_rnn_hidden_dim = 5 * self.image_embedding_dim if not self.max_pooling_message_embedding and not self.mlp_pooling_message_embedding else 80
        self.message_decoder_rnn_hidden_dim = 5 * self.image_embedding_dim if not self.max_pooling_message_embedding and not self.mlp_pooling_message_embedding else 80

        self.message_encoder_transformer_hidden_dim = 5 * self.image_embedding_dim if not self.max_pooling_message_embedding and not self.mlp_pooling_message_embedding else 80
        self.message_decoder_transformer_hidden_dim = 5 * self.image_embedding_dim if not self.max_pooling_message_embedding and not self.mlp_pooling_message_embedding else 80

        self.cell = 'gru'
        self.length_cost = 0
        self.gumbel_temperature = 1.0
        
        self.message_length_cost = 0

        self.lr = 3e-3
        self.vlr = 5e-3
        self.slr = 1e-4
        self.weight_decay = 0.01
        

        self.rand_perm = False
        self.message_embedding_noise = "none" # gaussian, none

        self.listener_reset = False
        self.listener_reset_times = 50
        self.listener_reset_cycle = 20

        self.symbol_attr_dim = 4
        self.symbol_onehot_dim = 30
        self.symbol_model_hidden_dims = [32,]
        self.rules_dim = 15
        

        self.max_epoches = 200
        # self.max_epoches = 0
        self.dataloader_num_workers = 8
        self.train_batch_size = 512
        # self.train_batch_size = 2048
        self.test_batch_size = 32
        self.validation_batch_size = 32

        self.data_dir = "./data/paper/warmup_%d/" % (self.symbol_onehot_dim)

        self.data_format_str = '%s_visual.pkl'
        
        self.auto_resume = True

        self.agent_type = 'rnn_reinforce' # ["rnn_gs", "rnn_reinforce", "transformer_reinforce"]
        self.rule = True
        # self.rule = False
        self.visual, self.symbol = (False, True)
        # self.visual, self.symbol = (True, False)
        assert self.visual ^ self.symbol
        self.null_message = False
        self.const_message = False
        self.use_message_max_len = False

        self.add_LN = False

        self.use_constrative = False


        self.rnn_gs_speaker_configs = {
            "vocab_size": self.vocab_size, 
            "embed_dim": self.message_embedding_dim, 
            "hidden_size": self.message_encoder_rnn_hidden_dim, 
            "cell": self.cell,
            "max_len": self.message_max_len,
            "temperature": self.gumbel_temperature,
            "trainable_temperature": False,
            "straight_through": False
        }

        self.rnn_gs_listener_configs = {
            "vocab_size": self.vocab_size, 
            "embed_dim": self.message_embedding_dim, 
            "hidden_size": self.message_decoder_rnn_hidden_dim, 
            "cell": self.cell,
        }

        self.rnn_reinforce_speaker_configs = {
            "vocab_size": self.vocab_size, 
            "embed_dim": self.message_embedding_dim, 
            "hidden_size": self.message_encoder_rnn_hidden_dim, 
            "cell": self.cell,
            "max_len": self.message_max_len,
        }

        self.rnn_reinforce_listener_configs = {
            "vocab_size": self.vocab_size, 
            "embed_dim": self.message_embedding_dim, 
            "hidden_size": self.message_decoder_rnn_hidden_dim, 
            "cell": self.cell,
        }

        self.transformer_reinforce_speaker_configs = {
            "vocab_size": self.vocab_size, 
            "embed_dim": self.message_embedding_dim, 
            "max_len": self.message_max_len,
            "num_layers": 1,
            "num_heads": 4,
            "hidden_size": self.message_encoder_transformer_hidden_dim,
            "generate_style": "standard",
            "causal": True,
        }

        self.transformer_reinforce_listener_configs = {
            "vocab_size": self.vocab_size, 
            "max_len": self.message_max_len,
            "embed_dim": self.message_embedding_dim, 
            "num_heads": 4,
            "hidden_size": self.message_decoder_transformer_hidden_dim,
            "num_layers": 1,
            "positional_emb": True,
            "causal": True,
        }
        

        self.visual_shared_mlp_configs = {
            "groups": 10, 
            "group_input_dim": 256, 
            "group_output_dim": 8,
            "hidden_dims": [128], 
            "add_res_block": True, 
            "nr_mlps": 1, 
            "flatten": True,
            "shared": True
        }

        self.discri_analogy_shared_mlp_configs = {
            "groups": self.image_embedding_dim, 
            "group_input_dim": 3, 
            "group_output_dim": 1,
            "hidden_dims": [64, 32], 
            "add_res_block": True, 
            "nr_mlps": 5, 
            "flatten": True,
            "shared": True
        }

        self.recon_analogy_shared_mlp_configs = {
            "groups": self.symbol_attr_dim, # 4
            "group_input_dim": 2 * self.image_embedding_dim // self.symbol_attr_dim, # 40
            "group_output_dim": self.image_embedding_dim // self.symbol_attr_dim, # 20
            "hidden_dims": [64, 32], 
            "add_res_block": True, 
            "nr_mlps": self.rules_dim, # 10 
            "flatten": False,
            "shared": True
        }

        self.constrative_analogy_shared_mlp_configs = {
            "groups": 80, 
            "group_input_dim": 2, 
            "group_output_dim": 1,
            "hidden_dims": [64, 32], 
            "add_res_block": True, 
            "nr_mlps": 5, 
            "flatten": True,
            "shared": True
        }

        # self.execution_id = "test"
        self.execution_id = '%dx%d_warmup_88_%d' % (self.message_max_len, self.vocab_size, self.symbol_onehot_dim)
        # self.execution_id = 'SCL_egg_%dx%d_symbol_rnn_discri_gs_vlenmsg_ptlistener_r_cs_10k_num' % (self.message_max_len, self.vocab_size)
        # self.execution_id = 'SCL_egg_%dx%d_%dx%d_symbol_1e41e-2cs_rule_10k' % (self.message_max_len, self.vocab_size, self.listener_reset_times, self.listener_reset_cycle)
        self.dump_root = './dump_paper/'
        self.dump_dir = os.path.join(self.dump_root, self.execution_id)

        
        self.dump_message = False
        self.discri_game = True
        
    
    def get_execution_id(self):
        pass
