import os
from low_rank_greedy_acquisition import LowRankGreedyAcquisition
from k_centers_greedy_acquisition import KCentersGreedyAcquisition
import numpy as np
import torch
import tqdm
import time
from random import shuffle

seed = 10
lambd = 0.0
sigma = 0.0
theta_norm_guess = 1
sample_size = 2000
n_gpus = 3
batch_size = 110 * n_gpus
n_acquisitions = 2000
acquisition_type = 'low_rank'  # 'k_centers' # 'random'

data_root = f"{os.environ['HOME']}/data/"
k_dd = np.load(f'{data_root}/k_dd.2020-08-23_15:27:15.ntk.npy', mmap_mode='r')
k_dt = np.load(f'{data_root}/k_td.2020-08-24_02:40:42.ntk.npy', mmap_mode='r')


def timestr():
    return time.strftime("%Y-%m-%d_%H:%M:%S", time.gmtime())


_, _, y_train, y_test = np.load(f'{data_root}/mnist.2020-08-25_11:35:50.npz').values()

np.random.seed(seed)
torch.manual_seed(seed)


class RandomAcquisition(object):
    def __init__(self, n):
        self._shuffled = list(range(n))
        shuffle(self._shuffled)
        self._xi = []

    def next(self):
        return self._xi.append(self._shuffled.pop())

    @property
    def xi(self):
        return self._xi


k_vv = torch.tensor(k_dd).type(torch.float64)

if acquisition_type == 'low_rank':
    acquisition = LowRankGreedyAcquisition(k_vv, lambd, sigma, theta_norm_guess, sample_size, batch_size, n_gpus)
elif acquisition_type == 'random':
    acquisition = RandomAcquisition(k_dd.shape[0])
elif acquisition_type == 'k_centers':
    acquisition = KCentersGreedyAcquisition(k_vv, sample_size)

acc = None
result_root = f'{os.environ["HOME"]}/tmp/'
result_file = open(f'{result_root}/mnist_lenet_oed.' + timestr() + '.csv', 'w')
result_file.write('xi, acc\n')
trange = tqdm.trange(len(acquisition.xi), n_acquisitions)
for i in trange:
    trange.set_description(f'acc={acc}')
    acquisition.next()
    xi = acquisition.xi
    k_inv = acquisition.k_inv if hasattr(acquisition, 'k_inv') else \
        (torch.tensor(k_dd[:, xi][xi, :]) + torch.eye(len(xi)) * lambd).inverse()
    y_hat = k_dt[:, xi] @ k_inv.numpy() @ y_train[xi, :]
    acc = (y_hat.argmax(axis=1) == y_test.argmax(axis=1)).mean()
    new_ind = xi[-1]
    result_file.write(f'{new_ind},{acc}\n')
    result_file.flush()
#   if i % 100 == 0:
#      acquisition._k_vv = None
#      torch.save(acquisition, f'{result_root}/checkpoint.{i}.'+timestr())
#      acquisition._k_vv = k_vv
