import numpy as np
from make_algorithm.make_algo_config import config as make_config, action_specs
from test_algorithm.test_algo_config import config as test_config
from make_algorithm.baselines import ADAPTIVE_BASELINE_CONFIGS, DEFAULT_BASELINE_CONFIGS
import argparse
from make_algorithm.make_algorithm_main import adapter as make_adapter
from test_algorithm.test_algorithm_main import adapter as test_adapter
from test_algorithm.test_algorithm_main import matrix_loss
from make_algorithm.make_algo_grid_config import name_to_type
from sorcerun.sacred_utils import run_sacred_experiment
from tqdm import tqdm
import matplotlib.pyplot as plt
import pickle
from plot import aggregate
EIGH = "torch.linalg.eigh"

plt.rcParams.update(
    {
        "font.size": 30,
        "figure.titlesize": 46,
        # "figure.figsize": (15, 12),
        "figure.figsize": (15, 12),
        "lines.linewidth": 3,
        "xtick.labelsize": 30,
        "ytick.labelsize": 30,
    }
)
plt.rcParams.keys()

SIGN = "sign"
SQRT = "sqrt"
INV = "inv"
PROOT = "proot"
MCTS = "mcts"   

CAMERA_READY_FUNC = {
    SIGN: "Matrix Sign",
    SQRT: "Matrix Square Root",
    INV: "Matrix Inverse",
    PROOT: "Matrix pth-root",
}

CAMERA_READY_DISTS = {
    "wishart": "Wishart",
    "wishart_unif": "Wishart Uniform",
    "CIFAR": "CIFAR-10",
    "quartic_saddle": "Hessian of Quartic",
    "unif": "Uniform",
    "Erdos_Renyi": "Erdos-Renyi"
}

from sorcerun.git_utils import (
    is_dirty,
    get_repo,
    get_commit_hash,
    get_time_str,
    get_tree_hash,
)

from sorcerun.incense_utils import (
    exps_to_xarray,
    load_filesystem_expts_by_config_keys,
    get_latest_single_and_grid_exps,
)
import sys

repo = get_repo()
ROOT = repo.working_dir
if ROOT not in sys.path:
    sys.path.append(ROOT)

make_algorithm_hash = get_tree_hash(repo, "make_algorithm")
test_algorithm_hash = get_tree_hash(repo, "test_algorithm")

from globals import SIGN, SQRT, INV, PROOT
import os
import json
from collections import defaultdict

from make_algorithm.baselines import DEFAULT_BASELINE_CONFIGS, ADAPTIVE_BASELINE_CONFIGS
from globals import MATRIX_FUNCTIONS, MCTS

all_baseline_names = sorted(
    list(DEFAULT_BASELINE_CONFIGS.keys()) + list(ADAPTIVE_BASELINE_CONFIGS.keys())
)

# partition the baseline names into groups by which matrix function they contain in their name

baseline_names_by_matrix_function = {
    func: [name for name in all_baseline_names if func in name]
    for func in MATRIX_FUNCTIONS
}

ALGORITHM_COLOR = {}
# get the default matplotlib color cycle
default_colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
# add the MCTS color
ALGORITHM_COLOR[MCTS] = default_colors[2]

# remove the MCTS color from the default color cycle
default_colors = default_colors[:2] + default_colors[3:]


for func, names in baseline_names_by_matrix_function.items():
    algs = sorted(names)
    for i, name in enumerate(algs):
        # get the color from the default color cycle
        if i > len(default_colors) - 1:
            raise ValueError(
                f"Not enough colors in default color cycle for {func} algorithms"
            )
        color = default_colors[i]
        ALGORITHM_COLOR[name] = color

ALGORITHM_COLOR[EIGH] = default_colors[-1]
# create a pallet to display name:color. a pallet for each matrix function
pallets = {**baseline_names_by_matrix_function, MCTS: [MCTS]}
pallet_keys = sorted(pallets.keys())

