import numpy as np
import torch
from collections import OrderedDict


class DiscreteBranin:

    def __init__(self, n_grid=16, maximise=False, normalize=True):

        """

        :param n_grid: number of discrete points to create per dimension.
        :param normalize: normalize to the range of [0,1] for the reward
        """
        self.n_grid = n_grid
        self.x1_bounds = np.array([0, 15])
        self.x2_bounds = np.array([0, 15])

        self.x1_values = np.linspace(self.x1_bounds[0], self.x1_bounds[1], n_grid)
        self.x2_values = np.linspace(self.x2_bounds[0], self.x2_bounds[1], n_grid)
        self.maximise = maximise
        self.normalize = normalize

        meshgrid = np.meshgrid(self.x1_values, self.x2_values)
        self.values = np.vstack(list(map(np.ravel, meshgrid))).T
        self.fX = self.compute(self.values, normalize=False)
        # config specification for bayesmark style evaluation
        self.config = OrderedDict(
            {'x1': {'type': 'cat', 'values': self.x1_values},
             'x2': {'type': 'cat', 'values': self.x2_values}}
        )
        self.n_vertices = 2

    def compute(self, x, normalize=True):
        # assert x in self.values
        x = x.astype(float)
        # Translate the Branin to the original [-5, 10] domain it was originally defined
        x0 = x[:, 0] - 5
        x1 = x[:, 1]

        a = 1.0
        b = 5.1 / (4.0 * np.pi ** 2)
        c = 5.0 / np.pi
        d = 6.0
        e = 10.0
        f = 1.0 / (8.0 * np.pi)

        res = a * (x1 - b * x0 ** 2 + c * x0 - d) ** 2 + e * (1 - f) * np.cos(x0) + e
        if normalize:
            fX_max = np.max(self.fX)
            fX_min = np.min(self.fX)
            res = (res - fX_min) / (fX_max - fX_min)
        if self.maximise:
            res = -res
        return res



class DiscreteHartmann6(object):
    """Discrete Hartmann6 from COMBO"""

    def __init__(self):
        self.n_vertices = np.array([51] * 6)
        self.n_grid = 51
        self.n_factors = len(self.n_vertices)
        # self.suggested_init = torch.tensor(self.n_vertices).long().unsqueeze(0) // 2
        # for _ in range(1, 2):
        #     random_init = torch.cat([torch.randint(0, int(elm), (1, 1)) for elm in self.n_vertices], dim=1)
        #     self.suggested_init = torch.cat([self.suggested_init, random_init], dim=0)
        # self.adjacency_mat = []
        # self.fourier_freq = []
        # self.fourier_basis = []
        # for i in range(len(self.n_vertices)):
        #     n_v = self.n_vertices[i]
        #     adjmat = torch.diag(torch.ones(n_v - 1), -1) + torch.diag(torch.ones(n_v - 1), 1)
        #     adjmat *= (n_v - 1.0)
        #     self.adjacency_mat.append(adjmat)
        #     wgtsum = torch.sum(adjmat, dim=0)
        #     laplacian = (torch.diag(wgtsum) - adjmat)
        #     eigval, eigvec = torch.symeig(laplacian, eigenvectors=True)
        #     self.fourier_freq.append(eigval)
        #     self.fourier_basis.append(eigvec)
        # config specification for bayesmark style evaluation
        self.config = {
            'x1': {'type': 'cat', 'range': tuple(str(i) for i in np.arange(52))},
            'x2': {'type': 'cat', 'range': tuple(str(i) for i in np.arange(52))},
            'x3': {'type': 'cat', 'range': tuple(str(i) for i in np.arange(52))},
            'x4': {'type': 'cat', 'range': tuple(str(i) for i in np.arange(52))},
            'x5': {'type': 'cat', 'range': tuple(str(i) for i in np.arange(52))},
            'x6': {'type': 'cat', 'range': tuple(str(i) for i in np.arange(52))},
        }

    def compute(self, x_g):
        if not isinstance(x_g, torch.Tensor):
            x_g = torch.tensor(x_g, ).float()
        alpha = torch.tensor([1.0, 1.2, 3.0, 3.2])
        A = torch.tensor([[10.0, 3.00, 17.0, 3.50, 1.70, 8.00],
                          [0.05, 10.0, 17.0, 0.10, 8.00, 14.0],
                          [3.00, 3.50, 1.70, 10.0, 17.0, 8.00],
                          [17.0, 8.00, 0.05, 10.0, 0.10, 14.0]]).t()
        P = torch.tensor([[0.1312, 0.1696, 0.5569, 0.0124, 0.8283, 0.5886],
                          [0.2329, 0.4135, 0.8307, 0.3736, 0.1004, 0.9991],
                          [0.2348, 0.1451, 0.3522, 0.2883, 0.3047, 0.6650],
                          [0.4047, 0.8828, 0.8732, 0.5743, 0.1091, 0.0381]]).t()

        flat = x_g.dim() == 1
        if flat:
            x_g = x_g.view(1, -1)
        ndata, ndim = x_g.size()
        n_repeat = int(ndim / 6)

        x_e = torch.ones(x_g.size())
        for d in range(len(self.n_vertices)):
            x_e[:, d] = torch.linspace(-1, 1, int(self.n_vertices[d]))[x_g[:, d]]

        x_e = (x_e + 1) * 0.5

        output = 0
        for i in range(n_repeat):
            x_block = x_e[:, 6 * i:6 * (i + 1)]
            output += -(alpha.view(1, -1).repeat(ndata, 1)
                        * torch.exp(-(A.unsqueeze(0).repeat(ndata, 1, 1)
                                      * (
                                              x_block.unsqueeze(2).repeat(1, 1, 4)
                                              - P.unsqueeze(0).repeat(ndata, 1, 1)) ** 2).sum(1))).sum(1, keepdim=True)
        output /= float(n_repeat)
        output = output.numpy()
        if flat:
            return output.squeeze()
        else:
            return output
