
import os
import sys
import inspect

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 

import random
import copy
import time
from dataclasses import dataclass

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
import tyro
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter

from scheduling.environment import SchedulingEnvironment, generate_environment, make_env

# from solvers.hetgat_solver_individual import HetGatSolverIndividual

from solvers.hetgat_solver_simultaneous import HetGatSolverSimultaneous
from models.graph_scheduler import GraphSchedulerCritic

from solvers.edf import EarliestDeadlineFirstAgent
from solvers.improved_edf import ImprovedEarliestDeadlineFirstAgent
from solvers.milp_solver import MILP_Solver, solve_with_MILP
from solvers.genetic_algorithm import GeneticAlgorithm
from training.replay_buffer import ReplayBuffer

from utils.utils import set_seed
from get_checkpoints import get_checkpoints

@dataclass
class Args:
    exp_name: str = os.path.basename(__file__)[: -len(".py")]
    """the name of this experiment"""
    seed: int = 10
    """seed of the experiment"""
    torch_deterministic: bool = True
    """if toggled, `torch.backends.cudnn.deterministic=False`"""
    cuda: bool = False
    """if toggled, cuda will be enabled by default"""
    track: bool = False
    """if toggled, this experiment will be tracked with Weights and Biases"""
    wandb_project_name: str = "Task Allocation and Scheduling with Path Planning using GNNs"
    """the wandb's project name"""
    wandb_entity: str = "wandb_repo"
    """the entity (team) of wandb's project"""
    capture_video: bool = False
    """whether to capture videos of the agent performances (check out `videos` folder)"""

    save_every: int = 10
    """save the model every `save_every` iterations"""
    
    # Algorithm specific arguments
    env_id: str = "evaluation"
    """the id of the environment"""
    env_location: str = "data/problem_set_r5_t20_s10_f30_w50_euc_2000"
    """the location of the environment"""
    num_problems: int = 20
    """the number of problems to evaluate"""
    
    total_timesteps: int = 10000000
    """total timesteps of the experiments"""
    num_envs: int = 8
    """the number of parallel game environments"""
    num_steps: int = 128
    """the number of steps to run in each environment per policy rollout"""
    
    # Environment specific arguments
    start_problem: int = 1
    """the starting problem number"""
    end_problem: int = 20
    """the ending problem number"""
    
    # Training specific arguments
    num_iterations: int = 120000
    """the number of iterations"""
    # to be filled in runtime
    batch_size: int = 0
    """the batch size (computed in runtime)"""
    minibatch_size: int = 0
    """the mini-batch size (computed in runtime)"""
    
    num_heads: int = 1
    """the number of heads in the GAT layer"""
    num_layers: int = 4
    """the number of layers in the GAT layer"""
    
    # Graph specific arguments
    model: str = 'individual'
    """the model to use (individual, critic)"""
    graph_mode: str = 'hgt_edge_resnet'
    """the mode of the graph (no_edge, edge, attention, attention_simple)"""
    critic_mode: str = 'attention'
    """the mode of the critic (no_edge, edge, attention, attention_simple)"""
    # Save Checkpoints
    save_location: str = "checkpoints"
    """the location to save the checkpoints"""
    model_name: str = "sample_cp/cp"
    """"""
    start_checkpoint: int = 0
    """the starting checkpoint"""
    end_checkpoint: int = 120000
    """the ending checkpoint"""
    checkpoint_step: int = 500
    """the step size of the checkpoints"""
    
    seeds: tuple[int, ...] = (10, 11, 12)
    """the seeds to run the experiments"""
    
    checkpoint_steps: int = 250
    """the step size of the checkpoints"""

    
    # Save Checkpoints
    save_location: str = "final_checkpoints_sim"
    """the location to save the checkpoints"""
    rerun: bool = False
    """if toggled, the experiment will rerun"""

def load_model(model, path):
    if path[0] != '/': # if the path is not absolute
        path = os.path.join(parentdir, path)
    model.load_state_dict(torch.load(path))
    return model

def run_evaluation_time(agent, env, batch_size=0):
    # agent.set_environment(env)
    try:
        agent.set_environment(env, rerun=True)
        state = env.reset()[0]
        done = False
        total_reward = 0
        with torch.no_grad():
            schedules = agent.get_action_and_memory_batched(state, batch_size=batch_size)[0]
            # schedules = [schedul
            # for i in range(0, batch_size):
            #     schedule = agent.get_action_and_memory(state, greedy=False)[0]
            #     schedules.append(schedule)
        # for action in schedule:
        #     state, reward, done, _ = env.step(action)
            # total_reward += reward
        
        # while not done:
        #     action = agent.get_action(state, greedy=True)[0]
        #     state, reward, done, _ = env.step(action)
        # total_reward += reward
    except Exception as e:
        print(f"Error: {e}")
        pass        
    # makespan = env.get_raw_score()
    # infeasible_count = env.num_infeasible
    # return makespan, infeasible_count

