import json
import os

from marl_eval.plotting_tools.plotting import (
    aggregate_scores,
    performance_profiles,
    plot_single_task,
    probability_of_improvement,
    sample_efficiency_curves,
)
from marl_eval.utils.data_processing_utils import (
    create_matrices_for_rliable,
    data_process_pipeline,
)
from marl_eval.utils.diagnose_data_errors import DiagnoseData

data_path = "path_to_data"
exp_name = "name_of_experiment"
path_use = data_path + exp_name

data_file_path = f"{path_use}.json"

plot_dir = f"plots/{exp_name}/"

##############################
# Read in and process data
##############################
METRICS_TO_NORMALIZE: list = []

with open(data_file_path, "r") as f:
    raw_data = json.load(f)

data_diagnose = DiagnoseData(raw_data=raw_data)
diag_results = data_diagnose.check_data()

processed_data = data_process_pipeline(
    raw_data=raw_data, metrics_to_normalize=METRICS_TO_NORMALIZE
)

# Create folder for storing plots
if not os.path.exists(plot_dir):
    os.makedirs(plot_dir)

##############################
# Plot for foraging environment
##############################

environment_comparison_matrix, sample_effeciency_matrix = create_matrices_for_rliable(
    data_dictionary=processed_data,
    environment_name="lbforaging",
    metrics_to_normalize=METRICS_TO_NORMALIZE,
)

# Aggregate data over a single task.
for task in list(processed_data["lbforaging"].keys()):
    fig = plot_single_task(
        processed_data=processed_data,
        environment_name="lbforaging",
        task_name=task,
        metric_name="test_return_mean",
        metrics_to_normalize=METRICS_TO_NORMALIZE,
    )

    fig.figure.savefig(
        f"{plot_dir}lbforaging_{task}_test_return.png", bbox_inches="tight"
    )
# Aggregate data over all environment tasks.

fig = performance_profiles(
    environment_comparison_matrix,
    metric_name="test_return_mean",
    metrics_to_normalize=METRICS_TO_NORMALIZE,
)
fig.figure.savefig(f"{plot_dir}lbforaging_performance_profile.png", bbox_inches="tight")

fig, _, _ = aggregate_scores(  # type: ignore
    environment_comparison_matrix,
    metric_name="test_return_mean",
    metrics_to_normalize=METRICS_TO_NORMALIZE,
    save_tabular_as_latex=True,
)
fig.figure.savefig(
    f"{plot_dir}lbforaging_return_aggregate_scores.png", bbox_inches="tight"
)

fig = probability_of_improvement(
    environment_comparison_matrix,
    metric_name="test_return_mean",
    metrics_to_normalize=METRICS_TO_NORMALIZE,
    algorithms_to_compare=[
        ["iql", "qmix"],
        ["iql", "vdn"],
        ["vdn", "qmix"],
        ["maa2c", "mappo"],
    ],
)
fig.figure.savefig(
    f"{plot_dir}lbforaging_return_prob_of_improvement.png", bbox_inches="tight"
)

fig, _, _ = sample_efficiency_curves(  # type: ignore
    sample_effeciency_matrix,
    metric_name="test_return_mean",
    metrics_to_normalize=METRICS_TO_NORMALIZE,
)
fig.figure.savefig(
    f"{plot_dir}lbforaging_return_sample_effeciency_curve.png", bbox_inches="tight"
)

##############################
# Plots for rware environment
##############################

METRICS_TO_NORMALIZE = ["test_return_mean"]

with open(data_file_path, "r") as f:
    raw_data = json.load(f)

processed_data = data_process_pipeline(
    raw_data=raw_data, metrics_to_normalize=METRICS_TO_NORMALIZE
)

environment_comparison_matrix, sample_effeciency_matrix = create_matrices_for_rliable(
    data_dictionary=processed_data,
    environment_name="rware",
    metrics_to_normalize=METRICS_TO_NORMALIZE,
)

# Aggregate data over a single task.
for task in list(processed_data["rware"].keys()):
    fig = plot_single_task(
        processed_data=processed_data,
        environment_name="rware",
        task_name=task,
        metric_name="test_return_mean",
        metrics_to_normalize=METRICS_TO_NORMALIZE,
    )

    fig.figure.savefig(f"{plot_dir}/rware_{task}_test_return.png", bbox_inches="tight")

# Aggregate data over all environment tasks.
fig = performance_profiles(
    environment_comparison_matrix,
    metric_name="test_return_mean",
    metrics_to_normalize=METRICS_TO_NORMALIZE,
)
fig.figure.savefig(f"{plot_dir}rware_performance_profile.png", bbox_inches="tight")

fig, _, _ = aggregate_scores(  # type: ignore
    environment_comparison_matrix,
    metric_name="test_return_mean",
    metrics_to_normalize=METRICS_TO_NORMALIZE,
    save_tabular_as_latex=True,
)
fig.figure.savefig(f"{plot_dir}rware_return_aggregate_scores.png", bbox_inches="tight")

fig = probability_of_improvement(
    environment_comparison_matrix,
    metric_name="test_return_mean",
    metrics_to_normalize=METRICS_TO_NORMALIZE,
    algorithms_to_compare=[
        ["iql", "qmix"],
        ["iql", "vdn"],
        ["vdn", "qmix"],
        ["maa2c", "mappo"],
    ],
)
fig.figure.savefig(
    f"{plot_dir}rware_return_prob_of_improvement.png", bbox_inches="tight"
)

fig, _, _ = sample_efficiency_curves(  # type: ignore
    sample_effeciency_matrix,
    metric_name="test_return_mean",
    metrics_to_normalize=METRICS_TO_NORMALIZE,
)
fig.figure.savefig(
    f"{plot_dir}rware_return_sample_effeciency_curve.png", bbox_inches="tight"
)
