# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_ataripy


import os
import sys
import inspect

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 

import random
import time
from dataclasses import dataclass

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import tyro
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter

from utils.utils import set_seed

from scheduling.environment import SchedulingEnvironment, generate_environment, make_env

from solvers.hetgat_solver_individual import HetGatSolverIndividual
from solvers.edf import EarliestDeadlineFirstAgent
from solvers.improved_edf import ImprovedEarliestDeadlineFirstAgent
from solvers.milp_solver import MILP_Solver

from evaluation.get_checkpoints import get_checkpoints

@dataclass
class Args:
    exp_name: str = "ppo"
    """the name of this experiment"""
    seed: int = 10
    """seed of the experiment"""
    torch_deterministic: bool = True
    """if toggled, `torch.backends.cudnn.deterministic=False`"""
    cuda: bool = False
    """if toggled, cuda will be enabled by default"""
    track: bool = True
    """if toggled, this experiment will be tracked with Weights and Biases"""
    wandb_project_name: str = "Task Allocation and Scheduling with Path Planning using GNNs"
    """the wandb's project name"""
    wandb_entity: str = "wandb_repo"
    """the entity (team) of wandb's project"""
    capture_video: bool = False
    """whether to capture videos of the agent performances (check out `videos` folder)"""

    # Algorithm specific arguments
    env_id: str = "evaluation"
    """the id of the environment"""
    env_location: str = "data/problem_set_r5_t10_s10_f80_w50_euc_200"
    """the location of the environment"""
    start_id = 1
    """the starting problem number"""
    end_id = 200
    """the ending problem number"""
    store_location: str = f"{env_location}/evalutation_results.log"
    """the location to store the results"""
    
    
    # Model specific arguments
    test_model: bool = True
    """if toggled, the model will be tested"""
    checkpoint_location: str = "checkpoints"
    """the location to store the model checkpoints"""
    model_name: str = "attention_gb_1"
    """the name of the model"""
    start_cp: int = 0
    """the starting checkpoint number"""
    end_cp: int = 2000
    """the ending checkpoint number"""
    step_cp: int = 50
    """the step between checkpoints"""
    seeds = (10, 11, 12, 13, 14)
    """the seeds to evaluate the model"""
    ensemble_size: int = 8
    """the number of models to evaluate per seed"""
    greedy: bool = False
    """if toggled, the greedy baseline will be evaluated along with the ensemble"""
    graph_mode: str = 'attention'
    """the mode of the graph (no_edge, edge, attention, attention_simple)"""
    baselines: list = ('edf_baseline', 'improved_edf_baseline')
    """the baselines to evaluate, if toggled, (edf_baseline, improved_edf_baseline, milp_baseline)"""
    
