# Lifted from COMBO codes

import numpy as np
import torch
from collections import OrderedDict
from .base import TestFunction

CONTAMINATION_N_STAGES = 25


def sample_init_points(n_vertices, n_points, random_seed=None):
    """

    :param n_vertices: 1D array
    :param n_points:
    :param random_seed:
    :return:
    """
    if random_seed is not None:
        rng_state = torch.get_rng_state()
        torch.manual_seed(random_seed)
    init_points = torch.empty(0).long()
    for _ in range(n_points):
        init_points = torch.cat([init_points, torch.cat([torch.randint(0, int(elm), (1, 1)) for elm in n_vertices], dim=1)], dim=0)
    if random_seed is not None:
        torch.set_rng_state(rng_state)
    return init_points


def _contamination(x, cost, init_Z, lambdas, gammas, U, epsilon):
    assert x.size == CONTAMINATION_N_STAGES

    rho = 1.0
    n_simulations = 100

    Z = np.zeros((x.size, n_simulations))
    Z[0] = lambdas[0] * (1.0 - x[0]) * (1.0 - init_Z) + (1.0 - gammas[0] * x[0]) * init_Z
    for i in range(1, CONTAMINATION_N_STAGES):
        Z[i] = lambdas[i] * (1.0 - x[i]) * (1.0 - Z[i - 1]) + (1.0 - gammas[i] * x[i]) * Z[i - 1]

    below_threshold = Z < U
    constraints = np.mean(below_threshold, axis=1) - (1.0 - epsilon)

    return np.sum(x * cost - rho * constraints)


def generate_contamination_dynamics(random_seed=None):
    n_stages = CONTAMINATION_N_STAGES
    n_simulations = 100

    init_alpha = 1.0
    init_beta = 30.0
    contam_alpha = 1.0
    contam_beta = 17.0 / 3.0
    restore_alpha = 1.0
    restore_beta = 3.0 / 7.0
    init_Z = np.random.RandomState(random_seed).beta(init_alpha, init_beta, size=(n_simulations,))
    lambdas = np.random.RandomState(random_seed).beta(contam_alpha, contam_beta, size=(n_stages, n_simulations))
    gammas = np.random.RandomState(random_seed).beta(restore_alpha, restore_beta, size=(n_stages, n_simulations))

    return init_Z, lambdas, gammas


class Contamination(TestFunction):
    """
    Contamination Control Problem with the simplest graph
    Xingchen Note: important to fix the random_seed_pair!!!. Otherwise we get different solutions each time.
    """
    def __init__(self, lamda,
                 normalize=True,
                 random_seed_pair=(0, 0)):
        super(Contamination, self).__init__(normalize)
        self.lamda = lamda
        self.n_vertices = np.array([2] * CONTAMINATION_N_STAGES)
        self.n_grid = 2
        self.config = self.n_vertices
        # self.suggested_init = torch.empty(0).long()
        # self.suggested_init = torch.cat([self.suggested_init, sample_init_points(self.n_vertices, 20 - self.suggested_init.size(0), random_seed_pair[1])], dim=0)
        # self.adjacency_mat = []
        # self.fourier_freq = []
        # self.fourier_basis = []
        # self.random_seed_info = 'R'.join([str(random_seed_pair[i]).zfill(4) if random_seed_pair[i] is not None else 'None' for i in range(2)])
        # for i in range(len(self.n_vertices)):
        #     n_v = self.n_vertices[i]
        #     adjmat = torch.diag(torch.ones(n_v - 1), -1) + torch.diag(torch.ones(n_v - 1), 1)
        #     self.adjacency_mat.append(adjmat)
        #     laplacian = torch.diag(torch.sum(adjmat, dim=0)) - adjmat
        #     eigval, eigvec = torch.symeig(laplacian, eigenvectors=True)
        #     self.fourier_freq.append(eigval)
        #     self.fourier_basis.append(eigvec)
        # In all evaluation, the same sampled values are used.
        # self.normalize = normalize
        self.dim = CONTAMINATION_N_STAGES
        self.categorical_dims = np.arange(self.dim)
        self.init_Z, self.lambdas, self.gammas = generate_contamination_dynamics(random_seed_pair[0])
        if self.normalize:
            self.mean, self.std = self.sample_normalize()
        else:
            self.mean, self.std = None, None

    # def sample_normalize(self, size=None):
    #     if size is None:
    #         size = 2 * self.dim + 1
    #     y = []
    #     for i in range(size):
    #         x = np.array([np.random.choice(2) for _ in range(CONTAMINATION_N_STAGES)])
    #         y.append(self.compute(x, normalize=False,))
    #     y = np.array(y)
    #     return np.mean(y), np.std(y)

    def compute(self, x, normalize=None, ):
        if normalize is None: normalize = self.normalize
        if not isinstance(x, torch.Tensor):
            try:
                x = torch.tensor(x.astype(int))
            except:
                raise Exception('Unable to convert x to a pytorch tensor!')
        if x.dim() == 1:
            x = x.unsqueeze(0)
        # assert x.size(1) == len(self.n_vertices)
        res = torch.cat([self._evaluate_single(x[i]) for i in range(x.size(0))], dim=0)
        res += 1e-6 * torch.randn_like(res)

        if normalize:
            assert self.mean is not None and self.std is not None
            res = (res - self.mean) / self.std

        return res

    def _evaluate_single(self, x):
        assert x.dim() == 1
        assert x.numel() == len(self.n_vertices)
        if x.dim() == 2:
            x = x.squeeze(0)
        evaluation = _contamination(x=(x.cpu() if x.is_cuda else x).numpy(), cost=np.ones(x.numel()), init_Z=self.init_Z, lambdas=self.lambdas, gammas=self.gammas, U=0.1, epsilon=0.05)
        evaluation += self.lamda * float(torch.sum(x))
        return evaluation * x.new_ones((1,)).float()


if __name__ == '__main__':
    a = Contamination(lamda=1e-2, normalize=True)
    print(a.compute(
        np.ones((CONTAMINATION_N_STAGES, ))
    ))

    print(a.compute(
        np.zeros((CONTAMINATION_N_STAGES, ))
    ))