import os
import sys

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

import algos
import envs
import plots
from typing import List
from experiments.em_helpers import (
    cycle_lengths_geom,
    ada_eps_descending,
    create_const_lr,
    create_lr,
    indices_for_pairs_values_only,
)

from experiment_executor import execute_experiment
import argparse

parser = argparse.ArgumentParser(description="Run ADP experiments for GridWorld.")
parser.add_argument(
    "--project-name",
    dest="project_name",
    type=str,
    help="Project name (used in output paths).",
)
parser.add_argument(
    "--base-folder", type=str, default="results", help="Base results folder."
)
parser.add_argument(
    "--output-folder", type=str, default="plots", help="Folder for plots."
)
parser.add_argument(
    "--K_init", type=float, default=25000, help="Initial cycle length K_0."
)
parser.add_argument("--n", type=int, default=100, help="Number of cycles.")
parser.add_argument(
    "--gamma", type=float, default=0.9, help="Discount factor in (0,1)."
)
parser.add_argument(
    "--num-steps",
    dest="num_steps_single",
    type=int,
    help="Total training steps per algorithm.",
)
parser.add_argument(
    "--uniform-sampling",
    dest="uniform_sampling",
    action="store_true",
    help="Use uniform state-action sampling (default).",
)
parser.add_argument(
    "--no-uniform-sampling",
    dest="uniform_sampling",
    action="store_false",
    help="Disable uniform state-action sampling.",
)
parser.add_argument(
    "--num-runs",
    dest="num_runs_single",
    type=int,
    default=1,
    help="Number of runs per algorithm.",
)
parser.add_argument("--lr-s", dest="lr_initial", type=float)
parser.add_argument("--lr-c", dest="lr_const", type=float)
parser.add_argument("--lr-warmup", dest="lr_warmup", type=int, default=0)

args = parser.parse_args()

###########################################################################################################################################################################################################################
###########################################################################################################################################################################################################################
###########################################################################################################################################################################################################################

# Highlevel organizational parameters
project_name = args.project_name
base_folder = args.base_folder
output_folder = args.output_folder

K_init = int(args.K_init)
n = int(args.n)
gamma = args.gamma

print(K_init)

subproject_labels = ["testing"]

num_runs = len(subproject_labels) * [args.num_runs_single]
num_steps = len(subproject_labels) * [args.num_steps_single]

