import torch

from utils.base import ScoreEstimator


class SpectralScoreEstimator(ScoreEstimator):
    def __init__(self, n_eigen=None, eta=None, n_eigen_threshold=None):
        self._n_eigen = n_eigen
        self._eta = eta
        self._n_eigen_threshold = n_eigen_threshold
        super(SpectralScoreEstimator, self).__init__()

    def nystrom_ext(self, samples, x, eigen_vectors, eigen_values, kernel_width):
        # samples: [..., M, x_dim]
        # x: [..., N, x_dim]
        # eigen_vectors: [..., M, n_eigen]
        # eigen_values: [..., n_eigen]
        # return: [..., N, n_eigen], by default n_eigen=M.
        M = samples.size()[-2]
        # Kxq: [..., N, M]
        # grad_Kx: [..., N, M, x_dim]
        # grad_Kq: [..., N, M, x_dim]
        Kxq = self.gram(x, samples, kernel_width)
        # Kxq = tf.Print(Kxq, [tf.shape(Kxq)], message="Kxq:")
        # ret: [..., N, n_eigen]
        ret = torch.sqrt(torch.tensor(M) * 0.1) * torch.matmul(Kxq, eigen_vectors)
        ret *= 1. / torch.unsqueeze(eigen_values, -2)
        return ret

    def pick_n_eigen(self, eigen_values, eigen_vectors, n_eigen):
        # eigen_values: [..., M]
        # eigen_vectors: [..., M, M]
        M = eigen_values.shape()[-1]
        # eigen_values: [..., n_eigen]
        # top_k_indices: [..., n_eigen]
        eigen_values, top_k_indices = torch.topk(eigen_values,
                                                 k=n_eigen)
        # eigen_values = tf.Print(eigen_values, [eigen_values],
        #                         "eigen_values:", summarize=10)
        # eigen_vectors_flat: [... * M, M]
        eigen_vectors_flat = torch.reshape(
            torch.transpose(eigen_vectors), [-1, M])
        # eigen_vectors_flat = tf.Print(eigen_vectors_flat,
        #                               [tf.shape(eigen_vectors_flat)],
        #                               message="eigen_vectors_flat:")
        # indices_2d: [..., n_eigen]
        indices_2d = torch.reshape(top_k_indices, [-1, n_eigen])
        # indices_2d = tf.Print(indices_2d, [tf.shape(indices_2d)],
        #                       message="indices_2d:")
        indices_2d += torch.range(indices_2d.size()[0])[..., None] * M
        # indices_2d = tf.Print(indices_2d, [tf.shape(indices_2d)],
        #                       message="indices_2d:")
        # indices_flat: [... * n_eigen]
        indices_flat = torch.reshape(indices_2d, [-1])
        # indices_flat = tf.Print(indices_flat, [tf.shape(indices_flat)],
        #                         message="indices_flat")
        # eigen_vectors_flat: [... * n_eigen, M]
        eigen_vectors_flat = torch.gather(eigen_vectors_flat, indices_flat)
        eigen_vectors = torch.transpose(
            torch.reshape(eigen_vectors_flat,
                          torch.cat((top_k_indices.size(), torch.tensor([M])), dim=0)))
        # eigen_vectors = tf.Print(eigen_vectors, [tf.shape(eigen_vectors)],
        #                          message="eigen_vectors:", summarize=20)
        # eigen_vectors: [..., M, n_eigen]
        return eigen_values, eigen_vectors

    def compute_gradients(self, samples, x=None):
        # samples: [..., M, x_dim]
        # x: [..., N, x_dim]
        if x is None:
            kernel_width = self.heuristic_kernel_width(samples, samples)
            # TODO: Simplify computation
            x = samples
        else:
            # _samples: [..., N + M, x_dim]
            _samples = torch.cat([samples, x], dim=-2)
            kernel_width = self.heuristic_kernel_width(_samples, _samples)

        M = samples.size()[-2]
        # Kq: [..., M, M]
        # grad_K1: [..., M, M, x_dim]
        # grad_K2: [..., M, M, x_dim]
        Kq, grad_K1, grad_K2 = self.grad_gram(samples, samples, kernel_width)
        if self._eta is not None:
            Kq += self._eta * torch.eye(M, device=samples.device)
        # eigen_vectors: [..., M, M]
        # eigen_values: [..., M]

        eigen_values, eigen_vectors = torch.linalg.eigh(Kq)
        # eigen_vectors = tf.matrix_inverse(Kq)
        # eigen_values = tf.reduce_sum(Kq, -1)
        # eigen_values = tf.Print(eigen_values, [eigen_values],
        #                         message="eigen_values:", summarize=20)
        if (self._n_eigen is None) and (self._n_eigen_threshold is not None):
            eigen_arr = torch.mean(torch.reshape(eigen_values, [-1, M]), dim=0)
            eigen_arr = torch.flip(eigen_arr, [-1])
            eigen_arr /= torch.sum(eigen_arr)
            eigen_cum = torch.cumsum(eigen_arr, dim=-1)
            self._n_eigen = torch.sum(torch.less(eigen_cum, self._n_eigen_threshold).to(torch.int32))
            # self._n_eigen = tf.Print(self._n_eigen, [self._n_eigen],
            #                          message="n_eigen:")
        if self._n_eigen is not None:
            # eigen_values: [..., n_eigen]
            # eigen_vectors: [..., M, n_eigen]
            # eigen_values, eigen_vectors = self.pick_n_eigen(
            #     eigen_values, eigen_vectors, self._n_eigen)
            eigen_values = eigen_values[..., -self._n_eigen:]
            eigen_vectors = eigen_vectors[..., -self._n_eigen:]
        # eigen_ext: [..., N, n_eigen]
        eigen_ext = self.nystrom_ext(
            samples, x, eigen_vectors, eigen_values, kernel_width)
        # grad_K1_avg = [..., M, x_dim]
        grad_K1_avg = torch.mean(grad_K1, dim=-3)
        # beta: [..., n_eigen, x_dim]
        beta = -torch.sqrt(torch.tensor(M * 1.0, device=samples.device)) * torch.matmul(
            eigen_vectors, grad_K1_avg) / torch.unsqueeze(
            eigen_values, -1)
        # grads: [..., N, x_dim]
        grads = torch.matmul(eigen_ext, beta)
        return grads