def make_camera_ready_name(name):
    """
    Make a camera ready name for the algorithm.
    """
    for func in MATRIX_FUNCTIONS:
        if func in name:
            name = name.replace(func, "")

    name = name.replace("_", " ")
    # Make the first letter and every letter after a space uppercase
    name = name.title().strip()
    name = name.replace("Db", "DB")
    name = name.replace("Nsv", "NSV")
    name = name.replace("Ns", "NS")
    return name


CAMERA_READY_NAMES = {
    name: make_camera_ready_name(name) for name in ALGORITHM_COLOR.keys()
}
CAMERA_READY_NAMES[MCTS] = "MatRL (Ours)"
CAMERA_READY_NAMES[EIGH] = "torch.linalg.eigh"

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

# prepare_config.py

def prepare_config(config, matrix_function, distribution_name, c, d, epsilon, precision, custom_loss, device):
    # Load config from make_algo_config.py

    # Update config fields
    config.update({
        "matrix_function": matrix_function,
        "init_mat_name": distribution_name,
        "init_mat_config": {
            "c": c,
            "d": d,
            "eps": epsilon
        },
        "precision": precision,
        "custom_loss": custom_loss, ## boolean - use or not
        "device": device
        })
    return config

def prepare_test_config(config, matrix_function, distribution_name, c, d, epsilon, precision, custom_loss, device):
    # Update config fields for testing
    config.update({
        "matrix_function": matrix_function,
        "test_mat_name": distribution_name,
        "test_mat_config": {
            "c": c,
            "d": d,
            "eps": epsilon
        },
        "precision": precision,
        "custom_loss": custom_loss, ## boolean - use or not
        "device": device,
    })
    return config

def run_grid(adapter, configs):

    total_num_params = len(configs)
    t = tqdm(enumerate(configs), total=total_num_params)
    for i, conf in t:
        print(
            "-" * 5
            + "GRID RUN INFO: "
            + f"Starting run {i+1}/{total_num_params}"
            + "-" * 5
        )

        _run = run_sacred_experiment(adapter, conf)

        print(
            "-" * 5
            + "GRID RUN INFO: "
            + f"Completed run {i+1}/{total_num_params}"
            + "-" * 5
        )

def pick_best_by_auc(curves) -> int:
    """
    Given a list of (t_grid, loss) pairs, return the index of the curve
    with the smallest area under the loss-vs-time curve.

    Parameters
    ----------
    curves : List of (t_grid, loss) tuples
        - t_grid: 1D array of time points
        - loss:   1D array of loss values (may contain NaNs)

    Returns
    -------
    best_idx : int
        Index of the curve with minimal AUC. Curves with fewer than two
        valid points are treated as +inf AUC.
    """
    aucs = []
    for t, loss in curves:
        # only keep valid (non-NaN) entries
        mask = ~np.isnan(loss)
        if mask.sum() < 2:
            # not enough data to form a curve → ignore it
            aucs.append(np.inf)
        else:
            aucs.append(np.trapz(loss[mask], t[mask]))
    # argmin gives the index of the smallest AUC
    print("###################################################################")
    print("###################################################################")
    print("###################################################################")
    print("###################################################################")
    print("###################################################################")
    print(aucs)
    print(aucs)
    print(aucs)
    print("###################################################################")
    return int(np.argmin([abs(x) for x in aucs]))

def plot_raw_curves(results, t_limit, y_low, y_high, width, algname):
   
    for tvals, lvals in results:
        plt.semilogy(tvals, lvals, color=ALGORITHM_COLOR[algname], linewidth=width, alpha = 0.5)
        plt.xlim(0, t_limit)
        plt.ylim(y_low, y_high)
    
    common_time = np.linspace(0, t_limit, 1000)
    aligned_losses = []
    for tvals, lvals in results:
     
        aligned_loss = np.interp(common_time, tvals, np.log(lvals))
        aligned_losses.append(aligned_loss)

    loss_array = np.vstack(aligned_losses)
    mean_loss = loss_array.mean(axis=0)

    if algname in ALGORITHM_COLOR.keys():
        plt.semilogy(common_time, np.exp(mean_loss), color = ALGORITHM_COLOR[algname], linewidth = width, label=algname)
        plt.xlim(0, t_limit)
        plt.ylim(y_low,y_high)
    