######### ENVIRONMENT:
environments = len(subproject_labels) * [envs.GridWorld]
environments_specific_parameters = len(subproject_labels) * [
    {
        "grid_size": (4, 4),
        "state_type_loc": {
            "goal": ([(3, 3)], True),
            "start": ([(0, 0)], False),
            "stoch region": (
                [
                    (2, 0),
                    (3, 0),
                    (2, 1),
                    (3, 1),
                ],
                False,
            ),
            "fake goal": ([(i, 2) for i in range(4) if i != 1], True),
        },
        "rewards": {
            "default": [
                "choice",
                {"a": [1 * i for i in [-0.08, 0.05]], "p": [0.5, 0.5]},
            ],
            "stoch region": [
                "choice",
                {"a": [1 * i for i in [-2.1, 2]], "p": [0.5, 0.5]},
            ],
            "goal": ["choice", {"a": [0.5, 1.5], "p": [0.5, 0.5]}],
            "fake goal": -3,
        },
    }
]
env_randomization = len(subproject_labels) * [False]
env_randomization_kwargs = len(subproject_labels) * [{}]
env_randomization_schedule = len(subproject_labels) * [[-1]]
q_star_dict = {
    (0, 0): 0.39951752152306896,
    (0, 1): 0.46060694613685,
    (0, 2): 0.46060694613685,
    (0, 3): 0.39951752152306896,
    (1, 0): 0.46060694613685,
    (1, 1): -2.715028730000096,
    (1, 2): 0.5284840845966067,
    (1, 3): 0.39951752152306896,
    (2, 0): -3.0,
    (3, 0): 0.6039031273296697,
    (3, 1): 0.6039031273296697,
    (3, 2): 0.6877020636997397,
    (3, 3): -2.715028730000096,
    (4, 0): 0.39951752152306896,
    (4, 1): 0.5284840845966067,
    (4, 2): 0.36962638852354773,
    (4, 3): 0.46060694613685,
    (5, 0): 0.46060694613685,
    (5, 1): 0.6039031273296697,
    (5, 2): 0.4307158131373288,
    (5, 3): 0.46060694613685,
    (6, 0): -2.715028730000096,
    (6, 1): 0.6877020636997397,
    (6, 2): -2.715028730000096,
    (6, 3): 0.5284840845966067,
    (7, 0): 0.6039031273296697,
    (7, 1): 0.6877020636997397,
    (7, 2): 0.7808119929998175,
    (7, 3): 0.6039031273296697,
    (8, 0): 0.42739457613738197,
    (8, 1): 0.3975034431378608,
    (8, 2): 0.2545315166721077,
    (8, 3): 0.3364140185240797,
    (9, 0): 0.4952717145971387,
    (9, 1): -2.7482410999995643,
    (9, 2): 0.3095119988245107,
    (9, 3): 0.3364140185240797,
    (10, 0): -3.0,
    (11, 0): 0.6877020636997397,
    (11, 1): 0.7808119929998175,
    (11, 2): 0.884267469999904,
    (11, 3): -2.715028730000096,
    (12, 0): 0.3364140185240797,
    (12, 1): 0.3095119988245107,
    (12, 2): 0.2545315166721077,
    (12, 3): 0.2545315166721077,
    (13, 0): 0.3975034431378608,
    (13, 1): -2.7482410999995643,
    (13, 2): 0.3095119988245107,
    (13, 3): 0.2545315166721077,
    (14, 0): -3.0,
    (15, 0): 0.999218,
}
env = envs.GridWorld(**environments_specific_parameters[0])
initial_q_fct = {
    (state, action): 0
    for state in range(env.num_states)
    for action in env.allowed_actions[state]
}

######### ALGORITHMS:

algorithms = len(subproject_labels) * [
    [
        algos.ADP,
        algos.ADP,
        algos.ADP,
    ]
]
algo_labels = [
    [
        f"ADP, $TUI={K_init}$",
        f"ADP, $ITUI={K_init}$",
        f"ADP, $TUI={K_init}$, min_steps=1",
    ],
]

uniform_sampling_bool = args.uniform_sampling


algorithms_specific_parameters = len(subproject_labels) * [
    [
        {
            "cycle_lengths": [K_init],
            "lr_per_cycle": True,
            "uniform_state_action_sampling": uniform_sampling_bool,
            "q_fct_manual_init": True,
            "initial_q_fct": initial_q_fct,
        },
        {
            "cycle_lengths": cycle_lengths_geom(K_init, n, gamma),
            "lr_per_cycle": True,
            "uniform_state_action_sampling": uniform_sampling_bool,
            "q_fct_manual_init": True,
            "initial_q_fct": initial_q_fct,
        },
        {
            "cycle_lengths": [K_init],
            "lr_per_cycle": True,
            "uniform_state_action_sampling": uniform_sampling_bool,
            "q_fct_manual_init": True,
            "initial_q_fct": initial_q_fct,
            "adaptive_sync": True,
            "adaptive_eps": ada_eps_descending(n),
            "adaptive_min_cycle_steps": 1,
        },
    ]
]

algo_special_logs = True
algo_special_logs_kwargs = len(subproject_labels) * [
    [
        {
            "updated_q_values": True,
        }
        for _ in range(len(algorithms[0]))
    ]
]

learning_rate_kwargs = [
    [
        create_lr(args.lr_initial, args.lr_const, args.lr_warmup),
        create_lr(args.lr_initial, args.lr_const, args.lr_warmup),
        create_lr(args.lr_initial, args.lr_const, args.lr_warmup),
    ],
]

learning_rate_state_action_wise = len(learning_rate_kwargs[0]) * [False]

