import json
import os
import random
import string
import uuid
import shutil

from typing import Optional, Union
from pprint import pprint

import configargparse
import sys
from contextlib import contextmanager, redirect_stderr, redirect_stdout
import torch.nn as nn

from bo import gp_initialize_model, gp_optimize_acqf_and_get_observation
import numpy as np

@contextmanager
def suppress_output():
    """
        A context manager that redirects stdout and stderr to devnull
        https://stackoverflow.com/a/52442331
    """
    with open(os.devnull, 'w') as fnull:
        with redirect_stderr(fnull) as err, redirect_stdout(fnull) as out:
            yield (err, out)

with suppress_output():
    import design_bench

    from design_bench.datasets.discrete.tf_bind_8_dataset import TFBind8Dataset
    from design_bench.datasets.discrete.tf_bind_10_dataset import TFBind10Dataset
    from design_bench.datasets.discrete.cifar_nas_dataset import CIFARNASDataset
    from design_bench.datasets.discrete.chembl_dataset import ChEMBLDataset

    from design_bench.datasets.continuous.ant_morphology_dataset import AntMorphologyDataset
    from design_bench.datasets.continuous.dkitty_morphology_dataset import DKittyMorphologyDataset
    from design_bench.datasets.continuous.superconductor_dataset import SuperconductorDataset
    # from design_bench.datasets.continuous.hopper_controller_dataset import HopperControllerDataset

from util import TASKNAME2TASK, configure_gpu, set_seed, get_weights

task = design_bench.make(TASKNAME2TASK['ant'])
# dataset = task.dataset
original_x = task.x[:10]
original_y = task.y[:10]

#task.map_normalize_x()
task.map_normalize_y()

#task_x = task.x[:10]
#print(task_x.max(), task_x.min(), original_x.max(), original_x.min())
#task_x = task.denormalize_x(task_x)
#task.map_denormalize_x()
task.map_denormalize_y()
task_x = task.x[:10]
task_pred = task.predict(task_x)
#task_pred = task.denormalize_y(task_pred)
print(original_y)
print(task_pred)

print('max', np.abs(original_y - task_pred).max())
print('mean', np.abs(original_y-task_pred).mean())