def align_and_plot_times(results, t_limit, y_low, y_high, algname, width):
    
    common_time = np.linspace(0, t_limit, 100)
    aligned_losses = []
    for tvals, lvals in results:
     
        aligned_loss = np.interp(common_time, tvals, np.log([x+y_low for x in lvals]))
        aligned_losses.append(aligned_loss)

    loss_array = np.vstack(aligned_losses)
    mean_loss = loss_array.mean(axis=0)
    std_loss = loss_array.std(axis=0)

    N = loss_array.shape[0]
    # Compute 95% confidence interval
    ci95 = 1.96 * std_loss / np.sqrt(N)

    plt.plot(common_time, mean_loss, color = pallets[algname], linewidth = width)
    plt.xlim(0, t_limit)
    plt.ylim(np.log10(y_low),np.log10(y_high))
    plt.fill_between(common_time, mean_loss - ci95, mean_loss + ci95, alpha=0.3)

def plot(result_step_dict, result_time_dict, args, name, method):

    ## Plot raw curves
    plt.figure(figsize=(10, 6))
    plt.xlabel("Time (seconds)")
    plt.ylabel("Loss (log value)")
    plt.title(f"Computing matrix {args.matrix_function}, Loss vs Time")
    alg_list = result_step_dict.keys()
    for (i,alg) in enumerate(alg_list):
        width = 2
        method(result_time_dict[alg], args.plot_t_limit, args.plot_y_low, args.plot_y_high, width, alg)

    plt.legend(loc="upper left", bbox_to_anchor=(1.05, 1))
    plt.tight_layout()
    plt.savefig(args.result_dir + f"/{name}.png")

