# Import stuff

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from collections import Counter

def flatten(grad: torch.Tensor):
    return torch.cat([grad_layer.view(-1) for grad_layer in grad])

def normalize(preds):
    min_vals, max_vals = torch.min(preds, dim=-1, keepdim=True)[0], torch.max(preds, dim=-1, keepdim=True)[0]
    normalized_preds = (preds - min_vals) / (max_vals - min_vals)
    normalized_preds = normalized_preds / torch.sum(normalized_preds, dim=-1, keepdim=True)
    return normalized_preds

def compute_entropy(preds):
    normalized_preds = normalize(preds)
    entropy = -torch.mean(torch.sum(normalized_preds * torch.log(normalized_preds + 1e-8), dim=-1))
    #if torch.isnan(entropy):
    #    import IPython; IPython.embed()

    return entropy

def compute_grad_map(dataset, learner, device, batch_size=1):
    #learner.eval()
    random_indices = np.arange(len(dataset))
    #np.random.shuffle(random_indices)

    grad_map = {}
    for i in range(0, len(dataset), batch_size):
        subset_dataset = []
        batch_size = min(len(dataset)-i, batch_size)
        for j in range(batch_size):
            subset_dataset.append(dataset[random_indices[i+j]][0])
        subset_dataset = torch.stack(subset_dataset).to(device)

        logits = learner(subset_dataset)
        for j in range(batch_size):
            grad = torch.autograd.grad(logits[j], learner.parameters(), retain_graph=True, create_graph=True)
            grad_map[random_indices[i+j]] = flatten(grad).detach().cpu()

        print('computeing {}/{} is done.'.format(i, len(dataset)))

    assert len(dataset) == len(list(grad_map.keys()))
    return grad_map

def compute_kernels(grad_map, indices1, indices2, device, batch_size=256):
    n1, n2 = len(indices1), len(indices2)
    kernel = torch.zeros((n1, n2))

    def _obtain_batch(i_start, i_end, indices, grad_map):
        batch = []
        for i in range(i_start, i_end):
            batch.append(grad_map[indices[i]])
        return torch.stack(batch)

    for i_start in range(0, n1, batch_size):
        i_end = min(i_start+batch_size, n1)
        i_batch = _obtain_batch(i_start, i_end, indices1, grad_map)
        for j_start in range(0, n2, batch_size):
            j_end = min(j_start+batch_size, n2)
            j_batch = _obtain_batch(j_start, j_end, indices2, grad_map)

            kernel_subset = torch.mm(i_batch, j_batch.T)
            kernel[i_start:i_end,j_start:j_end] = kernel_subset
    return kernel

def init_indices(ways, init_num, adaptation_indices, adaptation_labels):
    task_classes = np.arange(ways)
    np.random.shuffle(task_classes)

    adaptation_indices = adaptation_indices.detach().cpu().numpy()
    num_images = np.zeros(ways)
    for i in range(ways):
        task_class = task_classes[i]
        if init_num % ways > 0 and i < init_num % ways:
            num_images[task_class] += 1
        num_images[task_class] += init_num // ways

    indices = []
    for task_class in task_classes:
        index = np.random.choice(adaptation_indices[adaptation_labels == task_class],
                                 size=int(num_images[task_class]), replace=False).tolist()
        indices += index
    #while init_num > 0:
    #    for task_label in task_classes:
    #        if init_num <= 0:
    #            break
    #        index = np.random.choice(adaptation_indices[adaptation_labels == task_label])
    #        indices.append(index)
    #        init_num -= 1
    return np.array(indices)

def collect(candidate_indices, candidate_labels, indices):
    result = []
    for idx in indices:
        result.append(candidate_labels[candidate_indices == idx])
    return np.array(result).reshape(-1)

