
import os
import sys
import inspect

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 

import random
import copy
import time
from dataclasses import dataclass

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
import tyro
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter

from scheduling.environment import SchedulingEnvironment, generate_environment, make_env

# from solvers.hetgat_solver_individual import HetGatSolverIndividual
from solvers.hetgat_solver_simultaneous import HetGatSolverSimultaneous
from models.graph_scheduler import GraphSchedulerCritic

from solvers.edf import EarliestDeadlineFirstAgent
from solvers.improved_edf import ImprovedEarliestDeadlineFirstAgent
from solvers.genetic_algorithm import GeneticAlgorithm
from solvers.milp_solver import MILP_Solver, solve_with_MILP
from training.replay_buffer import ReplayBuffer

from utils.utils import set_seed
from get_checkpoints import get_checkpoints


# create an display a boxplot of greedy, sampled and combined models in the x axis and feasible count in the y axis
import matplotlib
import matplotlib.pyplot as plt


label_map = {
    "edf": "EDF",
    "improved_edf": "ImpEDF",
    "milp_solver": "MILP",
    "hetgat": "HetGAT",
    "hetgat_resnet": "HetGAT ResNet",
    "hgt": "HGT",
    "hgt_edge": "HGT Edge",
    "hgt_edge_resnet": "HGT Edge ResNet",
    "hgt_edge_resnet_bb": "HGT Edge ResNet BB"
}

env_prefix = "paper neurips25"
# seeds = [10]
# model_name = ["hetgat"]
# models = ["hetgat"]
# problem_sets = ["data/problem_set_r10_t20_s0_f10_w25_euc_2000_uni"]
# problem_sets_prefixes = ["10r20t0s10f25w"]
# reward_model = ["makespan-atari"]


seeds = [10, 11, 12]
models = ["hgt_simultaneous", "hgt_edge_simultaneous", "hgt_edge_resnet_simultaneous"]
model_name = ["hgt_simultaneous", "hgt_edge_simultaneous", "hgt_edge_resnet_simultaneous"]
problem_sets = ["data/problem_set_r2_t5_s10_f30_w50_euc_2000", "data/problem_set_r5_t20_s10_f30_w50_euc_2000"]
problem_sets_prefixes = ["10r20t0s10f25w"]
reward_model = [""]
reward_model = reward_model[:1]
baselines = ['milp_reward', 'improved_edf_reward', 'edf_reward'] # , 'max_sample_reward', 'mean_sample_reward', 'min_sample_reward']

key_words = ['reward-greedy-final1']

num_heads = 1
num_layers = 4

@dataclass
class Args:
    exp_name: str = 'evaluation' # os.path.basename(__file__)[: -len(".py")]
    """the name of this experiment"""
    seed: int = 10
    """seed of the experiment"""
    torch_deterministic: bool = True
    """if toggled, `torch.backends.cudnn.deterministic=False`"""
    cuda: bool = False
    """if toggled, cuda will be enabled by default"""
    track: bool = False
    """if toggled, this experiment will be tracked with Weights and Biases"""
    wandb_project_name: str = "Task Allocation and Scheduling with Path Planning using GNNs"
    """the wandb's project name"""
    wandb_entity: str = "wandb_repo"
    """the entity (team) of wandb's project"""
    capture_video: bool = False
    """whether to capture videos of the agent performances (check out `videos` folder)"""

    save_every: int = 1
    """save the model every `save_every` iterations"""
    
    # Algorithm specific arguments
    env_id: str = "evaluation_simultaneous"
    """the id of the environment"""
    # env_location: str = "data/problem_set_r5_t20_s10_f30_w50_euc_2000"
    env_location: str = "data/problem_set_r10_t20_s0_f10_w25_euc_200_test"
    """the location of the environment"""
    total_timesteps: int = 10000000
    """total timesteps of the experiments"""
    num_envs: int = 8
    """the number of parallel game environments"""
    num_steps: int = 128
    """the number of steps to run in each environment per policy rollout"""
    
    # Environment specific arguments
    start_problem: int = 1
    """the starting problem number"""
    end_problem: int = 200
    """the ending problem number"""
    
    # Training specific arguments
    num_iterations: int = 120000
    """the number of iterations"""
    # to be filled in runtime
    batch_size: int = 0
    """the batch size (computed in runtime)"""
    minibatch_size: int = 0
    """the mini-batch size (computed in runtime)"""
    
    num_heads: int = 1
    """the number of heads in the GAT layer"""
    num_layers: int = 4
    """the number of layers in the GAT layer"""
    
    # Graph specific arguments
    model: str = 'individual'
    """the model to use (individual, critic)"""
    graph_mode: str = 'hgt_edge_resnet'
    """the mode of the graph (no_edge, edge, attention, attention_simple)"""
    critic_mode: str = 'attention'
    """the mode of the critic (no_edge, edge, attention, attention_simple)"""
    # Save Checkpoints
    save_location: str = "final_checkpoints"
    """the location to save the checkpoints"""
    model_name: str = "sample_cp/cp"
    """"""
    start_checkpoint: int = 0
    """the starting checkpoint"""
    end_checkpoint: int = 120000
    """the ending checkpoint"""
    checkpoint_steps: int = 250
    """the step size of the checkpoints"""
    
    seeds: tuple[int, ...] = (10, 11, 12)
    """the seeds to run the experiments"""
    
    initial_performance: bool = False
    """run evaluation on the initial start performance"""
    target_location: str = "final_results"
    """the location to save the final results"""
    run_mode: str = None
    
    
