import argparse
import pickle
import GPy
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from IPython import embed
import logging

from testFunctions.syntheticFunctions import myrosenbrock, mysixhumpcamp, mybeale

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(lineno)d: %(message)s")

parser = argparse.ArgumentParser(description="result analysis")
parser.add_argument("--log_dir", default="./log/2_lhd_rosenbrock")
parser.add_argument("--budget", default=9)


args = parser.parse_args()


budget_dict = {0: 10000.0, 1: 1000.0, 2: 100.0, 3: 10.0, 4: 1.0,
               5: 0.1, 6: 0.01, 7: 0.001, 8: 0.0001, 9: 0.00001}


def analysis(beta_var: bool, maxval: bool):
    """
    beta_var: 计算beta的std
    """

    if beta_var:
        betas = np.array(pickle.load(open(f"{args.log_dir}/betas.pkl", "rb")))
        beta_std = np.sqrt(np.sum((betas - betas.mean(axis=0)) ** 2) / betas.shape[0])
        logging.info(f"beta_std:{beta_std:.4f}")

    if maxval:
        data = pd.read_pickle(f"{args.log_dir}/CoCaBO_1_best_vals_LCB_ARD_False_mix_-1.0")
        all_data = []
        for i in range(5):
            all_data += data[i].to_list()
        logging.info(f"max val:{max(all_data)}")


def intergral_loss():
    gp = pickle.load(open(f"{args.log_dir}/gp.pkl", "rb"))

    loss = 0
    for x in np.linspace(-1, 1, 20):
        for y in np.linspace(-1, 1, 20):
            predict = gp.predict(np.array([[args.budget, x, y]]))
            # print(predict)
            loss += abs(myrosenbrock(np.array([x, y])) + budget_dict[args.budget] - predict)
    return loss


def main():
    analysis(True, True)
    print(intergral_loss())


if __name__ == "__main__":
    main()

# log1 = pd.read_pickle("./log/1_budget_func2C/CoCaBO_1_best_vals_LCB_ARD_False_mix_-1.0")
# # log2 = pd.read_pickle("./log/1_budget_func2C/CoCaBO_1_best_vals_LCB_ARD_False_mix_0.0_df_s0")
# # log3 = pd.read_pickle("./log/1_budget_func2C/CoCaBO_1_best_vals_LCB_ARD_False_mix_0.0_df_s1")
# # log4 = pd.read_pickle("./log/1_budget_func2C/CoCaBO_1_best_vals_LCB_ARD_False_mix_0.0_df_s2")


# log21 = pd.read_pickle("./log/1_lhd_budget_func2C/CoCaBO_1_best_vals_LCB_ARD_False_mix_-1.0")

# embed()

# with open("./log/1_budget_func2C/gp.pkl", "rb") as f:
#     gp = pickle.load(f)

# a = np.array([[1.0, 2.0, 3.0, 0.3, -0.3]])
# b = np.array([[0.0, 1.0, 2.0, 0.3, -0.3]])

# var = gp.kern.K(a, b)
# embed()


# def parametric_mean_function(max_iters=100, optimize=True, plot=True):
#     """
#     A linear mean function with parameters that we'll learn alongside the kernel
#     """
#     # create  simple mean function
#     mf = GPy.core.Mapping(1, 1)
#     mf.f = np.sin

#     X = np.linspace(0, 10, 50).reshape(-1, 1)
#     Y = np.sin(X) + 0.5 * np.cos(3 * X) + 0.1 * np.random.randn(*X.shape) + 3 * X

#     mf = GPy.mappings.Linear(1, 1)

#     k = GPy.kern.RBF(1)
#     lik = GPy.likelihoods.Gaussian()
#     m = GPy.core.GP(X, Y, kernel=k, likelihood=lik, mean_function=mf)
#     if optimize:
#         m.optimize(max_iters=max_iters)
#     if plot:
#         m.plot()

#     embed()
#     return m


# if __name__ == "__main__":
#     parametric_mean_function()
#     plt.show()