def update(candidate_to_index_map, candidate_indices, candidate_labels,
          images_per_class, new_selected_indices, new_train_labels):
    for i, train_label in enumerate(new_train_labels):
        images_per_class[train_label] -= 1
        assert images_per_class[train_label] >= 0
        if images_per_class[train_label] == 0:
            indices = candidate_indices[candidate_labels == train_label]
        else:
            indices = new_selected_indices[i:i+1]
        for idx in indices:
            try:
                candidate_to_index_map.pop(idx)
            except:
                continue
    return candidate_to_index_map, images_per_class

def extend(k_test_train, k_train_train, Y_train, k_test_cand, k_cand_cand,
           candidate_labels, selected_indices, idx, ways, device, expected_label=None):
    # extend k_train_train and k_test_train to new_k_train_train and new_k_test_train
    extra_k_train_train_col = k_cand_cand[selected_indices, idx:idx+1]
    extra_k_train_train_point = k_cand_cand[idx:idx+1, idx:idx+1]
    new_k_train_train = torch.cat((k_train_train, extra_k_train_train_col.T), dim=0)
    extra_k_train_train_col = torch.cat((extra_k_train_train_col, extra_k_train_train_point), dim=0)
    new_k_train_train = torch.cat((new_k_train_train, extra_k_train_train_col), dim=1)

    extra_k_test_train_col = k_test_cand[:,idx:idx+1]
    new_k_test_train = torch.cat((k_test_train, extra_k_test_train_col), dim=1)

    extra_Y_train = np.ones((1, ways)) * -0.1
    if expected_label is None:
        extra_Y_train[0][candidate_labels[idx]] = 0.1 * (ways-1)
    else:
        extra_Y_train[0][expected_label] = 0.1 * (ways-1)
    new_Y_train = torch.cat((Y_train, torch.from_numpy(extra_Y_train).to(device)), dim=0)
    return new_k_test_train, new_k_train_train, new_Y_train


