import torch
from contrastive import ContrastiveNetwork,ContrastiveNetworkMixer, train_contrastive_network, get_embs


path =  "data/201_epdm.npy" # path to the empdms
from torch.utils.data import Dataset
import numpy as np
from einops import rearrange

class CachedProjectedJacobians(Dataset):

    def __init__(self, num_augs = 2, path = "../Bench201/cifar100_64_256_4_100.npy", valid_indices = None):
        self.data = np.load(path, mmap_mode="r")
        
        self.data_augs = self.data[0].shape[0]
        self.proj_size = self.data[0].shape[2]
        self.num_augs = num_augs
        assert(num_augs <= self.data_augs)
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        ret = self.data[index]
        
        if ret is None or np.isnan(ret).any():
            ret = self.data[0][np.random.choice(self.data_augs, self.num_augs, replace=False)]
            ret.fill(0)
            #print("early ret")
            return ret.squeeze()
        
        ret = self.data[index][np.random.choice(self.data_augs, self.num_augs, replace=False)]
        
        return ret.squeeze()

data_set = CachedProjectedJacobians(path = path, num_augs = 2)

print(data_set[0].shape)

net = ContrastiveNetworkMixer(data_set[0].shape[-1], emb_size = 512, projection_head_out_size=1024, channels = data_set[0].shape[-2])
net.cuda()

train_contrastive_network(net, data_set, batch_size=512, epochs=30, barlow = True, lr = 1e-3, val=False)

data_set = CachedProjectedJacobians(path = path, num_augs = 4)
embs = get_embs(net, data_set)

np.save("embs/embs_201_[0.87-1.1428].npy", embs) 
