import pandas as pd
import torch
import os
import pandas

from DatasetInference import return_di_embeddings
from stealing.steal_MLP import load_BERT_for_CL, get_sent_features
from utils import dotdict

from transformers import AutoTokenizer, AutoConfig

### TINY BERT PART


BASE_MODEL = "prajjwal1/bert-tiny"

TINY_VICTIM_ROOT =
TINY_INDEPENDENT_ROOT =
TINY_STOLEN_ROOT =

NLI_PATH =
QQP_PATH =
FLICKR_PATH =

REPRESENTATIONS_VICTIM_ROOT =
REPRESENTATIONS_INDEPENDENT_ROOT =
REPRESENTATIONS_STOLEN_ROOT =

NUM_SAMPLES = 20000
OUTPUT_DIMS = 128
SEED = 42
BATCH_SIZE = 20
NUM_BATCHES = NUM_SAMPLES // BATCH_SIZE
reps = torch.zeros(NUM_SAMPLES, OUTPUT_DIMS)  # to be filled at any point

NLI_DATA = pd.read_csv(NLI_PATH, sep=',', header=None).sample(NUM_SAMPLES, random_state=SEED).squeeze().astype(
    dtype=str).to_list()
QQP_DATA = pd.read_csv(QQP_PATH, sep=',', header=None).sample(NUM_SAMPLES, random_state=SEED).squeeze().astype(
    dtype=str).to_list()
FLICKR_DATA = pd.read_csv(FLICKR_PATH, sep=',', header=None).sample(NUM_SAMPLES, random_state=SEED).squeeze().astype(
    dtype=str).to_list()

device = "cpu" #'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)

DATA = ["nli", "qqp", "flickr"]

### preparation for the model
config = AutoConfig.from_pretrained(BASE_MODEL)
config.hidden_size = OUTPUT_DIMS
config.output_size = OUTPUT_DIMS

args = dotdict()
args.outputdir = ""
args.logdir = ""
args.seed = 42
args.model_to_load = ""
args.pooler = "cls"
args.seqlength = 32
args.device = device

for victim_data in DATA:
    print(f"victim: {victim_data}")
    if victim_data == "nli":
        data = NLI_DATA
    elif victim_data == "qqp":
        data = QQP_DATA
    elif victim_data == "flickr":
        data = FLICKR_DATA
    else:
        raise Exception("No such data")

    # evaluate the respective victim model
    victim_name = f"my-sup-simcse-tiny-bert-{victim_data}-10epochs_unconverted"
    victim_path = os.path.join(TINY_VICTIM_ROOT, victim_name)

    model = load_BERT_for_CL(victim_path, args, config, tokenizer)
    for name, param in model.named_parameters():
        param.requires_grad = False
    model = model.to(device)
    for name, param in model.named_parameters():
        param.requires_grad = False
    for i in range(NUM_BATCHES):
        x_batch = data[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
        sent_features = get_sent_features(args, x_batch, tokenizer, sent_features_type="standard")
        embeddings = return_di_embeddings(model, sent_features, use_pooler=True, return_sent_emb=True)
        reps[i * BATCH_SIZE: (i + 1) * BATCH_SIZE] = embeddings
        del sent_features
        del embeddings
    save_name = f"tiny_victim_{victim_data}_on_{victim_data}_data.pt"
    save_dir = os.path.join(REPRESENTATIONS_VICTIM_ROOT, save_name)
    torch.save(reps, save_dir)
    del model


    # evaluate all the stolen models
    for steal_dataset in DATA:
        print(f"stolen with: {steal_dataset}")
        steal_name = f"tiny-{victim_data}-with-{steal_dataset}-API"
        steal_path = os.path.join(TINY_STOLEN_ROOT, steal_name)

        model = load_BERT_for_CL(steal_path, args, config, tokenizer)
        for name, param in model.named_parameters():
            param.requires_grad = False
        model = model.to(device)
        for i in range(NUM_BATCHES):
            x_batch = data[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
            sent_features = get_sent_features(args, x_batch, tokenizer, sent_features_type="standard")
            embeddings = return_di_embeddings(model, sent_features, use_pooler=True, return_sent_emb=True)
            reps[i * BATCH_SIZE: (i + 1) * BATCH_SIZE] = embeddings
            del sent_features
            del embeddings
        save_name = f"tiny_stolen_{victim_data}_with_{steal_dataset}_on_{victim_data}_data.pt"
        save_dir = os.path.join(REPRESENTATIONS_STOLEN_ROOT, save_name)
        torch.save(reps, save_dir)
        del model

    # evaluate the three independent models
    for independent in DATA:
        print(f"independent: {independent}")
        independent_name = f"tiny-bert-{independent}-10epochs_unconverted_alternative"
        independent_path = os.path.join(TINY_INDEPENDENT_ROOT, independent_name)

        model = load_BERT_for_CL(independent_path, args, config, tokenizer)
        for name, param in model.named_parameters():
            param.requires_grad = False
        model = model.to(device)
        for i in range(NUM_BATCHES):
            x_batch = data[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
            sent_features = get_sent_features(args, x_batch, tokenizer, sent_features_type="standard")
            embeddings = return_di_embeddings(model, sent_features, use_pooler=True, return_sent_emb=True)
            reps[i * BATCH_SIZE: (i + 1) * BATCH_SIZE] = embeddings
            del sent_features
            del embeddings
        save_name = f"tiny_independent_{independent}_on_{victim_data}_data.pt"
        save_dir = os.path.join(REPRESENTATIONS_INDEPENDENT_ROOT, save_name)
        torch.save(reps, save_dir)
        del model