def actively_select_batch(grad_map, dataset, batch, ways, train_shots, test_shots,
                          device, init_num=4, time=10):
    data, labels, indices = batch

    adaptation_indices_bool = np.zeros(len(indices), dtype=bool)
    selection = np.arange(ways) * (train_shots + test_shots)
    for offset in range(train_shots):
        adaptation_indices_bool[selection + offset] = True
    evaluation_indices_bool = ~adaptation_indices_bool
    adaptation_labels, adaptation_indices = labels[adaptation_indices_bool], indices[adaptation_indices_bool]
    evaluation_data, evaluation_labels, evaluation_indices =\
        data[evaluation_indices_bool], labels[evaluation_indices_bool], indices[evaluation_indices_bool]
    evaluation_indices = evaluation_indices.detach().cpu().numpy()

    images_per_class = Counter(adaptation_labels.detach().cpu().numpy().tolist())

    # find all candidate points
    (candidate_indices, candidate_labels) = dataset.load_candidates(
        adaptation_indices, evaluation_indices, adaptation_labels)
    candidate_to_index_map = {idx: i for i, idx in enumerate(candidate_indices)}
    index_to_candidate_map = {i: idx for i, idx in enumerate(candidate_indices)}

    # obtain init indices and corresponding kernels
    assert init_num > 0
    selected_cand_indices = init_indices(ways, init_num, adaptation_indices, adaptation_labels)
    k_train_train = compute_kernels(grad_map, selected_cand_indices, selected_cand_indices, device).to(device)
    k_test_train = compute_kernels(grad_map, evaluation_indices, selected_cand_indices, device).to(device)
    train_labels = collect(candidate_indices, candidate_labels, selected_cand_indices)

    n_test = evaluation_labels.shape[0]
    n_train = train_labels.shape[0]
    Y_train = np.ones((n_train, ways)) * -0.1
    for i in range(n_train):
        Y_train[i][train_labels[i]] = 0.1 * (ways-1)
    Y_train = torch.from_numpy(Y_train).to(device)

    # get selected indices from selected cand indices
    selected_indices = []
    for idx in selected_cand_indices:
        selected_indices.append(candidate_to_index_map[idx])
    selected_indices = np.array(selected_indices)

    # update candidate indices and labels
    candidate_to_index_map, images_per_class = update(
        candidate_to_index_map, candidate_indices, candidate_labels,
        images_per_class, selected_cand_indices, train_labels)

    # compute a kernel for all candidate points
    k_cand_cand = compute_kernels(grad_map, candidate_indices, candidate_indices, device).to(device)
    k_test_cand = compute_kernels(grad_map, evaluation_indices, candidate_indices, device).to(device)

    # for each candidate point, compute uncertainty
    while sum([val for key, val in images_per_class.items()]) > 0:
        #uncertainties = np.zeros(k_cand_cand.shape[0])
        uncertainties = -10000*np.ones(k_cand_cand.shape[0])
        k_train_train_inv = torch.inverse(k_train_train + 0.001*torch.eye(k_train_train.shape[0]).to(device))
        if time >= 1e-10:
            temp_mat = torch.mm(
                k_test_train, k_train_train_inv - torch.mm(k_train_train_inv, expm_one(-1*time*k_train_train.to('cpu')).to(device) ))
        else:
            temp_mat = torch.mm(k_test_train, k_train_train_inv)
        prev_preds = torch.mm(temp_mat, Y_train.float())
        prev_pred_labels = torch.argmax(prev_preds, dim=1).detach().cpu()
        prev_test_accuracy = torch.sum(prev_pred_labels==evaluation_labels).float() / n_test
        print('With {} samples, test accuracy = {}'.format(Y_train.shape[0], prev_test_accuracy))

        for candidate_idx in candidate_to_index_map:
            idx = candidate_to_index_map[candidate_idx]

            k_tmp_train = compute_kernels(grad_map, np.array([candidate_idx]), selected_cand_indices, device).to(device)
            tmp_mat = torch.mm(
                k_tmp_train, k_train_train_inv - torch.mm(k_train_train_inv, expm_one(-1*time*k_train_train.to('cpu')).to(device) ))
            tmp_preds = torch.mm(tmp_mat, Y_train.float())
            tmp_probs = normalize(tmp_preds)

            uncertainty = 0.
            for expected_label in range(ways):
                # extend kernels
                new_k_test_train, new_k_train_train, new_Y_train = extend(
                    k_test_train, k_train_train, Y_train, k_test_cand, k_cand_cand,
                    candidate_labels, selected_indices, idx, ways, device, expected_label=expected_label)

                # compute new_k_train_train_inv
                new_k_train_train_inv = torch.inverse(
                    new_k_train_train + 0.001*torch.eye(new_k_train_train.shape[0]).to(device))

                if time >= 1e-10:
                    temp_mat = torch.mm(
                        new_k_test_train, new_k_train_train_inv - torch.mm(new_k_train_train_inv, expm_one(-1*time*new_k_train_train.to('cpu')).to(device) ))
                else:
                    temp_mat = torch.mm(new_k_test_train, new_k_train_train_inv)

                # compute curr_preds
                curr_preds = torch.mm(temp_mat, new_Y_train.float())
                curr_pred_labels = torch.argmax(curr_preds, dim=1).detach().cpu()

                # compute uncertainty
                # TODO: put expectation
                #uncertainty = torch.sum(torch.norm(normalize(curr_preds) - normalize(prev_preds), dim=-1))
                #import IPython; IPython.embed()
                #uncertainty += tmp_probs[0][expected_label] * torch.sum(prev_pred_labels != curr_pred_labels)
                uncertainty += tmp_probs[0][expected_label] * torch.sum(curr_pred_labels==evaluation_labels)
                #uncertainty += tmp_probs[0][expected_label] * -1.0 * compute_entropy(curr_preds)

                #if torch.isnan(uncertainty):
                #    import IPython; IPython.embed()

            # store uncertainty
            uncertainties[idx] = uncertainty

        # get new index
        #import IPython; IPython.embed()
        #print('uncertainties: {}, num of nans: {}'.format(uncertainties, np.sum(np.isnan(uncertainties))))
        selected_idx = np.array([np.argmax(uncertainties)])
        selected_cand_idx = np.array(
            [index_to_candidate_map[idx] for idx in selected_idx])
        selected_train_labels = np.array(
            [candidate_labels[idx] for idx in selected_idx])

        #if images_per_class[selected_train_labels[0]] == 0:
        #    import IPython; IPython.embed()

        # updated candidate_to_index_map, and images_per_class
        candidate_to_index_map, images_per_class = update(
            candidate_to_index_map, candidate_indices, candidate_labels,
            images_per_class, selected_cand_idx, selected_train_labels)

        # update k_train_train, k_test_train, and Y_train
        k_test_train, k_train_train, Y_train = extend(
                k_test_train, k_train_train, Y_train, k_test_cand, k_cand_cand,
                candidate_labels, selected_indices, selected_idx[0], ways, device)

        # update selected indices, cand indices and train labels
        selected_indices = np.concatenate((selected_indices, selected_idx))
        selected_cand_indices = np.concatenate((selected_cand_indices, selected_cand_idx))
        train_labels = np.concatenate((train_labels, selected_train_labels))

    # get new batch = (data, labels, indices)
    batch_data = []
    batch_labels = []
    sorting_indices = np.argsort(train_labels)
    for label, cand_idx in zip(
        train_labels[sorting_indices], selected_cand_indices[sorting_indices]):
        data_idx, _ = dataset[cand_idx]
        batch_data.append(data_idx)
        batch_labels.append(label)

    data[adaptation_indices_bool] = torch.stack(batch_data)
    labels[adaptation_indices_bool] = torch.tensor(batch_labels).long()
    indices[adaptation_indices_bool] = torch.from_numpy(selected_cand_indices[sorting_indices])
    batch = (data, labels, indices)

    k_train_train_inv = torch.inverse(k_train_train + 0.001*torch.eye(k_train_train.shape[0]).to(device))
    if time >= 1e-10:
        temp_mat = torch.mm(
            k_test_train, k_train_train_inv - torch.mm(k_train_train_inv, expm_one(-1*time*k_train_train.to('cpu')).to(device) ))
    else:
        temp_mat = torch.mm(k_test_train, k_train_train_inv)
    prev_preds = torch.mm(temp_mat, Y_train.float())
    prev_pred_labels = torch.argmax(prev_preds, dim=1).detach().cpu()
    prev_test_accuracy = torch.sum(prev_pred_labels==evaluation_labels).float() / n_test
    print('With {} samples, test accuracy = {}'.format(Y_train.shape[0], prev_test_accuracy))

    misc = {'k_train_train': k_train_train, 'selected_cand_indices': selected_cand_indices,
            'selected_indices': selected_indices, 'train_labels': train_labels, 'grad_map': grad_map}
    return batch, misc

