from functools import reduce
import numpy as np
import torch
from joblib import Parallel, delayed


class LowRankGreedyAcquisition(object):
    def __init__(self, k_vv, lambd, sigma, theta_norm_guess, sample_size, batch_size, n_gpus):
        self._k_vv = k_vv
        self._lambd = lambd
        self._sigma = sigma
        self._theta_norm_guess = theta_norm_guess
        self._sample_size = sample_size
        self._batch_size = batch_size
        self._n_gpus = n_gpus
        self._parallel = None  # Parallel(n_jobs=n_gpus, prefer="threads")

        self._context = LowRankGreedyContext(dtype=k_vv.dtype)
        self._xi = []

    def next(self):
        self._xi, self._context = low_rank_greedy_acquisition(self._k_vv, self._context, self._xi, self._lambd,
                                                              self._sigma,
                                                              self._theta_norm_guess,
                                                              self._sample_size, self._batch_size, self._n_gpus,
                                                              self._parallel)
        k_inv = self._context.A
        return self._xi[-1], k_inv.squeeze(0)

    @property
    def xi(self):
        return self._xi

    @property
    def k_inv(self):
        return self._context.A.squeeze(0)


class LowRankGreedyContext:
    def __init__(self, dtype):
        self.A = torch.zeros(0, 0, dtype=dtype)
        self.AS = torch.zeros(0, 0, dtype=dtype)
        self.B = torch.zeros(0, 0, dtype=dtype)
        self.C = torch.zeros(0, 0, dtype=dtype)
        self.D = torch.zeros(0, 0, dtype=dtype)
        self.E = torch.zeros(0, 0, dtype=dtype)
        self.F = torch.zeros(0, 0, dtype=dtype)

    def pad(self):
        self.A, self.AS, self.B, self.C, self.D, self.E, self.F = \
            pad(self.A), pad(self.AS), pad(self.B), pad(self.C), pad(self.D), pad(self.E), pad(self.F)


def low_rank_greedy_acquisition(k_vv, context, xi, lambd, sigma, theta_norm_guess, sample_size, batch_size, n_gpus,
                                parallel):
    n = k_vv.shape[0]
    a0, as0, b0, c0, d0, e0, f0 = (pad(context.A), pad(context.AS), pad(context.B),
                                   pad(context.C), pad(context.D), pad(context.E), pad(context.F))
    if parallel is None:
        parallel = Parallel(n_jobs=n_gpus, prefer="threads")
    delayed_compute_low_rank_score = delayed(
        lambda device, ks: compute_low_rank_score(a0, as0, b0, c0, d0, e0, f0, k_vv, xi, ks,
                                                  lambd, sigma, theta_norm_guess, device)[0])

    pool = list(set(range(n)) - set(xi))
    assert (len(pool) > 0)
    if sample_size < n:
        pool = np.random.choice(pool, sample_size, replace=False).tolist()

    score = torch.ones(n, dtype=k_vv.dtype) * float('inf')
    batches = list(chunks(pool, batch_size))
    for batch in batches:
        local_batch_size = int((len(batch) - 1) / n_gpus) + 1
        ks_per_device = list(chunks(batch, local_batch_size))
        score_per_device = parallel(
            delayed_compute_low_rank_score(device, ks) for device, ks in (enumerate(ks_per_device)))
        score[batch] = torch.cat([s.to('cpu') for s in score_per_device])

    k_win = int(torch.argmin(score))
    score_win, context.A, context.AS, context.B, context.C, context.D, context.E, context.F = \
        compute_low_rank_score(a0, as0, b0, c0, d0, e0, f0, k_vv, xi, [k_win], lambd, sigma, theta_norm_guess,
                               'cpu')

    new_xi = xi + [k_win]
    return new_xi, context


def compute_low_rank_score(a0, as0, b0, c0, d0, e0, f0, k_vv, xi, ks, lambd, sigma, theta_norm_guess, device):
    a0, as0, b0, c0, d0, e0, f0 = a0.to(device), as0.to(device), b0.to(device), c0.to(device), \
                                  d0.to(device), e0.to(device), f0.to(device)

    q = OperatorQ(k_vv, xi, lambd, a0, device)
    w = OperatorW(k_vv, xi, device)
    z = OperatorZ(k_vv, xi, device)

    q.update(ks)
    w.update(ks)
    z.update(ks)

    a = a0 + q.get()
    b = b0 + w.get()
    c = c0 + z.get()
    d = d0 + q(b0).t().get() + w(a).get()
    r = OperatorsSum([q(a0), q(a0).t(), q(q)])
    as_ = as0 + r.get()
    b0r = r.t()(b0).t()
    was = w(as_)
    e = e0 + b0r(c).get() + z(as0)(b0).t().get() + was(c).get()
    f = f0 + b0r.get() + was.get()
    bias = theta_norm_guess ** 2 * (-2 * trace(d) + trace(e))
    variance = sigma ** 2 * trace(f)
    score = bias + variance
    return score, a, as_, b, c, d, e, f