policy = algos.BasePolicy
policy_specific_params = {
    "policy_mode": "offpolicy",
    "policy_mode_kwargs": {"type": "uniform_random", "kwargs": {}},
}

eval_steps = 1
eval_freq = 100
eval_policy_choice = "greedy"
eval_policy_choice_kwargs = {}
max_steps_per_epoch = -1
correct_act_q_fct_mode = "value_iteration"
correct_act_q_fct_mode_kwargs = len(subproject_labels) * [
    {
        "n_max": 10000,
        "tol": 1e-10,
        "env_mean_rewards": {},
        "env_mean_rewards_mc_runs": 1000000,
    }
]

bias_estimation = True
focus_state_actions = True
which_state_actions_focus = (
    list(range(16)),
    [
        [0] if i in [15] + [i * 4 + 2 for i in range(4) if i != 1] else [0, 1, 2, 3]
        for i in range(16)
    ],
)
correct_action_log = True
correct_action_log_which = [1, 4, 6, 7]
eval_reseeding = False
eval_seed_schedule = [-1]
training_mode = "steps"
training_reseeding = False
training_seed_schedule = [-1]


# Plot functions and parameters to be applied
figsize = (4, 4)
loc = "best"
grid = True
show = False
save = True
plotkeys_for_single_plots = []
plotkeys_for_single_kwargs = [{}]
num_rows = [4, 4, 2]


plotkeys_for_boards = [
    "Mean Q function values at multiple chosen at evals",
    "Mean average special logs multiple at evals",
    "Mean bias metrics multiple at evals",
]
pairs = [
    (0, 2),
    (4, 1),
    (4, 2),
    (5, 1),
    (5, 2),
    (6, 1),
    (6, 3),
    (15, 0),
    (9, 0),
    (9, 1),
    (9, 3),
    (10, 0),
]

idx_list = indices_for_pairs_values_only(which_state_actions_focus, pairs)

q_star_dict_indices = {
    idx: value
    for idx, value in {
        indices_for_pairs_values_only(which_state_actions_focus, [(s, a)])[0]: value
        for (s, a), value in q_star_dict.items()
    }.items()
    if idx in idx_list
}
label_dict = {key: f"$Q^*$" for key in q_star_dict_indices.keys()}

plotkeys_for_boards_kwargs = [
    {"which": pairs},
    {
        "index": idx_list,
        "real_value": q_star_dict_indices,
        "real_value_label": label_dict,
    },
    {
        "squared": [False, False],
        "log_scale": [False, True],
        "normalized": [False, False],
        "best_arms": [False, False],
        "sup_norm": [True, True],
        "conv_int": [True, True],
        "conv_int_n": num_runs[0],
    },
]

# Cosmetic parameters.
safe_mode = True
verbose = True
progress = True
progress_single_games = True
runtime_estimation = True

###########################################################################################################################################################################################################################
###########################################################################################################################################################################################################################
###########################################################################################################################################################################################################################

print("The experiment will be initialized ... ")

# Intialize dictionary for keeping the result paths
resultspath_dict = {subproject_label: {} for subproject_label in subproject_labels}
num_of_experiments = len(algorithms) * len(algorithms[0])