def compute_grads(net, dataset, device, batch_size=4):
    dataset = dataset.to(device)
    grads = []
    for i in range(0, dataset.shape[0], batch_size):
        train_batch = dataset[i:i+batch_size]
        n = train_batch.shape[0]

        logits = net(train_batch)
        for j in range(n):
            grad = torch.autograd.grad(logits[j], net.parameters(), retain_graph=True, create_graph=True)
            grads.append(flatten(grad))

    return torch.stack(grads)


def kernel_matrices(net, trainset, testset, device, kernels='both', batch_size=4):
    train_grads = compute_grads(net, trainset, device, batch_size=batch_size)
    if kernels == 'testvtrain' or kernels == 'both':
        test_grads = compute_grads(net, testset, device, batch_size=batch_size)
        K_testvtrain = torch.mm(test_grads, train_grads.T)

    K_trainvtrain = torch.mm(train_grads, train_grads.T)
    if kernels == 'both':
        return K_testvtrain, K_trainvtrain
    elif kernels == 'trainvtrain':
        return K_trainvtrain
    elif kernels == 'testvtrain':
        return K_testvtrain

#def kernel_mats(net, gamma_train, gamma_test, device, kernels='both', batch_size=4):
#    n_pts = len(gamma_test)
#    n = len(gamma_train)
#    # the following computes the gradients with respect to all parameters
#    grad_list = []
#    for i in range(0, n, batch_size):
#        gamma = gamma_train[i:i+batch_size].to(device)
#        batch = gamma.shape[0]
#        loss = net(gamma)
#        for j in range(batch):
#            grad = torch.autograd.grad(loss[j], net.parameters(), retain_graph = True, create_graph = True )
#            #if batch <= 1:
#            #    grad = torch.autograd.grad(loss.squeeze(), net.parameters(), retain_graph = True, create_graph = True )
#            #else:
#            #    grad = torch.autograd.grad(loss.squeeze(), net.parameters(), basis_vectors[j], retain_graph = True, create_graph = True )
#            ##grad = flatten(grad)
#            grad_list.append(grad)
#
#    # testvstrain kernel
#    if kernels=='both' or kernels=='testvtrain':
#        K_testvtrain = torch.zeros((n_pts,n))
#        for i, gamma in enumerate(gamma_test):
#            gamma = gamma.to(device)
#            gamma = gamma.unsqueeze(0)
#            loss = net(gamma)
#            grads = torch.autograd.grad(loss,net.parameters(), retain_graph = True, create_graph = True ) # extract NN gradients
#
#            for j in range(len(grad_list)):
#                pt_grad = grad_list[j] # the gradients at the jth (out of 4) data point
#                #K_testvtrain[i, j] = sum([torch.sum(torch.mul(grads[u], pt_grad[u])) for u in range(len(grads))])
#                K_testvtrain[i, j] = torch.dot(flatten(grads), flatten(pt_grad))
#
#    # trainvstrain kernel
#    if kernels=='both' or kernels=='trainvtrain':
#        K_trainvtrain = torch.zeros((n, n))
#        for i in range(n):
#            grad_i = grad_list[i]
#            for j in range(i+1):
#                grad_j = grad_list[j]
#                #K_trainvtrain[i, j] = sum([torch.sum(torch.mul(grad_i[u], grad_j[u])) for u in range(len(grad_j))])
#                K_trainvtrain[i, j] = torch.dot(flatten(grad_i), flatten(grad_j))
#                K_trainvtrain[j, i] = K_trainvtrain[i, j]
#
#    if kernels=='both':
#        return K_testvtrain, K_trainvtrain
#    elif kernels=='trainvtrain':
#        return K_trainvtrain
#    elif kernels=='testvtrain':
#        return K_testvtrain

