import numpy as np
from sklearn.preprocessing import normalize as norm_rep
from numpy import dot
from numpy import linalg as LA
import os
import torch
import pandas as pd
from calculcate_cosine_similarity import cosine_similarity

MODE = "tiny"

REPRESENTATIONS_VICTIM_ROOT =
REPRESENTATIONS_INDEPENDENT_ROOT =
REPRESENTATIONS_STOLEN_ROOT =


DATA = ["flickr", "flickr", "qqp"] # the first two need to be evaluated on flickr, the last on qqp
MIXED_NAMES = ["flickr_mixed_into_qqp", "full_flickr_mixed_into_qqp", "qqp_mixed_into_flickr"]
DATA_NAMES = ["nli", "qqp", "flickr"]


device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

flickr_mixed_into_qqp = []
full_flickr_mixed_into_qqp = []
qqp_mixed_into_flickr = []

for j,victim_name in enumerate(MIXED_NAMES):
    print(f"evaluating for victim: {victim_name}")
    if victim_name == "flickr_mixed_into_qqp":
        our_list = flickr_mixed_into_qqp
    elif victim_name == "full_flickr_mixed_into_qqp":
        our_list = full_flickr_mixed_into_qqp
    elif victim_name == "qqp_mixed_into_flickr":
        our_list = qqp_mixed_into_flickr
    else:
        raise Exception("No such data")

    victim_data = DATA[j]

    save_name = f"{MODE}_victim_{victim_name}_on_{victim_data}_data.pt"
    print(save_name)
    rep_dir = os.path.join(REPRESENTATIONS_VICTIM_ROOT, save_name)

    orig_reps = torch.load(rep_dir)
    orig_reps = orig_reps.cpu().detach().numpy()
    score = cosine_similarity(orig_reps, orig_reps)[:2]
    our_list.append(score)


    for steal_dataset in DATA_NAMES:
        print(f"on stolen data: {steal_dataset}")
        save_name = f"{MODE}_stolen_{victim_name}_with_{steal_dataset}_on_{victim_data}_data.pt"
        print(save_name)
        rep_dir = os.path.join(REPRESENTATIONS_STOLEN_ROOT, save_name)

        stolen_reps = torch.load(rep_dir)
        stolen_reps = stolen_reps.cpu().detach().numpy()
        score = cosine_similarity(orig_reps, stolen_reps)[:2]
        our_list.append(score)

    for independent in DATA_NAMES:

        print(f"on independent models: {independent}")
        save_name = f"{MODE}_independent_{independent}_on_{victim_data}_data.pt"
        print(save_name)
        rep_dir = os.path.join(REPRESENTATIONS_INDEPENDENT_ROOT, save_name)

        indep_reps = torch.load(rep_dir)
        indep_reps = indep_reps.cpu().detach().numpy()
        score = cosine_similarity(orig_reps, indep_reps)[:2]
        our_list.append(score)


d = {'flickr_mixed_into_qqp':flickr_mixed_into_qqp,'full_flickr_mixed_into_qqp':full_flickr_mixed_into_qqp, 'qqp_mixed_into_flickr':qqp_mixed_into_flickr}

df = pd.DataFrame(d)
print(df.T.to_latex())




