import os
from transformers import AutoModel, AutoTokenizer, AutoConfig
import argparse
from utils import apply_pooler, dotdict
import logging

from simcse.models import BertForCL
from stealing.steal_MLP import load_BERT_for_CL, get_sent_features, trainer_eval, load_ROBERTA_for_CL

MODEL_NAMES = [
    ]

TOKENIZER_TINY = "prajjwal1/bert-tiny"
TOKENIZER_BERT = "bert-base-uncased"
TOKENIZER_ROBERTA = "roberta-large"


def main():
    TINY = False

    for model_path in MODEL_NAMES:
        if "roberta" in model_path:
            config = AutoConfig.from_pretrained(TOKENIZER_ROBERTA)
            tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ROBERTA)
            config.hidden_size = 1024
            config.output_size = 1024
        elif "big" in model_path:
            tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_BERT)
            config = AutoConfig.from_pretrained(TOKENIZER_BERT)
            config.hidden_size = 768
            config.output_size = 768
        else:
            config = AutoConfig.from_pretrained(TOKENIZER_TINY)
            tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_TINY)
            config.hidden_size = 128
            config.output_size = 128

        args = dotdict()
        args.outputdir = model_path
        args.logdir = model_path
        args.seed = 42
        args.model_to_load = model_path
        args.pooler = "cls"

        if "roberta" in model_path:
            model = load_ROBERTA_for_CL(model_path, args, config, tokenizer)
        else:
            model = load_BERT_for_CL(model_path, args, config, tokenizer)

        print("done loading model, starting evaluation")
        results = trainer_eval(args, model, tokenizer, batcher_type='standard', eval_senteval_transfer=True)

        output_eval_file = os.path.join(model_path,
                                        "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logging.info("***** Eval results *****")
            for key, value in sorted(results.items()):
                logging.info(f"  {key} = {value}")
                writer.write(f"{key} = {value}\n")


if __name__ == '__main__':
    main()