def kernel_mats(net, gamma_train, gamma_test, device, kernels='both', batch_size=4):
  # for a given net, this function computes the K_testvtrain (n_test by n_train) and the
  # K_trainvtrain (n_train by n_train) kernels.
  # You can choose which one to return by the parameter 'kernels', with values 'both' (default), 'testvtrain' or 'trainvtrain'
    n_train = len(gamma_train)
  # suppose cuda available
    n_pts = len(gamma_test)
# the following computes the gradients with respect to all parameters
    grad_list = []
    for gamma in gamma_train:
        gamma = gamma.to(device)
        gamma = gamma.unsqueeze(0)
        loss = net(gamma)
        grad_list.append(torch.autograd.grad(loss,net.parameters(), retain_graph = True, create_graph = True ))

# testvstrain kernel
    if kernels=='both' or kernels=='testvtrain':
        K_testvtrain = torch.zeros((n_pts,n_train))
        for i, gamma in enumerate(gamma_test):
            gamma = gamma.to(device)
            gamma = gamma.unsqueeze(0)
            loss = net(gamma)
            grads = torch.autograd.grad(loss,net.parameters(), retain_graph = True, create_graph = True ) # extract NN gradients

            for j in range(len(grad_list)):
                pt_grad = grad_list[j] # the gradients at the jth (out of 4) data point
                K_testvtrain[i, j] = sum([torch.sum(torch.mul(grads[u], pt_grad[u])) for u in range(len(grads))])

