import pandas as pd
import torch
import os
import pandas
from torch.nn import DataParallel

from DatasetInference import return_di_embeddings
from stealing.steal_MLP import load_BERT_for_CL, get_sent_features, load_ROBERTA_for_CL
from utils import dotdict

from transformers import AutoTokenizer, AutoConfig, AutoModel



MODE = "roberta"  # "bert", "roberta"

if MODE == "bert":
    BASE_MODEL = "bert-base-uncased"
    VICTIM_ROOT = "princeton-nlp/sup-simcse-bert-base-uncased"
    VICTIM_OURS_ROOT = ""
    OUTPUT_DIMS = 768


elif MODE == "roberta":
    BASE_MODEL = "roberta-large"
    VICTIM_ROOT = "princeton-nlp/sup-simcse-roberta-large"
    VICTIM_OURS_ROOT = ""
    OUTPUT_DIMS = 1024

else:
    raise Exception("no such model")

INDEPENDENT_ROOT =
STOLEN_ROOT =

REPRESENTATIONS_VICTIM_ROOT =
REPRESENTATIONS_INDEPENDENT_ROOT =
REPRESENTATIONS_STOLEN_ROOT =

NLI_PATH =
QQP_PATH =
FLICKR_PATH =

NUM_SAMPLES = 20000
SEED = 42
BATCH_SIZE = 100
NUM_BATCHES = NUM_SAMPLES // BATCH_SIZE
reps = torch.zeros(NUM_SAMPLES, OUTPUT_DIMS)  # to be filled at any point

NLI_DATA = pd.read_csv(NLI_PATH, sep=',', header=None).sample(NUM_SAMPLES, random_state=SEED).squeeze().astype(
    dtype=str).to_list()
QQP_DATA = pd.read_csv(QQP_PATH, sep=',', header=None).sample(NUM_SAMPLES, random_state=SEED).squeeze().astype(
    dtype=str).to_list()
FLICKR_DATA = pd.read_csv(FLICKR_PATH, sep=',', header=None).sample(NUM_SAMPLES, random_state=SEED).squeeze().astype(
    dtype=str).to_list()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)

DATA = ["nli", "qqp", "flickr"]
VIC_DATA = ["nli"]

### preparation for the model
config = AutoConfig.from_pretrained(BASE_MODEL)
config.hidden_size = OUTPUT_DIMS
config.output_size = OUTPUT_DIMS

args = dotdict()
args.outputdir = ""
args.logdir = ""
args.seed = 42
args.model_to_load = ""
args.pooler = "cls"
args.seqlength = 32
args.device = device