def load_model(model, path):
    if path[0] != '/': # if the path is not absolute
        path = os.path.join(parentdir, path)
    model.load_state_dict(torch.load(path))
    return model

def run_evaluation(agent, problem_id, env_location, greedy=False, model_name="hetgat"):
    print(env_location, model_name, problem_id)
    env = make_env(env_location, problem_id, 1)
    state = env.reset()[0]
    done = False
    total_reward = 0
    if model_name == "milp_solver":
        # run the solver for 5 minutes
        feasible, status, _, schedule, duration = solve_with_MILP(env, verbose=True, time_limit=12*60*60) # 10 seconds -> 60*60 = 1 hour
        print(f"Problem: {problem_id} - Feasible: {feasible} - Status: {status} - Makespan: {env.get_raw_score()} - Duration: {duration}")
        if (not feasible and schedule is None):
            # generate a random schedule
            # tasks without repeat
            print("Generating Random Schedule since MILP is infeasible")
            tasks = np.random.choice(env.num_tasks, env.num_tasks, replace=False)
            agents = np.random.choice(env.num_agents, env.num_tasks, replace=True)
            schedule = list(zip(tasks, agents))
            
        elif len(schedule) < env.num_tasks:
            assigned_tasks = {s[0] for s in schedule}
            unassigned_tasks = np.array([i for i in range(env.num_tasks) if i not in assigned_tasks])
            # shuffle task order
            np.random.shuffle(unassigned_tasks)
            agents = np.random.choice(env.num_agents, len(unassigned_tasks), replace=True)
            schedule += list(zip(unassigned_tasks, agents))
            
        elif not feasible:
            print("Partial Schedule - ", schedule)
        # # read schedule from the environment
        # schedule_path = os.path.splitext(env.save_location)[0] + '_schedule.json'
        # schedule_data = []
        # with open(schedule_path, 'r') as f:
        #     schedule = f.read()
        #     # print(schedule)
        #     schedule_stepwise = schedule[1:-1].split('), (')
        #     for step in schedule_stepwise:
        #         step = step.split(', ')
        #         schedule_data.append([int(step[0]), int(step[1])])
        #     for action in schedule_data:
        #         state, reward, done, _ = env.step(action)
        #         # total_reward += reward
        env.reset()
        for i, action in enumerate(schedule):
            # print(i, action)
            state, reward, done, _ = env.step(action)
        makespan = env.get_raw_score()
        infeasible_count = env.num_infeasible
        return makespan, infeasible_count, schedule, (not feasible and schedule is None)
    
    agent.set_environment(env)
    env.reset()
    schedule_data = []
    while not done:
        # print("Step:", (env.time_step))
        with torch.no_grad():
            action = agent.get_action(state, greedy=greedy)[0]
        schedule_data.append(action)
        print(f"Step: {len(schedule_data)}/{env.num_tasks} - Action: {action}")
        state, reward, done, _ = env.step(action)
        # total_reward += reward
        
    makespan = env.get_raw_score()
    infeasible_count = env.num_infeasible
    return makespan, infeasible_count, schedule_data, False

