# Modified from COMBO codes

import math
import numpy as np
from localglobal.test_funcs.base import TestFunction

import torch


class Branin(TestFunction):

    def __init__(self, normalize=True):
        super(Branin, self).__init__(normalize)
        self.n_vertices = np.array([51, 51])
        self.dim = 2
        self.config = self.n_vertices
        self.n_factors = len(self.n_vertices)
        # self.suggested_init = torch.Tensor(self.n_vertices).long().unsqueeze(0) / 2
        # for _ in range(1, 2):
        #     random_init = torch.cat([torch.randint(0, int(elm), (1, 1)) for elm in self.n_vertices], dim=1)
        #     self.suggested_init = torch.cat([self.suggested_init, random_init], dim=0)
        self.adjacency_mat = []
        # self.fourier_freq = []
        # self.fourier_basis = []
        # for i in range(len(self.n_vertices)):
        #     n_v = self.n_vertices[i]
        #     adjmat = torch.diag(torch.ones(n_v - 1), -1) + torch.diag(torch.ones(n_v - 1), 1)
        #     adjmat *= (n_v - 1.0)
        #     self.adjacency_mat.append(adjmat)
        #     degmat = torch.sum(adjmat, dim=0)
        #     laplacian = (torch.diag(degmat) - adjmat)
        #     eigval, eigvec = torch.symeig(laplacian, eigenvectors=True)
        #     self.fourier_freq.append(eigval)
        #     self.fourier_basis.append(eigvec)
        if self.normalize:
            self.mean, self.std = self.sample_normalize()
        else:
            self.mean, self.std = None, None

    def compute(self, x, normalize=None, ):
        if not isinstance(x, torch.Tensor):
            x = torch.Tensor(x.astype(int))
        flat = x.dim() == 1
        if flat:
            x = x.view(1, -1)
        ndim = x.size(1)
        assert ndim == len(self.n_vertices)
        n_repeat = int(ndim / 2)
        n_dummy = int(ndim % 2)

        x_e = torch.ones(x.size())
        for d in range(len(self.n_vertices)):
            x_e[:, d] = torch.linspace(-1, 1, int(self.n_vertices[d]))[x[:, d].long()]

        shift = torch.cat([torch.FloatTensor([2.5, 7.5]).repeat(n_repeat), torch.zeros(n_dummy)])

        x_e = x_e * 7.5 + shift

        a = 1
        b = 5.1 / (4 * math.pi ** 2)
        c = 5.0 / math.pi
        r = 6
        s = 10
        t = 1.0 / (8 * math.pi)
        output = 0
        for i in range(n_repeat):
            output += a * (x_e[:, 2 * i + 1] - b * x_e[:, 2 * i] ** 2 + c * x_e[:, 2 * i] - r) ** 2 \
                      + s * (1 - t) * torch.cos(x_e[:, 2 * i]) + s
        output /= float(n_repeat)
        if normalize:
            assert self.mean is not None and self.std is not None
            output = (output - self.mean) / self.std

        if flat:
            return output.squeeze(0)
        else:
            return output

if __name__ == '__main__':
    import matplotlib.pylab as plt
    f = Branin()
    res = np.zeros((51, 51))
    for i in range(51):
        for j in range(51):
            res[i, j] = f.compute(np.array([[i, j]]), normalize=False)
    plt.imshow(res)
    plt.colorbar()
    plt.show()