if __name__ == "__main__":

    # Parse command-line arguments
    parser = argparse.ArgumentParser(description="Process configuration parameters.")
    parser.add_argument("--matrix_function", type=str, required=True, help="Matrix function to use.")
    parser.add_argument("--distribution_name", type=str, required=True, help="Name of the distribution.")
    parser.add_argument("--c", type=float, required=True, help="Parameter c.")
    parser.add_argument("--d", type=int, required=True, help="Parameter d.")
    parser.add_argument("--epsilon", type=float, required=True, help="Epsilon value.")
    parser.add_argument("--precision", type=str, required=True, help="Precision, float or double")
    parser.add_argument("--custom_loss", type=str, required=True, help="Use custom loss or not.")
    parser.add_argument("--device", type=str, required=True, help="Device to use (e.g., 'cpu', 'cuda').")
    parser.add_argument("--make_algorithm", type=str, required=True, help="Will you generate algorithm? if not, read algo_lists for exp id")
    parser.add_argument("--algorithm_id", type=list, required=False, help="if make_algorithm = False, give list of id")
    parser.add_argument("--use_recent", type=str, required=False, default=False, help="Use the most recent algorithm configuration.")
    parser.add_argument("--mcts_c_ucb", type=float, required=False, default=5, help="Exploration constant for MCTS UCB.")
    parser.add_argument("--mcts_alpha_pw", type=float, required=False, default=0.4, help="Alpha parameter for progressive widening.")
    parser.add_argument("--mcts_explore_k", type=int, required=False, default=5, help="Number of exploration steps for MCTS.")
    parser.add_argument("--mcts_budget", type=int, required=False, default=int(3e5), help="Num of steps for the algorithm.")
    parser.add_argument("--mcts_repeats", type=int, required=False, default=10, help="Number of repetitions for MCTS.")
    parser.add_argument("--test_algorithm", type=str, required=True, help="Will you test the algorithm? if not, use most recent")
    parser.add_argument("--test_repeats", type=int, required=False, default = 100, help="Number of test repeats.")
    parser.add_argument("--test_distribution_name", type=str, required=False, help="Name of the test distribution.")
    parser.add_argument("--test_c", type=float, required=False, help="Parameter c for the test distribution.")
    parser.add_argument("--test_d", type=int, required=False, help="Parameter d for the test distribution.")
    parser.add_argument("--test_epsilon", type=float, required=False, help="Epsilon value for the test distribution.")
    parser.add_argument("--plot_t_limit", type=float, required=False, default=100, help="Time limit for the algorithm.")
    parser.add_argument("--plot_y_low", type=float, required=False, default=1e-14, help="Lower bound for y-axis in plots.")
    parser.add_argument("--plot_y_high", type=float, required=False, default=1e5, help="Upper bound for y-axis in plots.")
    parser.add_argument("--result_dir", type=str, required=False, default="results", help="Directory to save results.")
    args = parser.parse_args()

    # If test parameters are not provided, use the ones from the make configuration
    if args.test_c is None:
        args.test_c = args.c
    if args.test_d is None:
        args.test_d = args.d
    if args.test_epsilon is None:
        args.test_epsilon = args.epsilon
    if args.test_distribution_name is None:
        args.test_distribution_name = args.distribution_name

    if args.matrix_function == "inv":
        CAMERA_READY_NAMES[EIGH] = "torch.linalg.inv"

    # Prepare configuration
    config = prepare_config(
        config = make_config,
        matrix_function=args.matrix_function,
        distribution_name=args.distribution_name,
        c=args.c,
        d=args.d,
        epsilon=args.epsilon,
        precision=args.precision,
        custom_loss=str2bool(args.custom_loss),
        device=args.device
    )
    config.update({"make_algorithm_hash": make_algorithm_hash})

    ######### Generate baselines that compute our matrix function #########

    ## Run baseline make_alg

    algs_id_list = []
    if str2bool(args.make_algorithm) == True:

        baseline_only_configs = [
            prepare_config(
                config = dict(
                    experiment_name="make_algorithm",
                    algorithm_name=name,
                    algorithm_config={**conf, "size": config["init_mat_config"]["d"], "device": config["device"]},
                ),
                matrix_function=name_to_type(name),
                distribution_name=args.distribution_name,
                c=args.c,
                d=args.d,
                epsilon=args.epsilon,
                precision=args.precision,
                custom_loss=args.custom_loss,
                device=args.device
            )
            for name, conf in DEFAULT_BASELINE_CONFIGS.items()
        ]

        ## Run adaptive baseline make_config
        adaptive_baseline_only_configs = [
            prepare_config(
                config = dict(
                    experiment_name="make_algorithm",
                    algorithm_name=name,
                    algorithm_config={**conf, "size": config["init_mat_config"]["d"], "device": config["device"]},
                ),
                matrix_function=name_to_type(name),
                distribution_name=args.distribution_name,
                c=args.c,
                d=args.d,
                epsilon=args.epsilon,
                precision=args.precision,
                custom_loss=args.custom_loss,
                device=args.device
            )
            for name, conf in ADAPTIVE_BASELINE_CONFIGS.items()
        ]

        configs = baseline_only_configs + adaptive_baseline_only_configs

        ## MCTS config
        config["algorithm_config"].update({
            "c_ucb": args.mcts_c_ucb,
            "alpha_pw": args.mcts_alpha_pw,
            "EXPLORE_K": args.mcts_explore_k,
            "budget": args.mcts_budget,
            "actions": action_specs[args.matrix_function],
            "device": args.device
            }
        )

        for _ in range(args.mcts_repeats):
            configs.append(config)
        
        seeds = 42 + np.arange(len(configs))
        commit_hash = get_commit_hash(repo)
        time_str = get_time_str()
        dirty = is_dirty(repo)
        grid_id = (f"{time_str}--{commit_hash}--dirty={dirty}",)
        configs = [
            {
                **c,
                "seed": int(seed),
                "make_algorithm_hash": make_algorithm_hash,
                "commit_hash": commit_hash,
                "time_str": time_str,
                "dirty": dirty,
                "grid_id": grid_id[0]
            }
            for c, seed in zip(configs, seeds) if c["matrix_function"] == args.matrix_function
        ]


        ## Generate baselines

        import time
        start = time.time()
        run_grid(adapter=make_adapter, configs=configs)
        with open("times.txt","a") as f:
            f.write(f"Finished making algorithm in {time.time() - start} seconds\n")

        # get single and grid experiments once in this file for efficiency

        make_exps = load_filesystem_expts_by_config_keys(
            f"{repo.working_dir}/file_storage/runs",
            make_algorithm_hash=make_algorithm_hash,
            experiment_name="make_algorithm"
            )
        make_exp, make_grid_exps = get_latest_single_and_grid_exps(make_exps)
    
    make_exps = load_filesystem_expts_by_config_keys(
        f"{repo.working_dir}/file_storage/runs",
        make_algorithm_hash=make_algorithm_hash,
        experiment_name="make_algorithm",
        )
    make_exp, make_exp_grid = get_latest_single_and_grid_exps(make_exps)
    for me in make_exp_grid:
        algs_id_list.append(me.id)

    if str2bool(args.test_algorithm) == True:

        
        #algs_id_list = choose_best_mcts(algs_id_list)
        ## Choose best MCTS
        # algs_id_list = [7527,7528,7529,7530,7531,7532,7533,7534,7535,7536,7537,7538] # Generalization - manual
        print(algs_id_list)
        print(len(algs_id_list))

        ### Test on these baselines. Repeat 100 times.

        test_config = prepare_test_config(
            config = test_config,
            matrix_function=args.matrix_function,
            distribution_name=args.distribution_name,
            c=args.c,
            d=args.d,
            epsilon=args.epsilon,
            precision=args.precision,
            custom_loss=str2bool(args.custom_loss),
            device=args.device
        )
        print(test_config)

        test_configs = [
            {
                **test_config,
                "make_algorithm_run_id" : e,
                "repeat":r
            }
            for e in algs_id_list
            for r in range(args.test_repeats)
        ]

        seeds = 42 + np.arange(len(test_configs))
        commit_hash = get_commit_hash(repo)
        time_str = get_time_str()
        dirty = is_dirty(repo)
        grid_id = (f"{time_str}--{commit_hash}--dirty={dirty}",)
        test_configs = [
            {
                **c,
                "seed": int(seed),
                "test_algorithm_hash": test_algorithm_hash,
                "commit_hash": commit_hash,
                "time_str": time_str,
                "dirty": dirty,
                "grid_id": grid_id[0]
            }
            for c, seed in zip(test_configs, seeds)
        ]
        
        run_grid(adapter=test_adapter, configs = test_configs)    

    algo_to_ids = {}
    for me in make_exp_grid:
        algo_to_ids.setdefault(me.config.algorithm_name, []).append(me.id)

    algo_to_ids = {"sqrt_db":[7527],"sqrt_nsv":[7528],"sqrt_visser":[7529],"sqrt_newton":[7530],"sqrt_visser_coupled":[7531],"sqrt_newton_coupled":[7532],"sqrt_scaled_db":[7533],"mcts":[7534,7535,7536,7537,7538]} # Generalization - manual
    
    ## Plot the results.
    test_exps = load_filesystem_expts_by_config_keys(
            f"{repo.working_dir}/file_storage/runs",
            test_algorithm_hash=test_algorithm_hash,
            experiment_name="test_algorithm",
        )

    result_step_dict = {}
    result_time_dict = {}

    _, test_exp_grid = get_latest_single_and_grid_exps(test_exps)
    _, test_grid = exps_to_xarray(test_exp_grid)

    algo_names = sorted(algo_to_ids.keys())

    # Force MatRL to be first
    if MCTS in algo_names:
        algo_names.remove(MCTS)
        algo_names = [MCTS] + algo_names


    y_metric = "relative_loss"
    x_types = list(
        str(v)
        for v in test_grid.coords["metric"].values
        if v
        not in [
            "loss",
            "time",
            "relative_loss",
            "action_abs_time_so_far",
        ]
    )   

    from make_algorithm.actions import _safe_inverse
    from matrix_distributions.matrix_distributions import MATRIX_DISTRIBUTIONS
    import time
    import torch
    from tqdm import tqdm
    times = []
    losses = []
    repeats = 50
    start_gpu = torch.cuda.Event(enable_timing=True)
    end_gpu = torch.cuda.Event(enable_timing=True)
    for i in range(repeats):
        # Warm up the GPU
        A = torch.randn(1000, 1000, device="cuda")
        for _ in range(10):
            _ = torch.matmul(A, A)
    for i in tqdm(range(repeats)):

        if args.device == "cuda":
            A = (
                MATRIX_DISTRIBUTIONS[test_exp_grid[0].config.test_mat_name](
                    **test_exp_grid[0].config.test_mat_config
                ).cuda()
                # .double()
            )
        else:
            A = (
                MATRIX_DISTRIBUTIONS[test_exp_grid[0].config.test_mat_name](
                    **test_exp_grid[0].config.test_mat_config
                ).cpu()
                # .double()
            )

        if args.device == "double":
            A = A.double()
       
        if args.matrix_function == INV:
            start_gpu.record()
            X = torch.linalg.inv(A)
            end_gpu.record()
            torch.cuda.synchronize()

        elif args.matrix_function == SIGN:
            start_gpu.record()
            D, U = torch.linalg.eigh(A)
            X = U @ torch.diag(torch.sign(D)) @ U.T
            end_gpu.record()
            torch.cuda.synchronize()

        elif args.matrix_function == SQRT:
            start_gpu.record()
            D, U = torch.linalg.eigh(A)
            X = U @ torch.diag(torch.sqrt(D)) @ U.T
            end_gpu.record()
            torch.cuda.synchronize()

        elif args.matrix_function == PROOT:
            start_gpu.record()
            D, U = torch.linalg.eigh(A)
            X = U @ torch.diag(D**(1/3)) @ U.T
            end_gpu.record()
            torch.cuda.synchronize()

        loss = (matrix_loss(X, A, args.matrix_function)/torch.linalg.norm(A)).item()
        times.append(start_gpu.elapsed_time(end_gpu) / 1000)
        losses.append(loss)

    times = np.array(times)
    losses = np.array(losses)
        # get the mean and std of the times
    mean_time = np.mean(times)
    std_time = np.std(times)
    print(f"Mean time: {mean_time:.4f} ± {std_time:.4f} s")

    # get the mean and std of the losses
    mean_loss = np.mean(losses)
    std_loss = np.std(losses)
    print(f"Mean loss: {mean_loss:.4e} ± {std_loss:.4e}")   
    mcts_labeled = False

    # get the best mcts id
    mcts_ids = algo_to_ids[MCTS]
    mcts_losses = []
    for mid in mcts_ids:
        tests = test_grid.sel(make_algorithm_run_id=mid)
        x = tests.sel(metric="time").values.copy().T
        y = tests.sel(metric="loss").values.copy().T
        t_grid, mean, lo, hi = aggregate(
            x,
            np.log10(y),
            num_pts=256,
            error="std",
            # error="iqr",
            grid="linear",
        )

        mcts_losses.append((t_grid, 20 + mean))

    # pick the best mcts id by AUC
    best_mcts_idx = pick_best_by_auc(mcts_losses)
    best_mcts_mid = mcts_ids[best_mcts_idx]
    print(best_mcts_mid)

    # Load the best MCTS ID from the file
    best_mcts_id_path = os.path.join(repo.working_dir, f"file_storage/runs/{best_mcts_mid}/cout.txt")
    if os.path.exists(best_mcts_id_path):
        with open(best_mcts_id_path, "r") as file:
            best_mcts_id_content = file.read()
        
        # Write the content to the results directory
        if not os.path.exists(args.result_dir):
            os.makedirs(args.result_dir)
        results_path = os.path.join(args.result_dir, "best_mcts_id.txt")
        with open(results_path, "w") as result_file:
            result_file.write(best_mcts_id_content)
    else:
        print(f"File {best_mcts_id_path} does not exist.")

    plt.figure()

    tlim = -123
    ylowlim = 123123212
    yhighlim = -123232
    for name in algo_names:
        # for mid, name in make_id_to_name.items():
        if name != MCTS:
            mid = algo_to_ids[name][0]
        else:
            mid = best_mcts_mid
        tests = test_grid.sel(make_algorithm_run_id=mid)
        x = tests.sel(metric="time").values.T
        y = tests.sel(metric="loss").values.T
        ylowlim = min(ylowlim, np.nanmin(np.log10(y)))

        if mid == best_mcts_mid:
            tlim = np.nanmax(x)
            yhighlim = np.nanmax(np.log10(y))

        t_grid, mean, lo, hi = aggregate(
            x,
            np.log10(y),
            num_pts=256,
            error="std",
            # error="iqr",
            grid="linear",
        )
        # if name == MCTS:
        #     label = r"$\textbf{" + CAMERA_READY_NAMES[name] + "}$"
        # else:
        #     label = CAMERA_READY_NAMES[name]
        label = CAMERA_READY_NAMES[name]

        if name == MCTS and not mcts_labeled:
            mcts_labeled = True
        elif name == MCTS and mcts_labeled:
            label = None

        plt.plot(
            t_grid,
            mean,
            linewidth=8 if name == MCTS else 4,
            label=label,
            color=ALGORITHM_COLOR[name],
        )
        plt.fill_between(t_grid, lo, hi, alpha=0.2, color=ALGORITHM_COLOR[name])
    # plt.xlabel(f"{x_metric} (s)")
    plt.xlabel(f"Time (s)")
    plt.ylabel("Log Relative Error")

    # Use result of MCTS for x-axis, y-axis

    ylowlim = min(ylowlim, np.log10(mean_loss))
    plt.xlim(0, max(tlim*2, mean_time+tlim))
    plt.ylim(ylowlim-0.5, yhighlim)
    # plt.title(f"Log Loss vs {x_metric}")
    init_mat_name = test_exp_grid[0].config.test_mat_name
    # init_mat_size = make_exp_grid[0].config.init_mat_config.d
    A = (
        MATRIX_DISTRIBUTIONS[test_exp_grid[0].config.test_mat_name](
            **test_exp_grid[0].config.test_mat_config
        ).cuda()
    )
    init_mat_size = A.shape[0]

    title = f"{CAMERA_READY_FUNC[args.matrix_function]}, {CAMERA_READY_DISTS[args.distribution_name]} (d={init_mat_size})"
    plt.title(title)
    # legend outside plot
    plt.plot(
        mean_time,
        np.log10(mean_loss),
        "o",
        color=ALGORITHM_COLOR[EIGH],
        label=CAMERA_READY_NAMES[EIGH],
        markersize=20,
    )
    plt.errorbar(
        mean_time,
        np.log10(mean_loss),
        xerr=std_time,
        yerr=std_loss,
        fmt="o",
        color=ALGORITHM_COLOR[EIGH],
        markersize=20,
        # label=CAMERA_READY_NAMES[EIGH],
    )
    leg = plt.legend(
        loc="upper left",
        bbox_to_anchor=(1.02, 1),
        fontsize=30,
        title="Algorithm",
        title_fontsize=30,
        frameon=False,
    )
    # Set the font weight for the first label
    leg.texts[0].set_fontweight("bold")

    plt.grid()
    plot_name = f"{args.matrix_function}_{init_mat_name}_d={init_mat_size}"

    if not os.path.exists(args.result_dir):
        os.makedirs(args.result_dir)

    plt.savefig(
        args.result_dir+f"/{plot_name}.png",
        dpi=300,
        bbox_inches="tight",
        bbox_extra_artists=[leg],
    )
    plt.savefig(
        args.result_dir+f"/{plot_name}.pdf",
        dpi=300,
        bbox_inches="tight",
        bbox_extra_artists=[leg],
    )
    plt.savefig(
        args.result_dir+f"/{plot_name}.eps",
        dpi=300,
        bbox_inches="tight",
        bbox_extra_artists=[leg],
    )
    plt.show()


    
    