import numpy as np
from sklearn.preprocessing import normalize as norm_rep
from numpy import dot
from numpy import linalg as LA
import os
import torch
import pandas as pd

MODE = "roberta" # "tiny", "bert", "roberta"


if MODE == "tiny":
    REPRESENTATIONS_VICTIM_ROOT =
    REPRESENTATIONS_INDEPENDENT_ROOT =
    REPRESENTATIONS_STOLEN_ROOT =
    APPEND = ""

elif MODE == "bert":
    REPRESENTATIONS_VICTIM_ROOT =
    REPRESENTATIONS_INDEPENDENT_ROOT =
    REPRESENTATIONS_STOLEN_ROOT =
    APPEND = "_ours"

elif MODE == "roberta":
    REPRESENTATIONS_VICTIM_ROOT =
    REPRESENTATIONS_INDEPENDENT_ROOT =
    REPRESENTATIONS_STOLEN_ROOT =
    APPEND = "_ours"

else:
    raise Exception("No such models.")



def cosine_similarity(x, y, normalize = True, center = True):
    """
    Calculate the cosine similarity (absolute value) between two representations.
    ----------
    x, y: (n, d) ndarray
        n samples from d-dimensional representations
    normalize: bool
        normalize representations
    center: bool
        center representations (mean 0, var 1)
    Returns:
    --------
    sim: float
        cosine similarity between x and y
    """

    assert x.shape == y.shape
    if center:
        centerx = (x - np.mean(x, axis=1).reshape(x.shape[0], 1))/(np.std(x, axis=1).reshape(x.shape[0], 1))
        x = centerx.copy()
        centery = (y - np.mean(y, axis=1).reshape(y.shape[0], 1))/(np.std(y, axis=1).reshape(y.shape[0], 1))
        y = centery.copy()
    if normalize:
        x = norm_rep(x)
        y = norm_rep(y)
    # https://stackoverflow.com/questions/18424228/cosine-similarity-between-2-number-lists
    sims = []
    for i in range(x.shape[0]): # similarity between corresponding indices
        cos_sim = dot(x[i], y[i]) / (LA.norm(x[i]) * LA.norm(y[i]))
        sims.append(cos_sim)
    sims = np.array(sims)
    return np.round(sims.mean(),2), np.round(sims.std(),2), sims

if __name__ == "__main__":
    DATA = ["nli", "qqp", "flickr"]

    if MODE == "tiny":
        VIC_DATA = ["nli", "qqp", "flickr"]
    else:
        VIC_DATA = ["nli"] # the bigger models don't have stolen ones.

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(device)

    nli_scores = []
    qqp_scores = []
    flickr_scores = []

    for victim_data in VIC_DATA:
        print(f"evaluating for victim: {victim_data}")
        if victim_data == "nli":
            our_list = nli_scores
        elif victim_data == "qqp":
            our_list = qqp_scores
        elif victim_data == "flickr":
            our_list = flickr_scores
        else:
            raise Exception("No such data")

        save_name = f"{MODE}_victim_{victim_data}_on_{victim_data}_data.pt"
        rep_dir = os.path.join(REPRESENTATIONS_VICTIM_ROOT, save_name)

        orig_reps = torch.load(rep_dir)
        orig_reps = orig_reps.cpu().detach().numpy()
        score = cosine_similarity(orig_reps, orig_reps)[:2]
        our_list.append(score)


        for steal_dataset in DATA:
            print(f"on stolen data: {steal_dataset}")
            save_name = f"{MODE}_stolen_{victim_data}_with_{steal_dataset}_on_{victim_data}_data.pt"
            rep_dir = os.path.join(REPRESENTATIONS_STOLEN_ROOT, save_name)

            stolen_reps = torch.load(rep_dir)
            stolen_reps = stolen_reps.cpu().detach().numpy()
            score = cosine_similarity(orig_reps, stolen_reps)[:2]
            our_list.append(score)

        for independent in DATA:

            # for all but the tiny bert, we need to replace the representations of their model - with our model
            if MODE != "tiny":
                save_name = f"{MODE}_victim_{victim_data}_on_{victim_data}_data{APPEND}.pt"
                rep_dir = os.path.join(REPRESENTATIONS_VICTIM_ROOT, save_name)
                orig_reps = torch.load(rep_dir)
                orig_reps = orig_reps.cpu().detach().numpy()

            print(f"on independent models: {independent}")
            save_name = f"{MODE}_independent_{independent}_on_{victim_data}_data.pt"
            rep_dir = os.path.join(REPRESENTATIONS_INDEPENDENT_ROOT, save_name)

            indep_reps = torch.load(rep_dir)
            indep_reps = indep_reps.cpu().detach().numpy()
            score = cosine_similarity(orig_reps, indep_reps)[:2]
            our_list.append(score)


    if MODE == "tiny":
        d = {'nli':nli_scores,'qqp':qqp_scores, 'flickr':flickr_scores}
    else:
        d = {'nli':nli_scores}
    df = pd.DataFrame(d)
    print(df.T.to_latex())




