import torch
from transformers import AutoModel, AutoTokenizer, AutoConfig

from DatasetInference import return_di_embeddings
from custom_datasets import OurDatasets
from paths import PATH_TO_SENTEVAL
from stealing.steal_MLP import load_BERT_for_CL, get_sent_features
from utils import dotdict

M_CONV =
M_UNVONC =
TOKENIZER = "prajjwal1/bert-tiny"


args = dotdict()
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Converted
converted_model = AutoModel.from_pretrained(M_CONV)
for name, param in converted_model.named_parameters():
    # if name not in ['fc.weight', 'fc.bias']:
    param.requires_grad = False
converted_model = converted_model.to(device)
converted_model.eval()

# Same model but non-converted
config = AutoConfig.from_pretrained(TOKENIZER)
config.hidden_size = 128
config.output_size = 128
args.model_to_load = M_UNVONC
args.pooler = "cls"
unconverted_model = load_BERT_for_CL(M_UNVONC, args, config, tokenizer)
for name, param in unconverted_model.named_parameters():
    # if name not in ['fc.weight', 'fc.bias']:
    param.requires_grad = False
unconverted_model = unconverted_model.to(device)
unconverted_model.eval()

## get the sick kids data
# sent evals "sick"
sick_path = PATH_TO_SENTEVAL + "/data/downstream/SICK/SICK_train.txt"
test_dataset = OurDatasets(sick_path, header=0,
                               col_inds=[1, 2],
                               sep="\t",
                           )  # header is none if we take our curated data


test_query_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=50, shuffle=False,
    num_workers=1, pin_memory=True, drop_last=True)
test_data = next(iter(test_query_loader))
test_features = get_sent_features(args, batch=test_data,
                                          tokenizer=tokenizer,
                                          sent_features_type="standard")
test_features = test_features.to(device)

conv_embeddings = return_di_embeddings(converted_model, test_features, use_pooler=True, return_sent_emb=False)
unconv_embeddings = return_di_embeddings(unconverted_model, test_features, use_pooler=True, return_sent_emb=True)

diff = torch.linalg.norm(conv_embeddings-unconv_embeddings)
print(diff)