class LowRankOperator:
    def __init__(self, left, right):
        self._left = left.clone()
        self._right = right.clone()

    def __call__(self, a):
        if isinstance(a, torch.Tensor):
            return LowRankOperator(self._left, self._right @ a)
        if isinstance(a, LowRankOperator):
            return LowRankOperator(self._left, self._right @ a._left @ a._right)
        if isinstance(a, OperatorsSum):
            return LowRankOperator(self._left, self._right @ a.get())

        return LowRankOperator(self._left, a * self._right)

    def t(self):
        return LowRankOperator(t(self._right), t(self._left))

    def get(self):
        return self._left @ self._right


class OperatorsSum:
    def __init__(self, operators):
        self._operators = operators

    def __call__(self, a):
        return OperatorsSum([o(a) for o in self._operators])

    def t(self):
        return OperatorsSum([o.t() for o in self._operators])

    def get(self):
        def get(x):
            if isinstance(x, LowRankOperator) or isinstance(x, OperatorsSum):
                return x.get()
            else:
                return x

        return reduce(lambda a, b: get(a) + get(b), self._operators)

    def set(self, operators):
        self._operators = operators


class OperatorQ(OperatorsSum):
    def __init__(self, k_vv, xi, lambd, k_inv, device):
        self._k_vv = k_vv
        self._xi = xi
        self._lambd = lambd
        self._k_inv = k_inv.to(device)
        self._device = device
        super().__init__(None)

    def update(self, ks):
        b = len(ks)
        m = len(self._xi) + 1
        q = [pad(self._k_vv[self._xi, k:k + 1], 1, 0).unsqueeze(0)
             for k in ks]
        q = torch.cat(q, dim=0).to(self._device)
        xtx = torch.cat([self._k_vv[k:k + 1, k:k + 1].unsqueeze(0) for k in ks], dim=0).to(self._device)
        r = 1 / (xtx + self._lambd - t(q) @ self._k_inv @ q)
        u = self._k_inv @ q
        ru = r * u
        minus_ru = -ru
        minus_ru_r = minus_ru.clone()
        minus_ru_r[:, -1, 0] = r[:, 0, 0]
        dtype = u.dtype
        super().set([LowRankOperator(ru, t(u)),
                     LowRankOperator(delta(b, m, dtype, self._device), t(minus_ru)),
                     LowRankOperator(minus_ru_r, t(delta(b, m, dtype, self._device)))])


class OperatorW(OperatorsSum):
    def __init__(self, k_vv, xi, device):
        self._k_vv = k_vv
        self._a = self.slice(k_vv, xi).to(device)
        self._device = device
        super().__init__(None)

    def update(self, ks):
        dtype = self._k_vv.dtype
        b = len(ks)
        m = self._a.shape[1] + 1
        s = torch.cat([self._k_vv[:, k:k + 1].unsqueeze(0) for k in ks], dim=0).to(self._device)
        if self._a.shape[1] > 0:
            bottom_left = pad(self.mult(self._a, s), 0, 1)
        else:
            bottom_left = torch.zeros(b, 1, 1, dtype=dtype, device=self._device)
        right = t(bottom_left).clone()
        right[:, -1, 0] = self.mult(s, s)[:, 0, 0]
        super().set([LowRankOperator(right, t(delta(b, m, dtype, self._device))),
                     LowRankOperator(delta(b, m, dtype, self._device), bottom_left)])

    @staticmethod
    def mult(a, b):
        return t(b) @ a

    @staticmethod
    def slice(k_vv, xi):
        a = k_vv[:, xi]
        return a


class OperatorZ(OperatorsSum):
    def __init__(self, k_vv, xi, device):
        self._k_vv = k_vv
        self._xi = xi
        self._device = device
        super().__init__(None)

    def update(self, ks):
        dtype = self._k_vv.dtype
        b = len(ks)
        m = len(self._xi) + 1
        right = torch.cat([self._k_vv[self._xi + [k], k:k + 1].unsqueeze(0) for k in ks], dim=0).to(self._device)
        bottom_left = t(right).clone()
        bottom_left[:, 0, -1] = 0
        super().set([LowRankOperator(right, t(delta(b, m, dtype, self._device))),
                     LowRankOperator(delta(b, m, dtype, self._device), bottom_left)])


def pad(a, bottom=1, right=1):
    if a.dim() == 2:
        m, n = a.shape
        b = torch.zeros(m + bottom, n + right, dtype=a.dtype, device=a.device)
        b[:m, :n] = a
    else:
        b, m, n = a.shape
        b = torch.zeros(b, m + bottom, n + right, dtype=a.dtype, device=a.device)
        b[:, :m, :n] = a
    return b


def delta(b, n, dtype, device):
    x = torch.zeros(b, n, 1, dtype=dtype, device=device)
    x[:, -1, 0] = 1
    return x


def t(a):
    return torch.transpose(a, 2, 1)


def trace(a):
    return torch.diagonal(a, dim1=1, dim2=2).sum(-1)


def test():
    v = torch.rand(11, 10)
    context = LowRankGreedyContext(dtype=v.dtype, device='cpu')
    k_vv = v @ v.t()
    xi = []
    for i in range(10):
        xi, context = low_rank_greedy_acquisition(k_vv, context, xi, lambd=1, sigma=1, theta_norm_guess=1,
                                                  sample_size=1, batch_size=6, n_gpus=3)
    print(f'xi={xi}')


def chunks(l, n):
    n = max(1, n)
    return (l[i:i + n] for i in range(0, len(l), n))


if __name__ == '__main__':
    test()
