from contrastive import ContrastiveNetwork, ContrastiveNetworkMixer, train_contrastive_network, get_embs, seed
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from scipy.stats import kendalltau
from tqdm import tqdm
import torch
from einops import rearrange, reduce
import sklearn
import lightgbm as lgb 
from sklearn.ensemble import RandomForestRegressor
from sklearn.datasets import make_regression


ea = "" #path of epdjms for search space a
eb = "" #path of epdjms for search space b

aa = '' #path to accs for search space a
ab = '' #path to accs for search space b

ap = "" #path to the number of parameters in search space a
bp = "" #path to the number of parameters in search space b

ap = np.load(ap)
bp = np.load(bp)


params = np.concatenate((ap, bp))

seed(seed = 42)
sns.set()

from torch.utils.data import Dataset
class CachedProjectedJacobians(Dataset):

    def __init__(self, num_augs = 2, emb_path = "data/imageNet_128_128_4_10_False_False_301_split.npy", val_path = "./Bench301/pred_split.npy", num_jacs = -1):

        self.data = np.load(emb_path, mmap_mode='r')
        self.data_augs = self.data[0].shape[0]
        self.proj_size = self.data[0].shape[2]
        self.num_jacs = num_jacs if num_jacs != -1 else self.data[0].shape[1]
        
        self.num_augs = num_augs
        if val_path != None:
            self.val = np.load(val_path)
        else:
            self.val = None
        

        assert(self.num_augs <= self.data_augs)
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        e = self.data[index][np.random.choice(self.data_augs, self.num_augs, replace=False), :self.num_jacs]
        if np.any(np.isnan(e)):
            e.fill(0)
        #n = np.linalg.norm(e, axis = -1, keepdims = True)
        #n[n==0] = 1
        #e /= n

        if self.val is not None:
            return e, [self.val[index]]*self.num_augs
        else:
            return e


class CombinedData():
    def __init__(self, num_augs = 2, emb_path_a = "data/imageNet_128_128_4_10_False_False_301.npy", val_path_a = "./Bench301/rt.npy", emb_path_b = "data/imageNet_128_128_4_10_False_False_201.npy", 
    val_path_b = './Bench201/data/cifar10_test_accs.npy'):
        self.a = CachedProjectedJacobians(num_augs, emb_path_a , val_path_a)
        self.b = CachedProjectedJacobians(num_augs, emb_path_b , val_path_b, num_jacs = self.a.num_jacs)

        
    def __len__(self):
        return len(self.a)+ len(self.b)

    def __getitem__(self, index):
        if index < len(self.a):
            return self.a[index]
        else:
            return self.b[index-len(self.a)]



data_set = CombinedData(num_augs = 2, 
    emb_path_a = ea,  val_path_a = None, 
    emb_path_b = eb, val_path_b = None)

net = ContrastiveNetworkMixer(data_set[0].shape[-1], emb_size = 512, projection_head_out_size=1024, channels = data_set[0].shape[-2])
net.cuda()

train_contrastive_network(net, data_set, batch_size=512, epochs=30, lr = 0.5e-3, barlow = True)


data_set = CombinedData(num_augs = 4, 
    emb_path_a = ea, val_path_a = aa, 
    emb_path_b = eb, val_path_b = ab)


def get_embs_():
    net.eval()
    data_set = CombinedData(num_augs = 4, 
        emb_path_a = ea, val_path_a = aa, 
        emb_path_b = eb, val_path_b = ab)

    embs = []
    vals = []
    for c in tqdm(range(len(data_set))):
        emb, val = data_set.__getitem__(c)
        a, _ = net(torch.tensor(emb).cuda())
        embs += [a.detach().cpu().numpy()]
        vals += [val]

    embs = np.array(embs)
    vals = np.array(vals)
    return embs, vals

all_embs, all_vals = get_embs_()

params_broadcasted = np.broadcast_to(params[:,None,None], (*all_embs.shape[:2], 1))

print(all_embs.shape, params_broadcasted.shape)
all_embs = np.concatenate((all_embs, params_broadcasted), axis = -1)

def get_embs(indices, num_augs):
    accs = rearrange(all_vals[indices, :num_augs], "b augs ... -> (b augs) ...")
    embs = rearrange(all_embs[indices, :num_augs], "b augs ... -> (b augs) ...")
    return accs, embs

a_len = len(data_set.a)
b_len = len(data_set.b)
a_start = 0
b_start = a_len


def fit_surrogate(indices, num_augs = 4, method = 'bo'):
    accs, embs = get_embs(indices, num_augs = num_augs)
        
    if method == 'rf':
        rf = RandomForestRegressor(max_features = 8) # note another hyper parameter, using many is just slow and I don't like waiting. That said this is fine.
        rf.fit(embs, accs)
        return rf
    elif method == 'lgb':
        return lgb.train({'objective': 'regression', 'verbosity':-1}, lgb.Dataset(embs, label=accs))
    
    assert(False)

def predict_surrogate(surrogate, indices, num_augs = 4, method = 'bo'):
    accs, embs = get_embs(indices, num_augs = num_augs)
        
    if method == 'rf':
        predicted = surrogate.predict(embs)
    elif method == 'lgb':
        predicted = surrogate.predict(embs)
    else: 
        assert(False)
    
    return reduce(predicted,  "(b augs)-> b", 'mean', augs = num_augs), reduce(accs,  "(b augs)-> b", 'mean', augs = num_augs)

print(all_embs.shape)
print(all_vals.shape)


num_runs = 10
corrs = np.zeros((2, 2, 10))
taus  = np.zeros((2, 2, 10))

for run in range(10):
    num_train = 5000
    num_val = 3000
    indices = np.random.choice(a_len, num_train+num_val, replace = False)
    a_train = indices[:num_train]+a_start
    a_val = indices[num_train:]+a_start


    num_train = 5000
    num_val = 3000
    indices = np.random.choice(b_len, num_train+num_val, replace = False)
    b_train = indices[:num_train]+b_start
    b_val = indices[num_train:]+b_start

    names = ["101", "201"]
    train = [a_train, b_train]
    val  = [a_val, b_val]

    na = 4
    for i in range(2):
        for j in range(2):
            train_space =  names[i]
            val_space   = names[j]
            m = fit_surrogate(train[i], method = 'rf')
            predicted, accs = predict_surrogate(m, val[j], method = 'rf')
            tau, corr = kendalltau(predicted, accs)[0], np.corrcoef(predicted,accs)[0,1]
            print(train_space + "--" + val_space," ", corr, tau)
            corrs[i,j, run] = corr
            taus [i,j, run] = tau


            if run == 0:
                plt.figure(figsize=(5,5))
                plt.scatter(accs,predicted , s=1)
                plt.xlabel("Accuracy", fontsize=14)
                plt.ylabel("Predicted Accuracy", fontsize=14)
                plt.xticks(fontsize=14)
                plt.yticks(fontsize=14)
                plt.tight_layout()
                plt.savefig("figures_paper/transfer/"+train_space + "--" + val_space + "_barlow_5000_new.pdf")
                np.save("figures_paper/transfer/"+train_space + "--" + val_space + "_barlow_5000_new.npy", (accs, predicted))

    print("\n")
print(corrs.mean(-1), taus.mean(-1))
np.save("figures_paper/transfer/metrics_barlow_5000_new.npy", [corrs, taus])
