import argparse
import pickle

import numpy as np
from joblib import Parallel, delayed

from Algorithms.CTS import play_CTS
from Algorithms.GenCTS import play_GenCTS
from Algorithms.GenLBINFV import play_GenLBINFV
from Algorithms.LBINFV import play_LBINFV
from CreateAction.action_list_maker import return_action_list
from utils import generate_loss
import os


def get_args():
    parser = argparse.ArgumentParser(description="argparse script")
    parser.add_argument(
        "-num_exp",
        "--num_exp",
        type=int,
        default=100,
        help="The number of times we run the experiment",
    )
    parser.add_argument(
        "-regime",
        "--regime",
        type=str,
        choices=[
            "Stochastic",
            "StochasticWithCorruption",
        ],
        default="Stochastic",
        help="Which regime to consider.",
    )
    parser.add_argument(
        "-time_horizon",
        "--time_horizon",
        type=int,
        default=2000,
        help="The value of sigma.",
    )
    parser.add_argument(
        "-action_list_making_iter",
        "--action_list_making_iter",
        type=int,
        default=200,
        help="The number of iterations to create the action list.",
    )
    parser.add_argument(
        "-change_dist_ratio",
        "--change_dist_ratio",
        type=float,
        default=0.1,
        help="The timing when the distribution changes.",
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = get_args()
    cost_min, cost_max = 0.1, 0.5
    file_name = "results/regret_result_{}_{}_{}_{}.txt".format(
        args.regime,
        cost_min,
        cost_max,
        args.time_horizon,
    )
    algorithms_list = [
        "GenLBINFV",
        "LBINFV",
        "GenCTS",
        "CTS",
    ]
    if os.path.exists(file_name):
        with open(file_name, "rb") as f:
            result_list_all = pickle.load(f)
    else:
        result_list_all = {}
    for exp in range(len(result_list_all), len(result_list_all) + args.num_exp):
        print("We start the {}-th experiment.".format(exp + 1))

        result_list_all[exp] = {}

        supplier = np.array([1, 2, 3, 4])
        demander = np.array([1, 2, 3, 4])

        # Generate cost matrix
        cost_matrix = np.random.uniform(
            cost_min,
            cost_max,
            (len(supplier), len(demander)),
        )
        n_i = np.zeros((len(supplier), len(demander)))
        for i in range(len(supplier)):
            for j in range(len(demander)):
                n_i[i][j] = min(supplier[i], demander[j])

        action_list = return_action_list(
            suppliers=supplier,
            demanders=demander,
            iter=args.action_list_making_iter,
        )

        prob = (np.ones(len(action_list)) / len(action_list)).reshape(
            len(action_list), 1, 1
        )

        # Generate losses for each time step t
        loss_list = generate_loss(
            cost_matrix=cost_matrix,
            supplier=supplier,
            demander=demander,
            time_horizon=args.time_horizon,
        )

        if args.regime == "StochasticWithCorruption":
            change_dist_ratio = args.change_dist_ratio
        else:
            change_dist_ratio = 1.0  # No change for the loss distribution
        for algorithm in algorithms_list:
            print(algorithm)
            if algorithm == "GenLBINFV":
                GenLBINFV_LS = play_GenLBINFV(
                    action_list=action_list,
                    cost_matrix=cost_matrix,
                    regime=args.regime,
                    m_estimation="LeastSquare",
                    n_i=n_i,
                    loss_list=loss_list,
                    time_horizon=args.time_horizon,
                    change_dist_ratio=change_dist_ratio,
                )
                GenLBINFV_GD = play_GenLBINFV(
                    action_list=action_list,
                    cost_matrix=cost_matrix,
                    regime=args.regime,
                    m_estimation="GradientDescent",
                    n_i=n_i,
                    loss_list=loss_list,
                    time_horizon=args.time_horizon,
                    change_dist_ratio=change_dist_ratio,
                )

            elif algorithm == "LBINFV":
                LBINFV_LS = play_LBINFV(
                    supplier=supplier,
                    demander=demander,
                    action_list=action_list,
                    cost_matrix=cost_matrix,
                    regime=args.regime,
                    m_estimation="LeastSquare",
                    loss_list=loss_list,
                    time_horizon=args.time_horizon,
                    change_dist_ratio=change_dist_ratio,
                )
                LBINFV_GD = play_LBINFV(
                    supplier=supplier,
                    demander=demander,
                    action_list=action_list,
                    cost_matrix=cost_matrix,
                    regime=args.regime,
                    m_estimation="GradientDescent",
                    loss_list=loss_list,
                    time_horizon=args.time_horizon,
                    change_dist_ratio=change_dist_ratio,
                )

            elif algorithm == "GenCTS":
                GenCTS = play_GenCTS(
                    action_list=action_list,
                    cost_matrix=cost_matrix,
                    regime=args.regime,
                    n_i=n_i,
                    loss_list=loss_list,
                    time_horizon=args.time_horizon,
                    change_dist_ratio=change_dist_ratio,
                )

            elif algorithm == "CTS":
                CTS = play_CTS(
                    supplier=supplier,
                    demander=demander,
                    action_list=action_list,
                    cost_matrix=cost_matrix,
                    regime=args.regime,
                    loss_list=loss_list,
                    time_horizon=args.time_horizon,
                    change_dist_ratio=change_dist_ratio,
                )
        results_paralleled = Parallel(n_jobs=-1)(
            [
                delayed(LBINFV_LS.run)(),
                delayed(LBINFV_GD.run)(),
                delayed(GenLBINFV_LS.run)(),
                delayed(GenLBINFV_GD.run)(),
                delayed(CTS.run)(),
                delayed(GenCTS.run)(),
            ]
        )

        result_list_all[exp]["LBINFV (LS)"] = results_paralleled[0]
        result_list_all[exp]["LBINFV (GD)"] = results_paralleled[1]
        result_list_all[exp]["GenLBINFV (LS)"] = results_paralleled[2]
        result_list_all[exp]["GenLBINFV (GD)"] = results_paralleled[3]
        result_list_all[exp]["CTS"] = results_paralleled[4]
        result_list_all[exp]["GenCTS"] = results_paralleled[5]

        # Save result
        with open(
            file_name,
            "wb",
        ) as f:
            pickle.dump(result_list_all, f)