for victim_data in VIC_DATA:
    print(f"victim: {victim_data}")
    if victim_data == "nli":
        data = NLI_DATA
    elif victim_data == "qqp":
        data = QQP_DATA
    elif victim_data == "flickr":
        data = FLICKR_DATA
    else:
        raise Exception("No such data")

    # First get representations for our model
    if MODE == "bert":
        our_name = f"bert-base-uncased-{victim_data}-3epochs_unconverted_alternative"
        our_path = os.path.join(VICTIM_OURS_ROOT, our_name)
        model = load_BERT_for_CL(our_path, args, config, tokenizer)


    else:
        our_name = f"roberta-uncased-{victim_data}-3epochs_unconverted_alternative"
        our_path = os.path.join(VICTIM_OURS_ROOT, our_name)
        model = load_ROBERTA_for_CL(our_path, args, config, tokenizer)
    model = DataParallel(model)
    model = model.to(device)
    for name, param in model.named_parameters():
        param.requires_grad = False
    for i in range(NUM_BATCHES):
        x_batch = data[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
        sent_features = get_sent_features(args, x_batch, tokenizer, sent_features_type="standard")
        embeddings = return_di_embeddings(model, sent_features, use_pooler=True,
                                          return_sent_emb=True)  # the API does not use our stuff
        embeddings = embeddings.detach().cpu()
        reps[i * BATCH_SIZE: (i + 1) * BATCH_SIZE] = embeddings
        del sent_features
        del embeddings
    save_name = f"{MODE}_victim_{victim_data}_on_{victim_data}_data_ours.pt"
    save_dir = os.path.join(REPRESENTATIONS_VICTIM_ROOT, save_name)
    torch.save(reps, save_dir)
    del model


    # Then get the representations from the API: evaluate the respective victim model
    model = AutoModel.from_pretrained(VICTIM_ROOT)
    model = DataParallel(model)
    model = model.to(device)
    for name, param in model.named_parameters():
        param.requires_grad = False
    for i in range(NUM_BATCHES):
        x_batch = data[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
        sent_features = get_sent_features(args, x_batch, tokenizer, sent_features_type="standard")
        embeddings = return_di_embeddings(model, sent_features, use_pooler=True,
                                          return_sent_emb=False)  # the API does not use our stuff
        embeddings = embeddings.detach().cpu()
        reps[i * BATCH_SIZE: (i + 1) * BATCH_SIZE] = embeddings
        del sent_features
        del embeddings
    save_name = f"{MODE}_victim_{victim_data}_on_{victim_data}_data.pt"
    save_dir = os.path.join(REPRESENTATIONS_VICTIM_ROOT, save_name)
    torch.save(reps, save_dir)
    del model

    # evaluate all the stolen models
    for steal_dataset in DATA:
        print(f"stolen with: {steal_dataset}")

        if MODE == "bert":
            steal_name = f"bert-base-{victim_data}-with-{steal_dataset}-API"
            steal_path = os.path.join(STOLEN_ROOT, steal_name)
            model = load_BERT_for_CL(steal_path, args, config, tokenizer)

        else:
            steal_name = f"roberta-large-{victim_data}-with-{steal_dataset}-API-interactive/4"
            steal_path = os.path.join(STOLEN_ROOT, steal_name)
            model = load_ROBERTA_for_CL(steal_path, args, config, tokenizer)



        model = DataParallel(model)
        for name, param in model.named_parameters():
            param.requires_grad = False
        model = model.to(device)
        for i in range(NUM_BATCHES):
            x_batch = data[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
            sent_features = get_sent_features(args, x_batch, tokenizer, sent_features_type="standard")
            embeddings = return_di_embeddings(model, sent_features, use_pooler=True, return_sent_emb=True)
            embeddings = embeddings.detach().cpu()
            reps[i * BATCH_SIZE: (i + 1) * BATCH_SIZE] = embeddings
            del sent_features
            del embeddings
        save_name = f"{MODE}_stolen_{victim_data}_with_{steal_dataset}_on_{victim_data}_data.pt"
        save_dir = os.path.join(REPRESENTATIONS_STOLEN_ROOT, save_name)
        torch.save(reps, save_dir)
        del model

    # evaluate the three independent models
    for independent in DATA:
        print(f"independent: {independent}")
        if MODE == "bert":
            independent_name = f"bert-base-uncased-{independent}-3epochs_unconverted"
            independent_path = os.path.join(INDEPENDENT_ROOT, independent_name)
            model = load_BERT_for_CL(independent_path, args, config, tokenizer)
        else:
            independent_name = f"roberta-uncased-{independent}-3epochs_unconverted"
            independent_path = os.path.join(INDEPENDENT_ROOT, independent_name)
            model = load_ROBERTA_for_CL(independent_path, args, config, tokenizer)



        model = DataParallel(model)
        for name, param in model.named_parameters():
            param.requires_grad = False
        model = model.to(device)
        for i in range(NUM_BATCHES):
            x_batch = data[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
            sent_features = get_sent_features(args, x_batch, tokenizer, sent_features_type="standard")
            embeddings = return_di_embeddings(model, sent_features, use_pooler=True, return_sent_emb=True)
            embeddings = embeddings.detach().cpu()
            reps[i * BATCH_SIZE: (i + 1) * BATCH_SIZE] = embeddings
            del sent_features
            del embeddings
        save_name = f"{MODE}_independent_{independent}_on_{victim_data}_data.pt"
        save_dir = os.path.join(REPRESENTATIONS_INDEPENDENT_ROOT, save_name)
        torch.save(reps, save_dir)
        del model
