import numpy as np
from localglobal.test_funcs.base import TestFunction
import subprocess
import random
import ast

# New offline stuff
uncert_types = ['mopo_default',
                'ensemble_var',
                'ensemble_std',
                # 'ensemble_var_rew',
                # 'ensemble_var_comb',
                'mopo_paper',  # Slightly different mopo penalty from the paper.
                'lompo',
                'm2ac',
                'morel',
                ]


class OfflineRL(TestFunction):
    problem_type = 'mixed'

    # For offline model-baseed RL.
    # Lambda [0, 10] say, or constraint uncertainty type (categorical), rollout length (categorical)

    def __init__(self, normalize=False, yaml_file=r"args_yml/bo_test_rig_hoppermed.yml", n_epochs=500):
        super(OfflineRL, self).__init__(normalize)
        self.categorical_dims = np.arange(0, 1)  # Only one categorical here - uncertainty type
        self.continuous_dims = np.arange(1, 1 + 3)
        self.dim = len(self.continuous_dims) + len(self.categorical_dims)
        self.n_vertices = np.array(
            [len(uncert_types)])  # first = uncert type
        self.config = self.n_vertices
        # specifies the range for the continuous variables
        self.lb, self.ub = np.array([1, 0, 5]), np.array(
            [50, 100, 15])  # Three continuous = model k, mopo lambda, num models
            
        self.yaml_file = yaml_file
        self.n_epochs = n_epochs

    def compute(self, X, normalize=None):
        if X.ndim == 1:
            X = X.reshape(1, -1)
        # To make sure there is no cheating, round the discrete variables before calling the function
        X[:, self.categorical_dims] = np.round(X[:, self.categorical_dims])

        y = []
        procs = []

        for x in X:
            # Potential two options:
            # "--mopo_uncertainty_target"
            # "--num_elites"
            seed = str(random.randint(0, 1000))
            l = ["python",
                 r"train.py",
                 "--yaml_file",
                 self.yaml_file,
                 "--offline_epochs",
                 str(self.n_epochs),
                 #
                 "--uuid",
                 "null",
                 "--seed",
                 seed,
                 #
                 # BAYES OPT VARIABLES DOWN HERE: 
                 "--mopo_penalty_type",
                 uncert_types[int(x[0])],
                 "--steps_k",
                 str(max(int(x[1]), 1)),
                 "--mopo_lam",
                 str(x[2]),
                 "--num_models",
                 str(max(int(x[3]), 5)),
                 ]
            print("--seed",
                  seed,
                  "--mopo_penalty_type",
                  uncert_types[int(x[0])],
                  "--steps_k",
                  str(max(int(x[1]), 1)),
                  "--mopo_lam",
                  str(x[2]),
                  "--num_models",
                  str(max(int(x[3]), 5)),
                  )
            proc = subprocess.Popen(l, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            procs.append(proc)

        # Wait on each process.
        np.seterr(all="raise")
        for idx, proc in enumerate(procs):
            result, err = proc.communicate()
            try:
                val = -float(err.splitlines()[-1]) / 10
            except:
                # Sometimes run into GPU memory problems with 4 per GPU...
                print(result.splitlines()[-20:])
                val = 0

            print(val)
            y.append(val)

        return np.array(y)

    def sample_normalize(self, size=None):
        from localglobal.bo.localbo_utils import latin_hypercube, from_unit_cube
        if size is None:
            size = 2 * self.dim + 1
        y = []
        for i in range(size):
            x_cat = np.array([np.random.choice(self.config[_]) for _ in range(self.categorical_dims.shape[0])])
            x_cont = latin_hypercube(1, self.continuous_dims.shape[0])
            x_cont = from_unit_cube(x_cont, self.lb, self.ub).flatten()
            x = np.hstack((x_cat, x_cont))
            y.append(self.compute(x, normalize=False))
        y = np.array(y)

        return np.mean(y), np.std(y)


if __name__ == '__main__':
    f = OfflineRL()
    print(f.sample_normalize(1))
