import torch
import torch.nn as nn
import numpy as np

from robustopt_torch.costs import eucl_norm_sq

# Gaussian kernel
def gaussian_kern(x, y, bandwidth = 1.0):
    return torch.exp(-eucl_norm_sq(x,y) / bandwidth)

# Log gaussian kernel
def log_gaussian_kern(x, y, bandwidth = 1.0):
    return -eucl_norm_sq(x,y) / bandwidth

# Random fourier features for a shift invariant kernel

# To get a gaussian kernel with bandwidth b, use:
# feature_mat = torch.Tensor(num_features,
# num_outputs).normal_().mul_(np.sqrt(2 / b))
# feature_bias = torch.Tensor(num_outputs).uniform_(-np.pi, np.pi)

# See this: https://gregorygundersen.com/blog/2019/12/23/random-fourier-features/
# for further details
def get_embedding(x, feature_mat, feature_bias):
    h = x @ feature_mat + feature_bias
    return torch.cos(h).mul(np.sqrt(2.0 / feature_mat.size(-1)))

def rff_kern_from_embeddings(x_embed, y_embed):
    return x_embed @ y_embed.t()

def rff_kern(x, y, feature_mat, feature_bias):
    x_embed = get_embedding(x, feature_mat, feature_bias)
    y_embed = get_embedding(y, feature_mat, feature_bias)
    return rff_kern_from_embeddings(x_embed, y_embed)

# Taken from
# https://github.com/IBM/USD/blob/1cd5ca559a1652adede945bd6634df2cab0f0a8f/modules.py
class RFFEmbedding(nn.Module):
    r"""Random Fourier Features Embedding

    Args
        **num_features** (scalar): number of input features
        **num_outputs** (scalar): number of random Fourier features
        **sigma** (scalar): kernel bandwidth

    Inputs
        **inputs** (batch x num_features): batch of inputs

    Outputs
        **outputs** (batch x num_outputs): batch of embedded inputs
    """
    def __init__(self, num_features, num_outputs=100, sigma=1.0):
        super(RFFEmbedding, self).__init__()

        self.num_features = num_features
        self.num_outputs = num_outputs
        self.sigma = sigma

        self.weight = nn.Parameter(torch.Tensor(num_features, num_outputs).normal_().mul_(np.sqrt(2) / sigma))
        self.bias = nn.Parameter(torch.Tensor(num_outputs).uniform_(-np.pi, np.pi))

    def forward(self, inputs):
        h = inputs @ self.weight + self.bias
        return torch.cos(h).mul(np.sqrt(2 / self.num_outputs))

class MMD_RFF(nn.Module):
    r"""MMD computed with Random Fourier Features

    Args
        **num_features** (scalar): number of input features
        **num_outputs** (scalar): number of random Fourier features

    Inputs
        **X** (batch1 x num_features): batch of inputs from distribution X
        **Y** (batch2 x num_features): batch of inputs from distribution Y
        **weights_X** (batch1, optional): weights weighing samples from X
            Weights are normalized so that weights_X.sum() == 1
        **weights_Y** (batch2, optional): weights weighing samples from Y
            Weights are normalized so that weights_X.sum() == 1

    Outputs
        **mmd**: Maximum Mean Discrepancy between X and Y
    """
    def __init__(self, num_features, num_outputs=100, sigma=1.0):
        super(MMD_RFF, self).__init__()

        self.num_features = num_features
        self.num_outputs = num_outputs

        self.rff_emb = RFFEmbedding(num_features, num_outputs, sigma=sigma)

    def forward(self, X, Y, weights_X=None, weights_Y=None):
        fX, fY = self.rff_emb(X), self.rff_emb(Y)

        if weights_X is None:
            mu_X = fX.mean(0)
        else:

            mu_X = (weights_X.view(-1,1) / weights_X.sum() * fX).sum(0)

        if weights_Y is None:
            mu_Y = fY.mean(0)
        else:
            mu_Y = (weights_Y.view(-1,1) / weights_Y.sum() * fY).sum(0)

        d_XY = mu_X - mu_Y
        return d_XY.norm()