# For each experiment, get the data
print("Executing the experiments ... \n\n")
exp_num = 1
for label_index, label in enumerate(subproject_labels):
    for algo_index, algo in enumerate(algorithms[label_index]):
        # Let the function run and save the path where the results are
        print(f"Run {exp_num} of {num_of_experiments}:")
        saved_path = execute_experiment(
            base_folder=base_folder,
            num_runs=num_runs[label_index],
            progress=progress,
            project_name=project_name,
            runtime_estimation=runtime_estimation,
            safe_mode=safe_mode,
            verbose=verbose,
            algo=algo,
            algo_special_logs=algo_special_logs,
            algo_special_logs_kwargs=algo_special_logs_kwargs[label_index][algo_index],
            bias_estimation=bias_estimation,
            which_state_actions_focus=which_state_actions_focus,
            focus_state_actions=focus_state_actions,
            correct_act_q_fct_mode=correct_act_q_fct_mode,
            correct_act_q_fct_mode_kwargs=correct_act_q_fct_mode_kwargs[label_index],
            correct_action_log=correct_action_log,
            correct_action_log_which=correct_action_log_which,
            env=environments[label_index],
            env_randomization=env_randomization[label_index],
            env_randomization_kwargs=env_randomization_kwargs[label_index],
            env_randomization_schedule=env_randomization_schedule[label_index],
            eval_reseeding=eval_reseeding,
            eval_seed_schedule=eval_seed_schedule,
            eval_steps=eval_steps,
            eval_freq=eval_freq,
            eval_policy_choice=eval_policy_choice,
            eval_policy_choice_kwargs=eval_policy_choice_kwargs,
            max_steps_per_epoch=max_steps_per_epoch,
            num_steps=num_steps[label_index],
            policy=policy,
            progress_single_games=progress_single_games,
            training_mode=training_mode,
            training_reseeding=training_reseeding,
            training_seed_schedule=training_seed_schedule,
            algo_specific_params=algorithms_specific_parameters[label_index][
                algo_index
            ],
            gamma=gamma,
            learning_rate_kwargs=learning_rate_kwargs[label_index][algo_index],
            learning_rate_state_action_wise=learning_rate_state_action_wise[algo_index],
            env_specific_params=environments_specific_parameters[label_index],
            policy_specific_params=policy_specific_params,
        )
        print("\n")
        exp_num += 1
        save_label = algo().__str__()
        save_label_temp = save_label
        save_label_index = 1
        while save_label_temp in resultspath_dict[label].keys():
            save_label_temp = save_label + f"_{save_label_index}"
            save_label_index += 1
        resultspath_dict[label][save_label_temp] = saved_path

print("All experiments have been executed.\n")

# Initialize dictionary for keeping the aggregated result paths
aggregated_resultspath_dict = {
    subproject_label: None for subproject_label in subproject_labels
}

# Aggregate the data algorithm-wise into appropriately named files
print("Aggregating data ... ")
if bias_estimation or focus_state_actions or correct_action_log:
    conditional_plots = True
else:
    conditional_plots = False
if algo_special_logs:
    special_plots = True
else:
    special_plots = False
for label_index, label in enumerate(subproject_labels):
    if algo_labels is None:
        labels = []
    else:
        labels = algo_labels[label_index]
    aggregated_saved_path = plots.results_single_to_batch_for_plot(
        result_paths=list(resultspath_dict[label].values()),
        labels=labels,
        output_folder=output_folder + "/.results_to_plot",
        project_name=project_name + "_" + label,
        safe_mode=safe_mode,
        conditional_plots=conditional_plots,
        special_plots=special_plots,
    )
    aggregated_resultspath_dict[label] = aggregated_saved_path

print("Plotting single plots ... \n")

activated_plots = ["default_act"]
if max_steps_per_epoch != -1:
    activated_plots.append("max_steps_per_epoch_act")
if bias_estimation:
    activated_plots.append("bias_estimation_act")
if focus_state_actions:
    activated_plots.append("focus_state_actions_act")
if correct_action_log:
    activated_plots.append("correct_action_log_act")
if algo_special_logs:
    activated_plots.append("special_act")

