from algorithms import *
from instances import *
from util import *

"""
Solve single instance with a single algorithm
"""
def solve_instance_with_alg(
        instance: OnlineBipartiteInstance,
        alg: FractionalMatchingAlgorithm
    ) -> float:
    n = instance.get_n()
    L_weights = instance.get_L_weights()
    alg.init_for_new_instance(n, L_weights)
    for t in range(n):
        X_so_far = alg.get_X_so_far()
        v_arrival = instance.get_next_arrival_with_advice(t, X_so_far)
        alg.match(v_arrival)
    opt_obj = alg.compute_offline_optimal()
    alg_obj = alg.compute_obtained_online_matching()
    ratio = alg_obj/opt_obj
    return ratio

"""
Worker function for multithread
"""
def worker_fn(args):
    alg_idx, instance, alg = args
    return alg_idx, solve_instance_with_alg(instance, alg)

"""
Solve a bunch of instances with a bunch of algorithms
"""
def solve_all(
        instances: list[OnlineBipartiteInstance],
        algs: list[FractionalMatchingAlgorithm],
        multithread: bool = True
    ):
    result_dict = dict()
    for instance in instances:
        weight_hash = short_hash_from_string(str(instance.L_weights))
        alg_hash = short_hash_from_string(" ".join([alg.name for alg in algs]))
        result_pickle_fname = f"results/{instance.folder}/{instance.name}_{weight_hash}_{alg_hash}.pickle"
        if not os.path.isfile(result_pickle_fname):
            print(f"{result_pickle_fname} not found. Solving from scratch")
            inst_result_dict = dict()
            if multithread:
                # Multi-thread version
                with Pool() as pool:
                    arg_list = [(alg_idx, instance, algs[alg_idx]) for alg_idx in range(len(algs))]
                    results = list(
                        tqdm(pool.imap(worker_fn, arg_list), total=len(arg_list), desc=f"{instance.name}")
                    )
                    for alg_idx, ratio in results:
                        inst_result_dict[(alg_idx, instance.noise_param)] = ratio
            else:
                # Single thread version
                for alg_idx in tqdm(range(len(algs)), desc=f"{instance.name}"):
                    ratio = solve_instance_with_alg(instance, algs[alg_idx])
                    inst_result_dict[(alg_idx, instance.noise_param)] = ratio
            with open(result_pickle_fname, "wb") as f:
                pickle.dump(inst_result_dict, f)
        with open(result_pickle_fname, "rb") as f:
            inst_result_dict = pickle.load(f)

        for key, val in inst_result_dict.items():
            if key not in result_dict.keys():
                result_dict[key] = []
            result_dict[key].append(val)
    for key, val in result_dict.items():
        result_dict[key] = np.array(val)
    return result_dict

"""
Run the set of experiments for a fixed graph type and save the resulting plot
"""
def run_experiment(
        experiment_name: str,
        instances: list[OnlineBipartiteInstance],
        algs: list[FractionalMatchingAlgorithm],
        noise_params: list[float],
        multithread=True,
        plot_and_save=True
    ) -> None:
    # Collect results
    result_dict = solve_all(instances, algs, multithread)
    
    # Process results
    if plot_and_save:
        num_algs = len(algs)
        means = [
            np.array([
                np.mean(result_dict[(i, noise_param)])
                for noise_param in noise_params
            ])
            for i in range(num_algs)
        ]
        stds = [
            np.array([
                np.std(result_dict[(i, noise_param)])
                for noise_param in noise_params
            ])
            for i in range(num_algs)
        ]

        # Plot and save
        plt.figure()
        for i in range(num_algs):
            if i <= 1:
                line_handle, = plt.plot(noise_params, means[i], label=algs[i].name)
            elif i <= 5:
                line_handle, = plt.plot(noise_params, means[i], label=algs[i].name, ls='--')
            else:
                line_handle, = plt.plot(noise_params, means[i], label=algs[i].name, ls=':')
                
            base_color = to_rgb(line_handle.get_color())
            fill_color = 0.5 * np.array(base_color) + 0.5 * np.array([1.0, 1.0, 1.0])
            plt.fill_between(noise_params, means[i] - stds[i], means[i] + stds[i], color=fill_color, alpha=0.5)
        plt.xlabel("Noise Parameter")
        plt.ylabel("Obtained competitive ratio (Mean ± 1 S.D.)")
        plt.legend(ncol=2, fontsize='small')
        plt.title(f"{experiment_name}")
        plt.savefig(f"plots/{experiment_name.replace(" ", "_")}.png", dpi=300)
        # plt.show()

    return result_dict