def get_evaluation(args):
    # args.batch_size = int(args.num_envs * args.num_steps)
    # args.num_iterations = args.total_timesteps // args.batch_size
    run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
    if args.track:
        import wandb

        wandb.init(
            project=args.wandb_project_name,
            entity=args.wandb_entity,
            sync_tensorboard=True,
            config=vars(args),
            name=run_name,
            monitor_gym=True,
            save_code=True,
            
        )
    writer = SummaryWriter(f"runs/{run_name}")
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )

    # TRY NOT TO MODIFY: seeding
    set_seed(args.seed)
    # random.seed(args.seed)
    # np.random.seed(args.seed)
    # torch.manual_seed(args.seed)
    # torch.backends.cudnn.deterministic = args.torch_deterministic

    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
   
    baselines = {
        # "edf": EarliestDeadlineFirstAgent(),
        # "improved_edf": ImprovedEarliestDeadlineFirstAgent(),
        "milp_solver": MILP_Solver(),
        # "gen_random": GeneticAlgorithm(),
        # "gen_edf": GeneticAlgorithm(EarliestDeadlineFirstAgent()),
        # # # "gen_improved_edf": GeneticAlgorithm(ImprovedEarliestDeadlineFirstAgent()),
        # "gen_random_3": GeneticAlgorithm(num_generations=3),
        # "gen_edf_3": GeneticAlgorithm(EarliestDeadlineFirstAgent(), num_generations=3),
        # # "gen_random_10": GeneticAlgorithm(num_generations=10),
        # "gen_edf_10": GeneticAlgorithm(EarliestDeadlineFirstAgent(), num_generations=10),
        # "gen_improved_edf_10": GeneticAlgorithm(ImprovedEarliestDeadlineFirstAgent(), num_generations=10),
    }

    if args.run_mode is not None:
        if args.run_mode in baselines:
            baselines = {args.run_mode: baselines[args.run_mode]} # only run the selected baseline
        else:
            baselines = {} # no baseline
             
    evaluation_data_location = {}
    
    
    baseline_performance = {}
    schedules = {}
    
    baseline_locations = {}
    schedule_locations = {}
    # if the baseline evaluations exist, load them
    eval_location = os.path.join(args.target_location, f"evals")
    print(f"Creating Evaluation Location: {eval_location}")
    if not os.path.exists(eval_location):
        os.makedirs(eval_location)
    if key_words is not None:
        for key in key_words:
            eval_location += f"__{key}"
        if not os.path.exists(eval_location):
            os.makedirs(eval_location)
    for baseline_name, baseline in baselines.items():
        baseline_location = os.path.join(eval_location, f"{args.exp_name}__{'_'.join(os.path.split(args.env_location))}__{args.seed}__{args.start_problem}__{baseline_name}.txt")
        schedule_location = os.path.join(eval_location, f"{args.exp_name}__{'_'.join(os.path.split(args.env_location))}__{args.seed}__{args.start_problem}__{baseline_name}_schedule.txt")
        
        baseline_locations[baseline_name] = baseline_location
        schedule_locations[baseline_name] = schedule_location
        
        print(f"Baseline: {baseline_name} - Location: {baseline_location}")
        baseline_performance[baseline_name] = []
        if os.path.exists(baseline_location):
            with open(baseline_location, 'r') as f:
                baseline_performance[baseline_name] = [line.split(', ') for line in f if len(line.strip()) > 0]
        if os.path.exists(schedule_location):
            with open(schedule_location, 'r') as f:
                schedules[baseline_name] = []
                
                for line in f:
                    if len(line.strip()) <= 0:
                        continue
                    s = []
                    for action in line[1:-2].split('), ('):
                        a = action.split(', ')
                        s.append([int(a[0]), int(a[1])])
                    
                    schedules[baseline_name].append(s)
                    # print("READ:", line)
                    # schedules[baseline_name].append(line)
                
    for baseline_name, baseline in baselines.items():
        baseline_length = len(baseline_performance[baseline_name])
        # clear the buffer for future writes
        baseline_performance[baseline_name] = []
        schedules[baseline_name] = []
        
        evaluation_data_location[baseline_name] = baseline_locations[baseline_name]
        if baseline_name not in baseline_performance:
            baseline_performance[baseline_name] = []
        if baseline_name not in schedules:
            schedules[baseline_name] = []
        print(f"Baseline: {baseline_name} - Problem: {args.start_problem} - Start: {baseline_length}")
        for problem_id in range(args.start_problem + baseline_length, args.end_problem + 1):
            print(f"Baseline: {baseline_name} - Problem: {problem_id}")
            makespan, infeasible_count, schedule, failed = run_evaluation(baseline, problem_id, args.env_location, greedy=True, model_name=baseline_name)
    
            baseline_performance[baseline_name].append([makespan, infeasible_count, failed, problem_id])
            schedules[baseline_name].append(schedule)

            print(f"Baseline: {baseline_name} - Problem: {problem_id} - Makespan: {makespan} - Infeasible: {infeasible_count} - Failed: {failed}")
            if problem_id % args.save_every == 0:
                with open(baseline_locations[baseline_name], 'a') as f:
                    # print(baseline_performance[baseline_name])
                    f.write("\n".join([", ".join(map(str, line)) for line in baseline_performance[baseline_name]]))
                    f.write("\n")
                    baseline_performance[baseline_name] = []
                with open(schedule_locations[baseline_name], 'a') as f:
                    # print(schedules[baseline_name])
                    print("\n".join([ "(" + "), (".join([ str(s[0]) + ", " + str(s[1]) for s in schedule]) + ")" for schedule in schedules[baseline_name]]))
                    
                    for line in [ "(" + "), (".join([ str(s[0]) + ", " + str(s[1]) for s in schedule]) + ")" for schedule in schedules[baseline_name]]:
                        f.write(line + '\n')
                    # f.write("\n".join([ "(" + "), (".join([ str(s[0]) + ", " + str(s[1]) for s in schedule]) + ")" for schedule in schedules[baseline_name]]))
                    # f.write("\n")
                    schedules[baseline_name] = []                       
                    
        # save the performance
        with open(baseline_locations[baseline_name], 'a') as f:
            f.write("\n".join([", ".join(map(str, line)) for line in baseline_performance[baseline_name]]))
            f.write("\n")
        with open(schedule_locations[baseline_name], 'a') as f:
            if len(schedules[baseline_name]) <= 0:
                continue
            print("s", schedules[baseline_name])
            print("\n".join([ "(" + "), (".join([ str(s[0]) + ", " + str(s[1]) for s in schedule]) + ")" for schedule in schedules[baseline_name]]))
            for line in [ "(" + "), (".join([ str(s[0]) + ", " + str(s[1]) for s in schedule]) + ")" for schedule in schedules[baseline_name]]:
                f.write(line + '\n')
                
            # f.write("\n".join([ "(" + "), (".join([ str(s[0]) + ", " + str(s[1]) for s in schedule]) + ")" for schedule in schedules[baseline_name]]))
            # f.write("\n")
            schedules[baseline_name] = []
            
    
    models = ["hgt", "hgt_edge", "hgt_edge_resnet"]
    model_name = ["hgt_simultaneous", "hgt_edge_simultaneous", "hgt_edge_resnet_simultaneous"]
    if args.run_mode is not None:
        if args.run_mode in models:
            models = [args.run_mode] # only run the selected model
            
        else:
            models = []
        model_name = models
    # we are picking the best performing policy pre-hoc
    checkpoints = get_checkpoints(args.save_location, env_prefix, seeds, model_name, models, problem_sets, problem_sets_prefixes, reward_model, baselines, run_name='sac_wip_simultaneous', key_words=key_words, initial_performance=args.initial_performance, checkpoint_steps=args.checkpoint_steps, num_heads=args.num_heads, num_layers=args.num_layers)
    
    
    # layer_data = {
    #     "agent": 3,
    #     "task": 3
    # }
    for key, checkpoint in checkpoints.items():
        print(f"Key: {key} - Checkpoint: {checkpoint}")
        if 'nc' in key:
            graph, seed, problem_size, _ = key.split(' ')
        else:
            graph_mode, seed, problem_size = key.split(' ')
        # checkpoint ends with checkpoint id in the form of XXXXX.pt
        checkpoint_id = int(checkpoint.split('__')[-1][:-3])
        
        # Load the model
        layer_data = None
        temperature = 1.0
        actor = HetGatSolverSimultaneous(graph_mode, num_heads, num_layers, layer_data, temperature)
        # actor = HetGatSolverIndividual(graph_mode, num_heads, num_layers).to(device)
        actor = load_model(actor, checkpoint)    
    
        # env setup (Batchs + Greedy + EDF)
        envs = [make_env(args.env_location, 1, 1) for _ in range(args.num_envs + 4)]
        
        env_id = "_".join(args.env_id.split(" "))
        env_location = "_".join(args.env_location.split("/"))
        if 'nc' in key:
            name = f"{graph_mode}_nc__{problem_size}"
        else:
            name = f"{graph_mode}__{problem_size}"
        
        
        parent_location = os.path.join(args.target_location, f"evals")
        if key_words is not None:
            for key in key_words:
                parent_location += f"__{key}"
            if not os.path.exists(parent_location):
                os.makedirs(parent_location)
        eval_location = os.path.join(parent_location, f"simultaneous_{name}__{seed}__{env_location}__{args.start_problem}__{args.end_problem}__{checkpoint_id}.txt")
        schedule_location = os.path.join(parent_location, f"simultaneous_{name}__{seed}__{env_location}__{args.start_problem}__{args.end_problem}__{checkpoint_id}_schedule.txt")
        # checkpoint_base = os.path.join(args.save_location, f"{args.exp_name}__{env_id}__{args.seed}__{args.graph_mode}__")
        print(f"Creating Evaluation Location: {eval_location}")
        # actor = load_model(actor, checkpoint)
        
        # checkpoint_qf1 = checkpoint_base + f"__{seed}__{args.graph_mode}__qf1_{str(i).zfill(5)}.pt"
        # checkpoint_qf2 = checkpoint_base + f"__{seed}__{args.graph_mode}__qf2_{str(i).zfill(5)}.pt"
        # qf1 = load_model(qf1, checkpoint_qf1)
        # qf2 = load_model(qf2, checkpoint_qf2)
        
        # performance_save_location = # os.path.join(eval_location, f"{args.exp_name}__{env_id}__{seed}__{args.start_problem}__{args.graph_mode}__{str(i).zfill(5)}.txt")
        
        performance_length = 0       
        if os.path.exists(eval_location):
            performance = []
            if os.path.exists(eval_location):
                with open(eval_location, 'r') as f:
                    performance = [line.split(', ') for line in f if len(line.strip()) > 0]
            if os.path.exists(schedule_location):
                with open(schedule_location, 'r') as f:
                    schedules[key] = []
                    for line in f:
                        if len(line.strip()) <= 0:
                            continue
                        s = []
                        for action in line[1:-2].split('), ('):
                            a = action.split(', ')
                            s.append([int(a[0]), int(a[1])])
                        
                        schedules[key].append(s)
                    # schedules[key] = [line for line in f]
                performance_length = len(performance)
        performance = []
        schedules[key] = []
                
        # iterate through the missing problems and evaluate the performance
        for problem_id in range(args.start_problem + performance_length, args.end_problem + 1):
            makespan, infeasible_count, schedule, failure = run_evaluation(actor, problem_id, args.env_location, greedy=True, model_name=graph_mode)
            
            performance.append([makespan, infeasible_count, failure, problem_id])
            # only store greedy schedule
            if key not in schedules:
                schedules[key] = []
            schedules[key].append(schedule)
            
            for batch_id in range(args.batch_size):
                makespan, infeasible_count, failure = run_evaluation(actor, problem_id, args.env_location, greedy=False, model_name=graph_mode)
                performance[-1].extend([makespan, infeasible_count, failure, problem_id])
            print(f"Problem: {problem_id} - {' '.join([str(x) for x in performance[-1]])}")
            
            if problem_id % args.save_every == 0:
                with open(eval_location, 'a') as f:
                    # print(baseline_performance[key])
                    f.write("\n".join([", ".join(map(str, line)) for line in performance]))
                    f.write("\n")
                    performance = []
                with open(schedule_location, 'a') as f:
                    # print(schedules[key])
                    print("\n".join([ "(" + "), (".join([ str(s[0]) + ", " + str(s[1]) for s in schedule]) + ")" for schedule in schedules[key]]))
                    
                    for line in [ "(" + "), (".join([ str(s[0]) + ", " + str(s[1]) for s in schedule]) + ")" for schedule in schedules[key]]:
                        f.write(line + '\n')
                    # f.write("\n".join([ "(" + "), (".join([ str(s[0]) + ", " + str(s[1]) for s in schedule]) + ")" for schedule in schedules[key]]))
                    # f.write("\n")
                    schedules[key] = []
                    
        with open(eval_location, 'a') as f:
            # print(baseline_performance[key])
            f.write("\n".join([", ".join(map(str, line)) for line in performance]))
            f.write("\n")
            performance = []
        with open(schedule_location, 'a') as f:
            # print(schedules[key])
            print("\n".join([ "(" + "), (".join([ str(s[0]) + ", " + str(s[1]) for s in schedule]) + ")" for schedule in schedules[key]]))
            
            for line in [ "(" + "), (".join([ str(s[0]) + ", " + str(s[1]) for s in schedule]) + ")" for schedule in schedules[key]]:
                f.write(line + '\n')
            # f.write("\n".join([ "(" + "), (".join([ str(s[0]) + ", " + str(s[1]) for s in schedule]) + ")" for schedule in schedules[key]]))
            # f.write("\n")
            schedules[key] = []                       
                        
        if '_nc_' in eval_location:
            name = f"{graph_mode}_nc__{problem_size}"
        else:
            name = f"{graph_mode}__{problem_size}"
            
        if name not in evaluation_data_location:
            evaluation_data_location[name] = {}
        evaluation_data_location[name][seed] = eval_location
    
    return evaluation_data_location