if __name__ == "__main__":
    args = tyro.cli(Args)
    # args.batch_size = int(args.num_envs * args.num_steps)
    # args.num_iterations = args.total_timesteps // args.batch_size
    run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
    if args.track:
        import wandb

        wandb.init(
            project=args.wandb_project_name,
            entity=args.wandb_entity,
            sync_tensorboard=True,
            config=vars(args),
            name=run_name,
            monitor_gym=True,
            save_code=True,
            
        )
    writer = SummaryWriter(f"runs/{run_name}")
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )

    # TRY NOT TO MODIFY: seeding
    set_seed(args.seed)
    # random.seed(args.seed)
    # np.random.seed(args.seed)
    # torch.manual_seed(args.seed)
    # torch.backends.cudnn.deterministic = args.torch_deterministic
    # env_prefix = "paper init"
    env_prefix = "paper neurips25"
    seeds = [10]
    models = ["hgt", "hgt_edge", "hgt_edge_resnet"]
    # model_name = ["hgt", "hgt_edge", "hgt_edge_resnet"]    
    model_name = ["hgt_simultaneous", "hgt_edge_simultaneous", "hgt_edge_resnet_simultaneous"]
    # models = models[:-1]
    # model_name = model_name[:-1]
    
    problem_sets = ["data/problem_set_r2_t5_s10_f30_w50_euc_2000", "data/problem_set_r5_t20_s10_f30_w50_euc_2000"]
    problem_sets_prefixes = ["10r20t0s10f25w"]
    baselines = [] # 'milp_reward', 'improved_edf_reward', 'edf_reward'] # , 'max_sample_reward', 'mean_sample_reward', 'min_sample_reward']
    reward_model = [""]
    
    key_words = ['reward-greedy-final1']
    
    baselines = {
        # "edf": EarliestDeadlineFirstAgent(),
        # "improved_edf": ImprovedEarliestDeadlineFirstAgent(),
        # "milp_solver": MILP_Solver(),
        # "gen_random": GeneticAlgorithm(),
        # "gen_edf": GeneticAlgorithm(EarliestDeadlineFirstAgent()),
        # # # "gen_improved_edf": GeneticAlgorithm(ImprovedEarliestDeadlineFirstAgent()),
        # "gen_random_3": GeneticAlgorithm(num_generations=3),
        # "gen_edf_3": GeneticAlgorithm(EarliestDeadlineFirstAgent(), num_generations=3),
        # # "gen_improved_edf_3": GeneticAlgorithm(ImprovedEarliestDeadlineFirstAgent(), num_generations=3),
        # "gen_random_10": GeneticAlgorithm(num_generations=10),
        # "gen_edf_10": GeneticAlgorithm(EarliestDeadlineFirstAgent(), num_generations=10),
        # # "gen_improved_edf_10": GeneticAlgorithm(ImprovedEarliestDeadlineFirstAgent(), num_generations=10),
    }
    
    # baseline_models = {
    #     "milp": None,
    #     "edf": EarliestDeadlineFirstAgent(),
    #     "improved_edf": ImprovedEarliestDeadlineFirstAgent(),
    # }
    
    folders = {}
    
    model_performance = {}
    
                    
    for baseline, model in baselines.items():
        if baseline not in model_performance:
            model_performance[baseline] = []
        filename = f"final_results/evals/{args.exp_name}__{args.env_id}__{os.path.split(args.env_location)[-1]}__{baseline}__time.txt"
        folders[baseline] = filename
        start = 0
        if not args.rerun:
            if os.path.exists(filename):
                with open(filename, 'r') as f:
                    model_performance[baseline] = f.readlines()
                    model_performance[baseline] = [float(x[:-1]) for x in model_performance[baseline] if x != '\n']
            start = len(model_performance[baseline])
            if len(model_performance[baseline]) >= args.num_problems:
                continue # no need to rerun
            if start // args.save_every > 0:
                model_performance[baseline] = model_performance[baseline][start // args.save_every * args.save_every:] # remove the saved ones
        else:
            with open(filename, 'w') as f:
                pass # empty the file
            
        for i in range(start + 1, args.num_problems + 1):
            print(f"Running {baseline} on problem {i}")
            env = make_env(args.env_location, i, 1)
            time_start = time.time()
            run_evaluation_time(model, env, args.batch_size)
            time_end = time.time()
            model_performance[baseline].append(time_end - time_start)
            env.reset()
            if i % args.save_every == 0:
                if not os.path.exists(filename):
                    with open(filename, 'w') as f:
                        pass
                with open(filename, 'a+') as f:
                    f.write("\n".join([str(x) for x in model_performance[baseline]]))
                    f.write("\n")
                    model_performance[baseline] = []
        with open(filename, 'a') as f:
            f.write("\n".join([str(x) for x in model_performance[baseline]]))
            f.write("\n")
            model_performance[baseline] = []
            
    # we are picking the best performing policy pre-hoc
    checkpoints = get_checkpoints(args.save_location, env_prefix, seeds, model_name, models, problem_sets, problem_sets_prefixes, reward_model, baselines, 'sac_wip_simultaneous', key_words, initial_performance=False, checkpoint_steps=args.checkpoint_steps, num_heads=args.num_heads, num_layers=args.num_layers)
    
    num_heads = 1
    num_layers = 4
    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
    for key, checkpoint in checkpoints.items():
        print(f"Key: {key} - Checkpoint: {checkpoint}")
        if 'nc' in key:
            graph, seed, problem_size, _ = key.split(' ')
        else:
            graph_mode, seed, problem_size = key.split(' ')
        # checkpoint ends with checkpoint id in the form of XXXXX.pt
        checkpoint_id = int(checkpoint.split('__')[-1][:-3])
        
    
        filename = f"final_results/evals/simultaneous_{args.exp_name}__{args.env_id}__{os.path.split(args.env_location)[-1]}__{graph_mode}__time"
        if args.batch_size == 0:
            filename += ".txt"
        else:
            filename += f"_batch_{args.batch_size}.txt"
        folders[graph_mode] = filename
        if graph_mode not in model_performance:
            model_performance[graph_mode] = []
        start = 0
        if not args.rerun:
            # read the file
            if os.path.exists(filename):
                with open(filename, 'r') as f:
                    model_performance[graph_mode] = f.readlines()
                    model_performance[graph_mode] = [float(x[:-1]) for x in model_performance[graph_mode] if x != '\n']
            start = len(model_performance[graph_mode])
            if len(model_performance[graph_mode]) >= args.num_problems:
                print(f"Skipping {graph_mode} as it has already been evaluated {len(model_performance[graph_mode])}/{args.num_problems} times")
                continue # no need to rerun
            if start // args.save_every > 0:
                model_performance[graph_mode] = model_performance[graph_mode][start // args.save_every * args.save_every:] # remove the saved ones

        # Load the model
        layer_data = None
        temperature = 1.0
        actor = HetGatSolverSimultaneous(graph_mode, num_heads, num_layers, layer_data, temperature)
        actor = load_model(actor, checkpoint)    
        for i in range(start + 1, args.num_problems + 1):
            print(f"Running {graph_mode} on problem {i}")
            env = make_env(args.env_location, i, 1)
            time_start = time.time()
            run_evaluation_time(actor, env, args.batch_size)
            time_end = time.time()
            model_performance[graph_mode].append(time_end - time_start)
            env.reset()
            if i % args.save_every == 0:
                with open(filename, 'a') as f:
                    f.write("\n".join([str(x) for x in model_performance[graph_mode]]))
                    f.write("\n")
                    model_performance[graph_mode] = []
        with open(filename, 'a') as f:
            f.write("\n".join([str(x) for x in model_performance[graph_mode]]))
            f.write("\n")
            model_performance[graph_mode] = []
    
    time_file = f"final_results/time__simultaneous_{os.path.split(args.env_location)[-1]}"
    if args.batch_size == 0:
        time_file += ".txt"
    else:
        time_file += f"_batch_{args.batch_size}.txt"
    with open(time_file, 'w') as f:        
        for key, folder in folders.items():
            # read the file
            with open(folder, 'r') as f_read:
                data = f_read.readlines()
                data = [float(x[:-1]) for x in data if x != '\n'][:args.num_problems]
                print(data)
                print(f"{key}: {np.mean(data)}")
                # with open(f)
                f.write(f"{key}: {np.mean(data)}, {np.std(data)}\n")
                