for label in subproject_labels:
    print(f"Plotting for subproject {label} ... ")
    for plotkey_index, plotkey in enumerate(plotkeys_for_single_plots):

        # Find the plotfunction
        if plotkey in plots.ALL_SINGLE_PLOT_KEYNAMES_STEPS:
            if plots.ALL_SINGLE_PLOT_KEYNAMES_STEPS[plotkey][1] in activated_plots:
                goalfunc = plots.ALL_SINGLE_PLOT_KEYNAMES_STEPS[plotkey][0]
            else:
                goalfunc = None
                print(
                    f"Warning: You did not gather the data necessary for the plot with key '{plotkey}', it will be skipped!"
                )
        elif plotkey in plots.ALL_SINGLE_PLOT_KEYNAMES_EPOCHS:
            if plots.ALL_SINGLE_PLOT_KEYNAMES_EPOCHS[plotkey][1] in activated_plots:
                goalfunc = plots.ALL_SINGLE_PLOT_KEYNAMES_EPOCHS[plotkey][0]
            else:
                goalfunc = None
                print(
                    f"Warning: You did not gather the data necessary for the plot with key '{plotkey}', it will be skipped!"
                )
        elif plotkey in plots.ALL_SINGLE_PLOT_KEYNAMES_EVAL:
            if plots.ALL_SINGLE_PLOT_KEYNAMES_EVAL[plotkey][1] in activated_plots:
                goalfunc = plots.ALL_SINGLE_PLOT_KEYNAMES_EVAL[plotkey][0]
            else:
                goalfunc = None
                print(
                    f"Warning: You did not gather the data necessary for the plot with key '{plotkey}', it will be skipped!"
                )
        elif plotkey in plots.ALL_SINGLE_PLOT_KEYNAMES_OTHER:
            if plots.ALL_SINGLE_PLOT_KEYNAMES_OTHER[plotkey][1] in activated_plots:
                goalfunc = plots.ALL_SINGLE_PLOT_KEYNAMES_OTHER[plotkey][0]
            else:
                goalfunc = None
                print(
                    f"Warning: You did not gather the data necessary for the plot with key '{plotkey}', it will be skipped!"
                )
        else:
            print(
                f"Warning: The plotkey '{plotkey}' matches none of the single plot keys, it will be skipped!"
            )

        # Plot the data
        if goalfunc is not None:
            goalfunc(
                input_path=aggregated_resultspath_dict[label],
                plot_folder=output_folder,
                project_name=project_name + "_" + label,
                figsize=figsize,
                loc=loc,
                grid=grid,
                show=show,
                save=save,
                mode="single plot",
                safe_mode=safe_mode,
                **plotkeys_for_single_kwargs[plotkey_index],
            )
        print("\n")

    for plotkey_index, plotkey in enumerate(plotkeys_for_boards):

        # Find the plotfunction
        if plotkey in plots.ALL_MULTIPLE_PLOT_KEYNAMES:
            if plots.ALL_MULTIPLE_PLOT_KEYNAMES[plotkey][1] in activated_plots:
                goalfunc = plots.ALL_MULTIPLE_PLOT_KEYNAMES[plotkey][0]
            else:
                goalfunc = None
                print(
                    f"Warning: You did not gather the data necessary for the plot with key '{plotkey}', it will be skipped!"
                )
        elif plotkey in plots.ALL_BOARD_PLOT_KEYNAMES:
            if plots.ALL_BOARD_PLOT_KEYNAMES[plotkey][1] in activated_plots:
                goalfunc = plots.ALL_BOARD_PLOT_KEYNAMES[plotkey][0]
            else:
                goalfunc = None
                print(
                    f"Warning: You did not gather the data necessary for the plot with key '{plotkey}', it will be skipped!"
                )
        else:
            print(
                f"Warning: The plotkey '{plotkey}' matches none of the single plot keys, it will be skipped!"
            )

        # Plot the data
        if goalfunc is not None:
            goalfunc(
                input_path=aggregated_resultspath_dict[label],
                plot_folder=output_folder,
                project_name=project_name + "_" + label,
                individual_figsize=figsize,
                num_rows=num_rows[plotkey_index],
                grid=grid,
                show=show,
                save=save,
                safe_mode=safe_mode,
                **plotkeys_for_boards_kwargs[plotkey_index],
            )
        print("\n")

# Print the locations of interesting stuff
for subproj in aggregated_resultspath_dict.keys():
    print(
        f"The aggregated results for subproject '{subproj}' can be found here:\n{aggregated_resultspath_dict[subproj]}\n\nThe single results can be found here:"
    )
    for alg_name in resultspath_dict[subproj].keys():
        print(f"{alg_name}: {resultspath_dict[subproj][alg_name]}")
    if save:
        path = os.path.join(output_folder, (project_name + "_" + subproj + "_plots"))
        print(f"\nThe saved plots can be found here:\n{path}\n")