# trainvstrain kernel
    if kernels=='both' or kernels=='trainvtrain':
        K_trainvtrain = torch.zeros((n_train,n_train))
        for i in range(n_train):
            grad_i = grad_list[i]
            for j in range(i+1):
                grad_j = grad_list[j]
                K_trainvtrain[i, j] = sum([torch.sum(torch.mul(grad_i[u], grad_j[u])) for u in range(len(grad_j))])
                K_trainvtrain[j, i] = K_trainvtrain[i, j]

    if kernels=='both':
        return K_testvtrain, K_trainvtrain
    elif kernels=='trainvtrain':
        return K_trainvtrain
    elif kernels=='testvtrain':
        return K_testvtrain

# def kernel_mats(net, gamma_train, gamma_test, device, n_train, kernels='both'):
#     # for a given net, this function computes the K_testvtrain (n_test by n_train) and the
#     # K_trainvtrain (n_train by n_train) kernels.
#     # You can choose which one to return by the parameter 'kernels', with values 'both' (default), 'testvtrain' or 'trainvtrain'
#
#     # suppose cuda available
#     n_pts = len(gamma_test)
#     # the following computes the gradients with respect to all parameters
#     grad_list = []
#     y_train = net(gamma_train.to(device))
#     y_test = net(gamma_test.to(device))
#
#     # testvstrain kernel
#     if kernels == 'both' or kernels == 'testvtrain':
#         # K_testvtrain = torch.zeros((n_pts, n_train))
#         K_testvtrain = torch.exp(-torch.sum(torch.square(torch.unsqueeze(y_test, -2) - torch.unsqueeze(y_train, -3)), -1))
#
#     # trainvstrain kernel
#     if kernels == 'both' or kernels == 'trainvtrain':
#         K_trainvtrain = torch.exp(-torch.sum(torch.square(torch.unsqueeze(y_train, -2) - torch.unsqueeze(y_train, -3)), -1))
#
#
#     if kernels == 'both':
#         return K_testvtrain, K_trainvtrain
#     elif kernels == 'trainvtrain':
#         return K_trainvtrain
#     elif kernels == 'testvtrain':
#         return K_testvtrain

@torch.jit.script
def torch_pade13(A):  # pragma: no cover
    # avoid torch select operation and unpack coefs
    (b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13) = (
        64764752532480000.0,
        32382376266240000.0,
        7771770303897600.0,
        1187353796428800.0,
        129060195264000.0,
        10559470521600.0,
        670442572800.0,
        33522128640.0,
        1323241920.0,
        40840800.0,
        960960.0,
        16380.0,
        182.0,
        1.0,
    )

    ident = torch.eye(A.shape[1], dtype=A.dtype, device=A.device)
    A2 = torch.matmul(A, A)
    A4 = torch.matmul(A2, A2)
    A6 = torch.matmul(A4, A2)
    U = torch.matmul(
        A,
        torch.matmul(A6, b13 * A6 + b11 * A4 + b9 * A2)
        + b7 * A6
        + b5 * A4
        + b3 * A2
        + b1 * ident,
    )
    V = (
        torch.matmul(A6, b12 * A6 + b10 * A4 + b8 * A2)
        + b6 * A6
        + b4 * A4
        + b2 * A2
        + b0 * ident
    )
    return U, V


@torch.jit.script
def matrix_2_power(x, p):  # pragma: no cover
    for _ in range(int(p)):
        x = x @ x
    return x


@torch.jit.script
def expm_one(A):  # pragma: no cover
    # no checks, this is private implementation
    # but A should be a matrix
    A_fro = torch.norm(A)

    # Scaling step

    n_squarings = torch.clamp(
        torch.ceil(torch.log(A_fro / 5.371920351148152).div(0.6931471805599453)), min=0
    )
    scaling = 2.0 ** n_squarings
    Ascaled = A / scaling

    # Pade 13 approximation
    U, V = torch_pade13(Ascaled)
    P = U + V
    Q = -U + V

    R = torch.linalg.solve(Q, P)  # solve Q*R = P
    expmA = matrix_2_power(R, n_squarings)
    return expmA


