from transformers import PretrainedConfig

from loader.models.simple_model import TransformerForPolynomials

from loader.models.simple_memory_model import MemoryTransformerForPolynomials


def load_model(params, vocab=None, tokenizer=None):
    encoding_method = params.encoding_method
    if "standard" in encoding_method:
        assert vocab is not None
        assert tokenizer is not None

        special_token_ids = dict(
            zip(
                [k + "_id" for k in tokenizer.special_tokens_map],
                tokenizer.convert_tokens_to_ids(tokenizer.special_tokens_map.values()),
            )
        )

        output_dim = tokenizer.vocab_size
        input_dim = params.num_variables
        params.token_register_size = 0
        use_standard_embedding = True
        vocab_size = tokenizer.vocab_size

    else:
        assert vocab is not None
        output_dim = len(vocab) + params.num_variables
        input_dim = params.num_variables + params.token_register_size
        use_standard_embedding = False
        vocab_size = len(vocab)
        special_token_ids = vocab

    if params.model == "bart":
        config = PretrainedConfig.from_dict(
            {
                "encoding_method": params.encoding_method,
                "d_model": params.d_model,
                "nhead": params.nhead,
                "num_encoder_layers": params.num_encoder_layers,
                "num_decoder_layers": params.num_decoder_layers,
                "dim_feedforward": params.dim_feedforward,
                "dropout": params.dropout,
                "output_dim": output_dim,
                "input_dim": input_dim,
                "token_register_size": params.token_register_size,
                "num_variables": params.num_variables,
                # 'max_number'            : params.gaussian_encoding_upper_bound,
                "use_standard_embedding": use_standard_embedding,
                "special_token_ids": special_token_ids,
                "vocab_size": vocab_size,
                "max_sequence_length": params.max_sequence_length,
                "positional_encoding": params.positional_encoding,
                # "num_memory_tokens": params.num_memory_tokens,
                # "num_batch": params.batch_size,
                "num_batch": params.num_batch,
                # "sparsity_lambda": params.sparsity_lambda,
                # "use_register": True,
            }
        )
        model = TransformerForPolynomials(config).cuda()

    elif params.model == "gpt2":
        from transformers import GPT2LMHeadModel, GPT2Config

        config = GPT2Config(
            vocab_size=vocab_size,
            n_positions=params.max_sequence_length,
            n_embd=params.d_model,
            n_layer=params.num_encoder_layers,
            n_head=params.nhead,
            n_inner=params.dim_feedforward,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
        model = GPT2LMHeadModel(config).cuda()
    else:
        raise ValueError("Model not found")

    return model