if __name__ == "__main__":
    args = tyro.cli(Args)
    run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
    if args.track:
        import wandb

        wandb.init(
            project=args.wandb_project_name,
            entity=args.wandb_entity,
            sync_tensorboard=True,
            config=vars(args),
            name=run_name,
            monitor_gym=True,
            save_code=True,
            
        )
    writer = SummaryWriter(f"runs/{run_name}")
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )

    args.seeds = list(args.seeds)
    
    # TRY NOT TO MODIFY: seeding
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    set_seed(args.seed)
    
    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

    # env setup (Batchs + Greedy + EDF)
    baselines = list(args.baselines)
    
    env_prefix = "paper init"
    seeds = [10, 11, 12]
    model_name = ["hetgat", "hetgat_resnet", "hgt", "hgt_edge", "hgt_edge_resnet", "hgt-edge-resnet_wbb"]
    models = ["hetgat", "hetgat_resnet", "hgt", "hgt_edge", "hgt_edge_resnet", "hgt_edge_resnet_bb"]
    problem_sets = ["data/problem_set_r2_t5_s10_f30_w50_euc_2000", "data/problem_set_r5_t20_s10_f30_w50_euc_2000"]
    problem_sets_prefixes = ["2r5t", "5r20t"]
    baselines = ['milp_reward', 'improved_edf_reward', 'edf_reward'] # , 'max_sample_reward', 'mean_sample_reward', 'min_sample_reward']
    
    checkpoints = get_checkpoints(env_prefix, seeds, model_name, models, problem_sets, problem_sets_prefixes, baselines)
    
    env
    agents = []
        
    agent_id_map = {}
    ensemble_list = []
    headers = []
    if args.test_model:
        for seed in args.seeds:
            agent_headers = []
            agent = HetGatSolverIndividual(args.graph_mode).to(device)
            checkpoint_base = os.path.join(args.checkpoint_location, f"{args.exp_name}__{args.model_name}__{args.seed}__{args.graph_mode}__")
            checkpoint = checkpoint_base + f"{str(0).zfill(5)}.pt"
            agent.load_state_dict(torch.load(checkpoint))
            if args.greedy:
                agent_headers.append(f"m_{seed}_g")
            if args.ensemble_size:
                a_headers = [f"m_{seed}_{ens}" for ens in range(args.ensemble_size)]
                ensemble_list.extend([len(headers)+len(agent_headers)+i for i in range(args.ensemble_size)])
                agent_headers.extend(a_headers)
            for i, _ in enumerate(agent_headers):
                agent_id_map[len(headers) + i] = len(agents)            
            headers.extend(agent_headers)
            agents.append(agent)
        
    baseline_headers = []
    if 'edf_baseline' in baselines:
        edf = EarliestDeadlineFirstAgent()
        agent_id_map[len(headers) + len(baseline_headers)] = len(agents)
        baseline_headers.append(edf.name)
        agents.append(edf)
    if 'improved_edf_baseline' in baselines:
        improved_edf = ImprovedEarliestDeadlineFirstAgent()
        agent_id_map[len(headers) + len(baseline_headers)] = len(agents)
        baseline_headers.append(improved_edf.name)
        agents.append(improved_edf)
    if 'milp_baseline' in baselines:
        milp_solver = MILP_Solver()
        agent_id_map[len(headers) + len(baseline_headers)] = len(agents)
        baseline_headers.append(milp_solver.name)
        agents.append(milp_solver)
    headers.extend(baseline_headers)
    num_envs = len(headers)
    
    # TRY NOT TO MODIFY: start the game
    global_step = 0
    start_time = time.time()
    # next_obs = torch.Tensor(next_obs).to(device)
    next_done = torch.zeros(num_envs).to(device)
    last_rewards = torch.ones(num_envs).to(device)


    for iteration in range(args.start_id, args.end_id + 1):
        envs = [make_env(args.env_location, iteration) for _ in range(num_envs)]
        num_steps = max([env.num_tasks for env in envs])
        # Memory Buffers
        obs = [None for _ in range(num_steps)]
        next_obs = [env.reset()[0] for env in envs] # off by one step
        
        actions = [[None for _ in range(len(envs))] for _ in range(num_steps)]
        memory = [[None for _ in range(len(envs))] for _ in range(num_steps)]
        logprobs = torch.zeros((num_steps, num_envs)).to(device)
        rewards = torch.zeros((num_steps, len(envs))).to(device)
        dones = torch.zeros((num_steps+1, len(envs))).to(device)
        
        dones[0] = torch.zeros(len(envs)).to(device)
        
        # Print the Environment IDs for the Action Space
        print("="*90)
        print("Env:   \t" + "\t".join(headers))
        
        final_rewards = [None for _ in range(num_envs)]
        max_steps = max([env.num_tasks for env in envs])
        for step in range(0, max_steps):
            # if all done, early terminate
            if dones[step].sum() >= len(envs):
                break
            global_step += num_envs
            
            # ALGO LOGIC: action logic
            with torch.no_grad():
                action = [None for _ in range(len(envs))]
                logprob = torch.zeros(num_envs).to(device)
                value = torch.zeros(num_envs).to(device)
                for i, next_observation in enumerate(next_obs):
                    if dones[step][i] == 1:
                        continue
                    if i in ensemble_list:
                        action[i] = agents[agent_id_map[i]].get_action(next_observation, greedy=args.greedy)[0]
                    else:
                        action[i] = agents[agent_id_map[i]].get_action(next_observation)[0]
                # print action in a way to split each decision of the batch with a \t
                # values[step] = value.flatten()
                
            # print the action space
            print(f"{step+1}/{max_steps} \t" + "\t".join([str(a) for a in action]))
            
            actions[step] = action
            logprobs[step] = logprob
            # TRY NOT TO MODIFY: execute the game and log data.
            outputs = [env.step(action[i]) if dones[step][i] == 0 else (None, +1, True, {}) for i, env in enumerate(envs)]
            next_obs, reward, terminations, infos = zip(*outputs)
            # print(next_obs)
            next_obs = list(next_obs)
            terminations = [term if term is not None else True for term in terminations]
            next_done = torch.tensor(terminations, dtype=torch.float32).to(device).view(-1)
            # Last Reward is the reward that is from the environment if this is the first complete step or the reward from the last step if this is an old complete step
            last_rewards = torch.tensor([torch.tensor(reward[i]) if reward[i] < 0 else last_rewards[i] for i in range(len(envs))])
            
            dones[step+1] = torch.logical_or(next_done, dones[step]).long()
            
            rewards[step] = torch.tensor(reward).to(device).view(-1)
            # next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device)

            if "final_info" in infos:
                for info in infos["final_info"]:
                    if info and "episode" in info:
                        print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
                        writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
                        writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
                        
        print("Rew: \t", "\t".join([str(f"{r:.3f}") for r in last_rewards.tolist()]))
        print("="*90)
        
        if args.track:
            log = {
                "steps": global_step,
                "iteration": iteration,
            }
            log.update({
                headers[i]: last_rewards[i].item() for i in range(len(headers))
            })
            wandb.log(
                log
            )
        # if args.track:
        #     wandb.log(
        #         {
        #             "steps": global_step,
        #             "Greedy Baseline": mean_rewards,
        #             "Mean Reward": last_rewards[:-1].mean(),
        #             "Mean Advantage": advantages.mean(),
        #             "EDF": edf_reward,
        #         }
        #     )
        
        
        print("SPS:", int(global_step))
        writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

    writer.close()