import pandas as pd
import torch
import os
import pandas

from DatasetInference import return_di_embeddings
from stealing.steal_MLP import load_BERT_for_CL, get_sent_features
from utils import dotdict

from transformers import AutoTokenizer, AutoConfig

### TINY BERT PART


BASE_MODEL = "prajjwal1/bert-tiny"

TINY_VICTIM_ROOT =
TINY_INDEPENDENT_ROOT =
TINY_STOLEN_ROOT =

#### I HAVE CHANGED THE PATHS HERE
QQP_PATH =
FLICKR_PATH =

REPRESENTATIONS_VICTIM_ROOT =
REPRESENTATIONS_INDEPENDENT_ROOT =
REPRESENTATIONS_STOLEN_ROOT =

NUM_SAMPLES = 20000
OUTPUT_DIMS = 128
SEED = 42
BATCH_SIZE = 20
NUM_BATCHES = NUM_SAMPLES // BATCH_SIZE
reps = torch.zeros(NUM_SAMPLES, OUTPUT_DIMS)  # to be filled at any point

QQP_DATA = pd.read_csv(QQP_PATH, sep=',', header=None).sample(NUM_SAMPLES, random_state=SEED).squeeze().astype(
    dtype=str).to_list()
FLICKR_DATA = pd.read_csv(FLICKR_PATH, sep=',', header=None).sample(NUM_SAMPLES, random_state=SEED).squeeze().astype(
    dtype=str).to_list()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)

DATA = ["flickr", "flickr", "qqp"] # the first two need to be evaluated on flickr, the last on qqp
MIXED_NAMES = ["flickr_mixed_into_qqp", "full_flickr_mixed_into_qqp", "qqp_mixed_into_flickr"]
DATA_NAMES = ["nli", "qqp", "flickr"]

### preparation for the model
config = AutoConfig.from_pretrained(BASE_MODEL)
config.hidden_size = OUTPUT_DIMS
config.output_size = OUTPUT_DIMS

args = dotdict()
args.outputdir = ""
args.logdir = ""
args.seed = 42
args.model_to_load = ""
args.pooler = "cls"
args.seqlength = 32
args.device = device

for j,victim_name in enumerate(MIXED_NAMES):
    print(f"victim: {victim_name}")

    victim_data = DATA[j]
    if victim_data == "qqp":
        data = QQP_DATA
    elif victim_data == "flickr":
        data = FLICKR_DATA
    else:
        raise Exception("No such data")

    # evaluate the respective victim model
    victim_name_long = f"{victim_name}-10epochs-unconverted"
    victim_path = os.path.join(TINY_VICTIM_ROOT, victim_name_long)

    model = load_BERT_for_CL(victim_path, args, config, tokenizer)
    for name, param in model.named_parameters():
        param.requires_grad = False
    model = model.to(device)
    for name, param in model.named_parameters():
        param.requires_grad = False
    for i in range(NUM_BATCHES):
        x_batch = data[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
        sent_features = get_sent_features(args, x_batch, tokenizer, sent_features_type="standard")
        embeddings = return_di_embeddings(model, sent_features, use_pooler=True, return_sent_emb=True)
        reps[i * BATCH_SIZE: (i + 1) * BATCH_SIZE] = embeddings
        del sent_features
        del embeddings
    save_name = f"tiny_victim_{victim_name}_on_{victim_data}_data.pt"
    save_dir = os.path.join(REPRESENTATIONS_VICTIM_ROOT, save_name)
    torch.save(reps, save_dir)
    del model


    # evaluate all the stolen models
    for steal_dataset in DATA_NAMES:
        print(f"stolen with: {steal_dataset}")
        steal_name = f"steal_{victim_name}_with_{steal_dataset}_unconverted"
        steal_path = os.path.join(TINY_STOLEN_ROOT, steal_name)

        model = load_BERT_for_CL(steal_path, args, config, tokenizer)
        for name, param in model.named_parameters():
            param.requires_grad = False
        model = model.to(device)
        for i in range(NUM_BATCHES):
            x_batch = data[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
            sent_features = get_sent_features(args, x_batch, tokenizer, sent_features_type="standard")
            embeddings = return_di_embeddings(model, sent_features, use_pooler=True, return_sent_emb=True)
            reps[i * BATCH_SIZE: (i + 1) * BATCH_SIZE] = embeddings
            del sent_features
            del embeddings
        save_name = f"tiny_stolen_{victim_name}_with_{steal_dataset}_on_{victim_data}_data.pt"
        save_dir = os.path.join(REPRESENTATIONS_STOLEN_ROOT, save_name)
        torch.save(reps, save_dir)
        del model