def read_data(location, num_problems):
    """Reads the data from the location and returns it as a list of lists"""
    makespan = []
    infeasible = []
    failed = []
    # print(f"Reading Data from: {location}")
    with open(location, 'r') as f:
        d = f.read()
        for line in d.split("\n"):
            if len(line) <= 0 or line[0] == '\n':
                continue
            d = line.split(", ")
            makespan.append(float(d[0]))
            infeasible.append(int(d[1]))
            # failed.append(int(d[2]))
            # failed.append(True if d[2] == "True" else False)
    return np.array(makespan[:num_problems]), np.array(infeasible[:num_problems]), np.array(failed[:num_problems])

def get_figure(data, labels, env_name, title_prefix, y_axis="Makespan (Lower Better)"):

    # Font and title
    font = {'family' : 'Times New Roman',
        'weight' : 'normal',
        'size'   : 14}

    matplotlib.rc('font', **font)
    fig = plt.figure(figsize =(10, 7))
    ax = fig.add_subplot(111)

    bp = ax.boxplot(data, patch_artist=True, notch='True', vert=1)

    color_lens = len(data)
    
    cm = plt.get_cmap('gist_rainbow')
    colors = [cm((i + 1)/color_lens) if i >= 1 else cm(i/color_lens) for i in range(color_lens)]
    
    for patch, color in zip(bp['boxes'], colors):
        patch.set_facecolor(color)
    
    # changing color and linewidth of
    # whiskers
    for whisker in bp['whiskers']:
        whisker.set(color ='#8B008B',
                    linewidth = 1.5,
                    linestyle =":")
    
    # changing color and linewidth of
    # caps
    for cap in bp['caps']:
        cap.set(color ='#8B008B',
                linewidth = 2)
    
    # changing color and linewidth of
    # medians
    for median in bp['medians']:
        median.set(color ='black',
                linewidth = 3)
    
    # changing style of fliers
    for flier in bp['fliers']:
        flier.set(marker ='D',
                color ='#e7298a',
                alpha = 0.5)
        
    # x-axis labels
    ax.set_xticklabels(labels)
    
    # plt.title(f"{title_prefix} {env_name}")
    # Removing top axes and right axes
    # ticks
    ax.get_xaxis().tick_bottom()
    ax.get_yaxis().tick_left()
    
    plt.ylabel(y_axis)
    
    fig.autofmt_xdate()
    return plt
    
