import torch
import pandas as pd


### idea: compare my qqp and flickr (BERT) trained with eachother. Give their cosine similarity.
#  it should not be zero
from torch.nn import DataParallel
from transformers import AutoTokenizer, AutoModel

from DatasetInference import return_di_embeddings
from stealing.calculcate_cosine_similarity import cosine_similarity
from stealing.steal_MLP import get_sent_features
from utils import dotdict
import numpy as np



BASE_MODEL = "bert-base-uncased"
model1_path =
model2_path =

NUM_SAMPLES = 20000
SEED = 42
BATCH_SIZE = 100
OUTPUT_DIMS = 768
NUM_BATCHES = NUM_SAMPLES // BATCH_SIZE

reps1 = np.zeros((NUM_SAMPLES, OUTPUT_DIMS))
reps2 = np.zeros((NUM_SAMPLES, OUTPUT_DIMS))

NLI_PATH =

NLI_DATA = pd.read_csv(NLI_PATH, sep=',', header=None).sample(NUM_SAMPLES, random_state=SEED).squeeze().astype(
    dtype=str).to_list()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)


tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
args = dotdict()
args.outputdir = ""
args.logdir = ""
args.seed = 42
args.model_to_load = ""
args.pooler = "cls"
args.seqlength = 32
args.device = device

# evaluate the respective victim model
model1 = AutoModel.from_pretrained(model1_path)
model1 = DataParallel(model1)
model1 = model1.to(device)
for name, param in model1.named_parameters():
    param.requires_grad = False
for i in range(NUM_BATCHES):
    x_batch = NLI_DATA[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
    sent_features = get_sent_features(args, x_batch, tokenizer, sent_features_type="standard")
    embeddings = return_di_embeddings(model1, sent_features, use_pooler=True,
                                      return_sent_emb=False)  # the API does not use our stuff
    embeddings = embeddings.detach().cpu().numpy()
    reps1[i * BATCH_SIZE: (i + 1) * BATCH_SIZE] = embeddings
    del sent_features
    del embeddings
del model1


model2 = AutoModel.from_pretrained(model2_path)
model2 = DataParallel(model2)
model2 = model2.to(device)
for name, param in model2.named_parameters():
    param.requires_grad = False
for i in range(NUM_BATCHES):
    x_batch = NLI_DATA[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
    sent_features = get_sent_features(args, x_batch, tokenizer, sent_features_type="standard")
    embeddings = return_di_embeddings(model2, sent_features, use_pooler=True,
                                      return_sent_emb=False)  # the API does not use our stuff
    embeddings = embeddings.detach().cpu().numpy()
    reps2[i * BATCH_SIZE: (i + 1) * BATCH_SIZE] = embeddings
    del sent_features
    del embeddings
del model2

score = cosine_similarity(reps1, reps2)[:2]
print(score)



