from functools import partial
from itertools import product
import numpy as np
import torch
from torch import Tensor
from robustopt_torch.StochasticSolver import StochasticSolver
from robustopt_torch.distributions import is_discrete, get_particles_and_weights
from robustopt_torch.costs import eucl_norm_sq

def mmd_ifunc(x, iterate_samps, iterate_weights, target_samps, target_weights,
              kernel, broadcastable = False):
    """The actual objective/influence function for MMD.
    """
    if len(iterate_samps) != len(iterate_weights) or len(target_samps) != len(target_weights):
        raise ValueError("Incorrect dimensions for samples and weights!")

    if broadcastable:
        return torch.einsum("i...,i...", kernel(iterate_samps, x), iterate_weights) - \
            torch.einsum("i...,i...", kernel(target_samps, x), target_weights)
    else:
        iter_ker = sum(kernel(iter_samp, x) * wt for iter_samp, wt in \
                       zip(iterate_samps, iterate_weights))
        target_ker = sum(kernel(target_samp, x) * wt for target_samp, wt in \
                         zip(target_samps, target_weights))
        return iter_ker - target_ker

class MMDIFunc:
    def __init__(self, kernel, target_sampler, iterate_bsize = float("inf"),
                 broadcastable_kernel = True, bandwidth = 1.0, bandwidth_adj = True):
        self.kernel = kernel
        self.broadcastable_kernel = broadcastable_kernel
        self.target_sampler = target_sampler
        self.iterate_bsize = iterate_bsize
        self.bandwidth = bandwidth
        self.bandwidth_adj = bandwidth_adj

    def get_ifunc(self, sampler):
        iterate_samps = sampler(self.iterate_bsize)
        iterate_weights = torch.ones(len(iterate_samps)) / len(iterate_samps)
        target_samps = self.target_sampler()
        target_weights = torch.ones(len(target_samps)) / len(target_samps)

        # if len(iterate_samps) != len(target_samps):
        #     print("Warning: batch size for the current distribution and the target distribution are different")

        if self.bandwidth_adj:
            self.bandwidth = eucl_norm_sq(iterate_samps, target_samps).median().item()
        ifunc_kern = partial(self.kernel, bandwidth = self.bandwidth)
        wrapped_ifunc = partial(mmd_ifunc,
                                iterate_samps = iterate_samps,
                                iterate_weights = iterate_weights,
                                target_samps = target_samps,
                                target_weights = target_weights,
                                kernel = ifunc_kern,
                                broadcastable = self.broadcastable_kernel)
        return wrapped_ifunc

    def get_mmd(self, sampler):
        iterate_samps = sampler(self.iterate_bsize)
        iterate_weights = torch.ones(len(iterate_samps)) / len(iterate_samps)
        target_samps = self.target_sampler()
        target_weights = torch.ones(len(target_samps)) / len(target_samps)

        i_mag = torch.einsum("i,ij,j->", iterate_weights,
                             self.get_kern_mat(iterate_samps, iterate_samps),
                             iterate_weights)
        t_mag = torch.einsum("i,ij,j->", target_weights,
                             self.get_kern_mat(target_samps, target_samps),
                             target_weights)
        it_mag = torch.einsum("i,ij,j->", iterate_weights,
                              self.get_kern_mat(iterate_samps, target_samps),
                              target_weights)
        return i_mag + t_mag - 2.0 * it_mag

    def get_kern_mat(self, x_samps, y_samps):
        if self.broadcastable_kernel:
            kern_matrix = kern(x_samps, y_samps, bandwidth = self.bandwidth)
        else:
            kern_matrix = torch.cat(tuple(map(torch.as_tensor,
                                              (self.kernel(x_samp, y_samp,
                                                           bandwidth = self.bandwidth) for
                                               x_samp, y_samp in product(x_samps,
                                                                         y_samps)))))
        return kern_matrix.reshape(len(x_samps), len(y_samps))