def get_box_plot(args, evaluation_data_location):
    title_prefix = "Performance Comparison of Models on "
    env_name = args.env_location.split("/")[-1]
    data = []
    data_if = []
    data_fail = []
    
    ms_data = {}
    if_data = {}
    fail_data = {}
    
    
    # seedwise data for ablation study
    
    data_seedwise = [] # reward
    data_seedwise_if = [] # infeasible
    
    reward_seedwise = {}
    if_seedwise = {}
    
    # entropy_seedwise = {}
    standard_deviation_seedwise = {}
    
    labels = []
    for key, location in evaluation_data_location.items():
        # if location is a list
        if isinstance(location, dict):
            tmp_ms = []
            tmp_if = []
            tmp_fail = []
            for seed, loc in location.items():
                ms, if_, fail_ = read_data(loc, args.end_problem)
                tmp_ms.append(ms)
                tmp_if.append(if_)
                tmp_fail.append(fail_)
            data.append(np.array(tmp_ms).max(axis=0) * 1.0)
            data_if.append(np.array(tmp_if).min(axis=0))
            data_fail.append(np.array(tmp_fail).max(axis=0))
            
            data_seedwise.append(np.array(tmp_ms))
            data_seedwise_if.append(np.array(tmp_if))
        else:
            ms, if_, fail_ = read_data(location, args.end_problem)
            data.append(ms)
            data_if.append(if_)
            data_fail.append(fail_)
            
            data_seedwise.append(np.array([ms]))
            data_seedwise_if.append(np.array([if_]))
        # print(f"Data: {data}")
        # data = np.array(data)
        # print(f"Data: {data}")
        # print(f"Data Shape: {data.shape
        # Generate Mean and Standard Deviation
        ms_mean = np.mean(data[-1])
        ms_std = np.std(data[-1])
        if_mean = np.mean(data_if[-1])
        if_std = np.std(data_if[-1])
        
        fail_sum = np.sum(data_fail[-1])
        fail_rate = fail_sum / len(data_fail[-1])
        
        print(f"{key} Makespan {' '.join([str(d) for d in data[-1]])}\n Infeasible {' '.join([str(d) for d in data_if[-1]])}\n Failed {' '.join([str(d) for d in data_fail[-1]])}")
        # print(f"Key: {key}")
        # if '__' in key:
        #     graph_type = key.split('__')[0]
        #     # remove _nc from the key
        #     # graph_type = graph_type.split('__')[0]
        #     label = f"{label_map[graph_type]} {key.split('__')[1]}"
        #     labels.append(f"{label_map[graph_type]} {key.split('__')[1]}")
        # else:
        #     label = label_map[key]
        if '_nc__' in location:
            key += '_nc'
        if key != "milp_solver":
            key = f"simultaneous_{key}"
        labels.append(key)
        ms_data[key] = np.array(data[-1])
        if_data[key] = np.array(data_if[-1])
        fail_data[key] = (fail_sum, len(data_fail[-1]))
    
    
        if data_seedwise[-1].shape[0] > 1:
            for s in args.seeds:
                reward_seedwise[f"{key}_{s}"] = np.array(data_seedwise[-1][s - args.seeds[0]])
                if_seedwise[f"{key}_{s}"] = np.array(data_seedwise_if[-1][s - args.seeds[0]])
                
            
            # entropy_seedwise
            seed_values = np.array(data_seedwise[-1])
            standard_deviation_seedwise[key] = np.std(seed_values, axis=0)
            # # get entropy of the score across the seeds
            # entropy = -np.sum(seed_values * np.log(seed_values + 1e-10), axis=0)
            # entropy_seedwise[key] = np.array(entropy)
            
        else:
            reward_seedwise[key] = np.array(data_seedwise[-1])
            if_seedwise[key] = np.array(data_seedwise_if[-1])
            
        for i, datum in enumerate(data):
            # print(f"Data: {labels[i]} - {datum}")
            # print(f"Data Shape: {np.array(data).shape}")
            # print(ms_data.keys())
            # save mean and standard deviation in location
            location = 'results'
            if key_words is not None and len(key_words) > 0:
                ext = ""
                
                for key in key_words:
                    ext += f"{key}_"
                location = os.path.join(location, ext)
                if not os.path.exists(location):
                    os.makedirs(location)
            if 'nc' in key:
                location = os.path.join(location, f"{args.exp_name}__{'_'.join(args.env_location.split('/'))}_{args.end_problem}_{labels[i].replace(' ', '_')}_nc.txt")
            else:
                location = os.path.join(location, f"{args.exp_name}__{'_'.join(args.env_location.split('/'))}_{args.end_problem}_{labels[i].replace(' ', '_')}.txt")
            with open(location, 'w') as f:
                f.write(f"{np.mean(datum)} +/- {np.std(datum)}")
                # f.write(f"{np.mean(ms_data[labels[key]])} +/- {np.std(ms_data[labels[key]])}")
                # f.write(f"{np.mean(if_mean[labels[key]])} +- {np.std(if_data[labels[key]])}")
        
        
    plt = get_figure(data, labels, env_name, title_prefix, y_axis="Makespan (Lower Better)")
    plt.savefig(f"figures/{args.exp_name}__{'_'.join(args.env_location.split('/'))}_{args.end_problem}_makespan.png")
    plt.cla()
    plt = get_figure(data_if, labels, env_name, title_prefix, y_axis="Infeasible Count (Lower Better)")
    plt.savefig(f"figures/{args.exp_name}__{'_'.join(args.env_location.split('/'))}_{args.end_problem}_infeasible.png")
    
    # save the dictionary to a file
    if args.initial_performance:
        base_location = os.path.join(args.target_location, f"{args.exp_name}__{'_'.join(args.env_location.split('/'))}_{args.end_problem}_0")
    else:
        base_location = os.path.join(args.target_location, f"{args.exp_name}__{'_'.join(args.env_location.split('/'))}_{args.end_problem}")
    
    # replace evaluation with evaluation_simultaneous in base_location
    base_location = base_location.replace("evaluation", "evaluation_simultaneous")
    
    with open(f"{base_location}_makespan.txt", 'w') as f:
        f.write("\n".join([f"{key}: {np.mean(value)}, {np.std(value)}" for key, value in ms_data.items() if key != 'milp_solver']))
        f.write("\n")
    with open(f"{base_location}_infeasible.txt", 'w') as f:
        f.write("\n".join([f"{key}: {np.mean(value)}, {np.std(value)}" for key, value in if_data.items() if key != 'milp_solver']))
        f.write("\n")
    with open(f"{base_location}_failed.txt", 'w') as f:
        f.write("\n".join([f"{key}: {value[0]} / {value[1]}" for key, value in fail_data.items() if key != 'milp_solver']))
        f.write("\n")
        
    
    with open(f"{base_location}_seedwise_makespan.txt", 'w') as f:
        f.write("\n".join([f"{key}: {np.mean(value)}, {np.std(value)}" for key, value in reward_seedwise.items()]))
    with open(f"{base_location}_seedwise_infeasible.txt", 'w') as f:
        f.write("\n".join([f"{key}: {np.mean(value)}, {np.std(value)}" for key, value in if_seedwise.items()]))
    # with open(f"{base_location}_seedwise_entropy.txt", 'w') as f:
    #     f.write("\n".join([f"{key}: {np.mean(value)}, {np.std(value)}" for key, value in entropy_seedwise.items()]))
    with open(f"{base_location}_seedwise_standard_deviation.txt", 'w') as f:
        f.write("\n".join([f"{key}: {np.mean(value)}, {np.std(value)}" for key, value in standard_deviation_seedwise.items()]))
        
    # adding simultaneous
    sim_data_list = ['hgt', 'hgt_edge', 'hgt_edge_resnet']
    seeds = [10, 11, 12]
    from collections import defaultdict
    sim_data = {}
    sim_best = {}
    
    if args.batch_size > 0:
        for key in sim_data_list:
            data_key = f"{key}_batch_{args.batch_size}"
            sim_data[data_key] = defaultdict(list)
            for seed in seeds:
                # addr = f"simultaneous_{key}__10r20t0s10f25w__{seed}__{'_'.join(os.path.split(args.env_location))}__{args.start_problem}__{args.end_problem}"
                # # if initial_performance:
                # if args.initial_performance:
                #     addr += f"__0"
                # else:
                #     addr += f"__250"
                # addr += ".txt"
                addr = evaluation_data_location[f"{key}__10r20t0s10f25w"][str(seed)]
                # replace key_words: '-' to '_'
                key_words_rep = []
                # for key_word in key_words:
                #     key_words_rep.append(key_word.replace("-", "_"))
                folder_name = f"evals__{'__'.join(key_words)}" if key_words_rep is not None else "evals"
                address = addr # os.path.join(args.target_location, folder_name, addr)
                
                
                with open(address, 'r') as f:
                    d = f.read()
                    for line in d.split("\n"):
                        if len(line) <= 0 or line[0] == '\n':
                            continue
                        d = line.split(", ")
                        
                        line_reward = [float(d[i*4]) for i in range(0, 1 + args.batch_size)]
                        line_infeasible = [int(d[1+i*4]) for i in range(0, 1 + args.batch_size)]
                        line_failed = [True if d[2+i*4] == "True" else False for i in range(0, 1 + args.batch_size)]
                        
                        sim_data[data_key][seed].append([np.array(line_reward).max(), np.array(line_infeasible).min()])
                        
            makespans = np.array([np.array(sim_data[data_key][seed])[:, 0] for seed in seeds])
            infeasibles = np.array([np.array(sim_data[data_key][seed])[:, 1] for seed in seeds])
            
            print(f"Simultaneous Batch {key} - {makespans.shape} - {infeasibles.shape} - {makespans.max(axis=0).shape}")
            sim_best[f"simultaneous_{key}_batch_{args.batch_size}"] = makespans.max(axis=0)
            ms_data[f"simultaneous_{key}_batch_{args.batch_size}"] = makespans.max(axis=0)
            with open(f"{base_location}_makespan.txt", 'a') as f:
                f.write(f"simultaneous_{key}_batch_{args.batch_size}: {np.mean(makespans.max(axis=0))}, {np.std(makespans.max(axis=0))}\n")
            with open(f"{base_location}_infeasible.txt", 'a') as f:
                f.write(f"simultaneous_{key}_batch_{args.batch_size}: {np.mean(infeasibles.min(axis=0))}, {np.std(infeasibles.min(axis=0))}\n")
            with open(f"{base_location}_failed.txt", 'a') as f:
                f.write(f"simultaneous_{key}_batch_{args.batch_size}: {np.sum(infeasibles)}/{len(infeasibles)}\n")
                
    
        
    # Optimality Gap w.r.t. MILP (mean and std per optimality gap of problem in %)
    optimality_gap = {}
    milp_data = ms_data['milp_solver']
    for key, data in ms_data.items():
        opt_gap_per_problem = []
        for i in range(len(data)):
            opt_gap = (milp_data[i] - data[i]) / milp_data[i] * 100
            opt_gap_per_problem.append(opt_gap)
        optimality_gap[key] = np.array(opt_gap_per_problem)
    # save the optimality gap data
    with open(f"{base_location}_makespan_optimality_gap.txt", 'w') as f:
        f.write("\n".join([f"{key}: {np.mean(value)}, {np.std(value)}" for key, value in optimality_gap.items()  if key != 'milp_solver']))
        
    optimality_rate = {}
    for key, data in ms_data.items():
        opt_rate_per_problem = []
        for i in range(len(data)):
            opt_rate = (data[i]/ milp_data[i])
            opt_rate_per_problem.append(opt_rate)
        optimality_rate[key] = np.array(opt_rate_per_problem)
    # save the optimality rate data
    with open(f"{base_location}_makespan_optimality_rate.txt", 'w') as f:
        f.write("\n".join([f"{key}: {np.mean(value)}, {np.std(value)}" for key, value in optimality_rate.items() if key != 'milp_solver']))
        
    
    # Optimality Rate and Gap w.r.t. MILP (mean and std per problem) for different seeds
    optimality_gap_seedwise = {}
    for key, data in reward_seedwise.items():
        opt_gap_per_problem = []
        for i in range(len(data)):
            opt_gap = (milp_data[i] - data[i]) / milp_data[i] * 100
            opt_gap_per_problem.append(opt_gap)
        optimality_gap_seedwise[key] = np.array(opt_gap_per_problem)
    # save the optimality gap data
    with open(f"{base_location}_seedwise_makespan_optimality_gap.txt", 'w') as f:
        f.write("\n".join([f"{key}: {np.mean(value)}, {np.std(value)}" for key, value in optimality_gap_seedwise.items()]))
    
    optimality_rate_seedwise = {}
    for key, data in reward_seedwise.items():
        opt_rate_per_problem = []
        for i in range(len(data)):
            opt_rate = (data[i] / milp_data[i])
            opt_rate_per_problem.append(opt_rate)
        optimality_rate_seedwise[key] = np.array(opt_rate_per_problem)
    # save the optimality rate data
    with open(f"{base_location}_seedwise_makespan_optimality_rate.txt", 'w') as f:
        f.write("\n".join([f"{key}: {np.mean(value)}, {np.std(value)}" for key, value in optimality_rate_seedwise.items()]))
        
   # with open(f"{args.target_location}/{args.exp_name}__{'_'.join(args.env_location.split('/'))}_{args.end_problem}_makespan.txt", 'w') as f:
    #     f.write("\n".join([f"{key}: {value}" for key, value in ms_data.items()]))
    # with open(f"{args.target_location}/{args.exp_name}__{'_'.join(args.env_location.split('/'))}_{args.end_problem}_infeasible.txt", 'w') as f:
    #     f.write("\n".join([f"{key}: {value}" for key, value in if_data.items()]))
        
    # show plot
    # plt.show()

if __name__ == "__main__":
    args = tyro.cli(Args)
    # get the evaluation and save them to a location
    evaluation_data_location = get_evaluation(args)
    # generate box plots
    for key, item in evaluation_data_location.items():
        print(f"Key: {key}")
        if isinstance(item, dict):
            for seed, location in item.items():
                print(f"\tSeed: {seed} - Location: {location}")
        else:
            print(f"\tLocation: {item}")
    get_box_plot(args, evaluation_data_location)
