import time

import matplotlib.pyplot as plt
import numpy as np

from envs.Bandits_gen import LinB_sample_multi, MAB_sample_multi
from envs.DP_gen_cts import DP_cts_sample_multi
from envs.NV_gen_cts import NV_ctx_sample_multi_cts

DP_continuous_dict = {
    "example_run": {
        "exp_name": "example_run",
        "gen_method": "Perturb",
        "dim_num": 4,
        "action_ub": 10,
        "num_samp": 10000,
        "horizon": 100,
        "err_std": 0.2,
        "finite_envs_num": 0,
        "seed": 123,
        "need_square": False,
    },
}


NV_continuous_dict = {
    "example_run": {
        "exp_name": "example_run",
        "gen_method": "Perturb",
        "dim_num": 4,
        "action_ub": 30,
        "num_samp": 10000,
        "horizon": 100,
        "perishable": True,
        "censor": False,
        "finite_envs_num": 4,
        "seed": 123,
    },
}

MAB_dict = {
    "example_run": {
        "exp_name": "example_run",
        "gen_method": "Perturb",
        "dim_num": 20,
        "num_samp": 10000,
        "horizon": 100,
        "finite_envs_num": 4,
        "err_std": 0.2,
        "finenvs_seed": 123,
    },
}

LinB_dict = {
    "example_run": {
        "exp_name": "example_run",
        "gen_method": "Perturb",
        "dim_num": 2,
        "num_samp": 10000,
        "horizon": 100,
        "finite_envs_num": 4,
        "err_std": 0.2,
        "finenvs_seed": 123,
    },
}


samp_function_pool = {
    "DP_continuous": DP_cts_sample_multi,
    "NV_continuous": NV_ctx_sample_multi_cts,
    "MAB": MAB_sample_multi,
    "LinB": LinB_sample_multi,
}

# the data list to be generated
task_pool = {
    "DP_continuous": [
        DP_continuous_dict["example_run"],
    ],
    "NV_continuous": [
        NV_continuous_dict["example_run"],
    ],
    "MAB": [MAB_dict["example_run"]],
    "LinB": [LinB_dict["example_run"]],
}


def main():
    for task_type, task_args in task_pool.items():
        samp_function = samp_function_pool[task_type]

        for task in task_args:
            print("####################################")
            print("task_type:", task_type)
            print("task_info:", task)
            tik = time.time()
            regs, act_values, best_actions = samp_function(**task)
            print("##############finish#################")
            print("generation time:", (time.time() - tik) / 60)
            if task_type not in ["MAB", "LinB"]:
                action_error = np.abs(act_values - best_actions.squeeze())
            elif task_type == "MAB":
                # switch the one-hot encoding to the original action
                act_values = np.argmax(act_values, axis=2)
                action_error = np.abs(act_values - best_actions.squeeze())
            elif task_type == "LinB":
                # compute the L2 distance between the action and the optimal action
                action_error = np.linalg.norm(act_values - best_actions, axis=2)
            print("action errors:", action_error[:5, :])

            plt.plot(np.mean(action_error, axis=0), label="mean_act_error")
            plt.legend()
            plt.savefig(task_type + "_action_error.png")
            plt.show()
            plt.close()


if __name__ == "__main__":
    main()
