import torch
from opt_einsum import contract


def tail_matrix_product(A, B):
    if B.ndim == 2 or ((B.ndim == 3) and (B.shape[0] == 1)):
        B = B[0] if ((B.ndim == 3) and (B.shape[0] == 1)) else B
        A_dims = ''.join([chr(ord('a') + i) for i in range(len(A.shape))])
        B_dims = ''.join([chr(ord('a') + len(A.shape) - 1 + i) for i in range(len(B.shape))])
        A_dims = A_dims[:-1] + 'z'
        B_dims = B_dims[:-1] + 'z'
        result_dims = A_dims[:-1] + B_dims[:-1]
        einsum_expr = f'{A_dims},{B_dims}->{result_dims}'
        result = contract(einsum_expr, A, B)
        return result
    else:
        assert B.ndim == 3 and B.shape[0] == A.shape[0]
        A_dims = ''.join([chr(ord('a') + i) for i in range(len(A.shape))])
        B_dims = ''.join([chr(ord('a') + len(A.shape) - 1 + i) for i in range(len(B.shape))])
        A_dims = 'a' + A_dims[1:-1] + 'z'
        B_dims = 'a' + B_dims[1:-1] + 'z'
        result_dims = 'a' + A_dims[1:-1] + B_dims[1:-1]
        einsum_expr = f'{A_dims},{B_dims}->{result_dims}'
        result = contract(einsum_expr, A, B)
        return result


class Fourier_basis:
    def __init__(
            self,
            num_bases,
            in_dims,
    ):
        self.num_bases = num_bases
        self.in_dims = in_dims
        self.re_init()

    def set_parms(self, lengthscales, variance):
        self.variance = variance
        self._weights = self._weights.to(lengthscales) / (lengthscales + 1e-6)
        self._bias = self._bias.to(lengthscales)

    def re_init(self):
        num_bases = self.num_bases
        in_dims = self.in_dims
        self._bias = torch.rand([num_bases]) * 1 * torch.pi
        self._weights = torch.randn((num_bases, in_dims))

    def __call__(self, x):
        bias = self._bias
        weights = self._weights
        output_scale = self.output_scale

        proj = contract('...d,bd->...b', x, weights)
        feat = torch.cos(proj + bias)

        return output_scale * feat

    @property
    def output_scale(self):
        variance = self.variance * (self.variance >= 0)
        variance = torch.as_tensor(variance)
        return torch.sqrt(2 * variance / self.num_bases)


class Prior_random_fourier():
    def __init__(
            self,
            sample_shape,
            num_bases,
            in_dims,
            basis=None,
            weights=None,
            **kwargs
    ):

        self.in_dims = in_dims
        self.sample_shape = sample_shape
        self.num_bases = num_bases
        self.basis = Fourier_basis(num_bases=num_bases, in_dims=in_dims)
        self.re_init()

    def set_parms(self, lengthscales, variance):
        self.basis.set_parms(lengthscales, variance)
        self.weights = self.weights.to(lengthscales)

    def re_init(self):
        sample_shape = self.sample_shape
        num_bases = self.num_bases
        self.weights = torch.randn(list(sample_shape) + [num_bases])
        self.basis.re_init()

    def __call__(self, x, mean_func=None):
        assert x.shape[-1] == self.in_dims

        feat = self.basis(x)  # [sample seq bases]
        if mean_func is not None:
            mean_val = mean_func(x)
        else:
            mean_val = 0
        vals = tail_matrix_product(self.weights, feat).unsqueeze(-1) + mean_val

        return vals


def sample_from_prior(x, kernel, num_samples=10000):
    prior = Prior_random_fourier(
        sample_shape=[num_samples],
        num_bases=2048,
        in_dims=x.shape[-1],
    )
    return prior(x), prior


def sample_from_gaussian(chov, num_samples=10000, mean=None):
    if chov.ndim == 2:
        hidden = 1
        chov = chov.unsqueeze(0)
        if mean is not None:
            assert mean.ndim == 1
            mean = mean.unsqueeze(0)
            mean = mean.unsqueeze(-1)
    else:
        hidden = chov.shape[0]
        mean = mean.to(chov.device)
        mean = mean.permute(1, 0)
        mean = mean.unsqueeze(0)
    seqz = chov.shape[-1]
    samples = torch.randn((num_samples, seqz, hidden)).to(chov.device)
    if mean is None:
        mean = torch.zeros(num_samples, seqz, hidden).to(chov.device)
    return contract('szd,dyz->syd', samples, chov) + mean
