# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_ataripy


import os
import sys
import inspect

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 

import random
import copy
import time
from dataclasses import dataclass

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
import tyro
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter

from utils.utils import set_seed
from scheduling.environment import SchedulingEnvironment, generate_environment, make_env

from solvers.hetgat_solver_simultaneous import HetGatSolverSimultaneous
from models.graph_scheduler import GraphSchedulerCritic

from solvers.edf import EarliestDeadlineFirstAgent
from solvers.improved_edf import ImprovedEarliestDeadlineFirstAgent
from solvers.milp_solver import MILP_Solver
from training.replay_buffer import ReplayBuffer

@dataclass
class Args:
    exp_name: str = os.path.basename(__file__)[: -len(".py")]
    """the name of this experiment"""
    seed: int = 10
    """seed of the experiment"""
    torch_deterministic: bool = True
    """if toggled, `torch.backends.cudnn.deterministic=False`"""
    cuda: bool = False
    """if toggled, cuda will be enabled by default"""
    track: bool = False # True # TODO: Debugging purpose
    """if toggled, this experiment will be tracked with Weights and Biases"""
    wandb_project_name: str = "Task Allocation and Scheduling with Path Planning using GNNs"
    """the wandb's project name"""
    wandb_entity: str = "wandb_repo"
    """the entity (team) of wandb's project"""
    capture_video: bool = False
    """whether to capture videos of the agent performances (check out `videos` folder)"""

    # Algorithm specific arguments
    env_id: str = "reward wrt same step greedy"
    """the id of the environment"""
    env_location: str = "data/problem_set_r10_t20_s0_f10_w25_euc_2000_uni"
    """the location of the environment"""
    total_timesteps: int = 10000000
    """total timesteps of the experiments"""
    learning_rate: float = 1e-3
    """the learning rate of the optimizer"""
    num_envs: int = 8
    """the number of parallel game environments"""
    num_steps: int = 128
    """the number of steps to run in each environment per policy rollout"""
    anneal_lr: bool = True
    """Toggle learning rate annealing for policy and value networks"""
    gamma: float = 0.95
    """the discount factor gamma"""
    tau: float = 0.99
    """target smoothing coefficient (default: 1)"""
    batch_size: int = 64
    """the batch size of sample from the reply memory"""
    partition_learning: bool = True
    """if toggled, the partition learning will be enabled, the critic will learn up to learning_starts and the actor will learn after that"""
    learning_starts: int = 200
    """timestep to start learning"""
    policy_lr: float = 1e-3
    """the learning rate of the policy network optimizer"""
    q_lr: float = 1e-3
    """the learning rate of the Q network network optimizer"""
    update_frequency: int = 4
    """the frequency of training updates"""
    target_network_frequency: int = 8000
    """the frequency of updates for the target networks"""
    alpha: float = 1.0
    """Entropy regularization coefficient."""
    autotune: bool = False
    """automatic tuning of the entropy coefficient"""
    target_entropy_scale: float = 0.89
    """coefficient for scaling the autotune entropy target"""

    # Environment specific arguments
    start_problem: int = 1
    """the starting problem number"""
    end_problem: int = 200
    """the ending problem number"""
    
    # Training specific arguments
    num_iterations: int = 12000
    """the number of iterations"""
    # to be filled in runtime
    batch_size: int = 8
    """the batch size (computed in runtime)"""
    minibatch_size: int = 0
    """the mini-batch size (computed in runtime)"""
    num_baseline_update: int = 10
    """the number of baseline updates"""
    
    num_heads: int = 1
    """the number of heads in the GAT layer"""
    num_layers: int = 4
    """the number of layers in the GAT layer"""
    agent_layers: int = num_layers
    """the number of layers in the agent GAT layer"""
    task_layers: int = num_layers
    """the number of layers in the critic GAT layer"""
    
    # Graph specific arguments
    graph_mode: str = 'hgt_edge_resnet'
    """the mode of the graph (no_edge, edge, attention, attention_simple)"""
    critic_mode: str = 'hgt_edge_resnet'
    """the mode of the critic (no_edge, edge, attention, attention_simple)"""
    # Save Checkpoints
    save_location: str = "final_checkpoints_sim"
    """the location to save the checkpoints"""
    checkpoint_step: int = 250
    """_summary_step to save the model"""
    continue_training: int = 1
    """the iteration to continue training from"""
    
    reward_mode: str = "base"
    """the reward mode for training [base/feasible]"""
    
    baseline_boosting: bool = True
    """if toggled, the baselines will be boosted"""
    
    num_qf: int = 2
    """the number of Q-functions"""
    temperature: float = 1.0
    """The Temperature Parameter for SoftMax for the sampler"""
    adaptive_temperature: bool = False
    """if toggled, the temperature will be adapted"""
    entropy_coeff: float = 0.01
    """the entropy coefficient"""
    
def save_model(model, path):
    torch.save(model.state_dict(), path)
    
def load_model(model, path):
    model.load_state_dict(torch.load(path))
    return model

def get_discounted_reward(reward, num_tasks, gamma):
    # discounted_reward = np.ones(num_tasks) * reward
    # return discounted_reward
    # discounted_reward = np.zeros(num_tasks)
    # discounted_reward[-1] = reward
    discounted_reward = reward
    for i in range(num_tasks - 2, -1, -1):
        discounted_reward[i] = discounted_reward[i] + gamma * discounted_reward[i + 1]
        # discounted_reward[i] = gamma * discounted_reward[i + 1]
    return discounted_reward

CLIP_GRAD = 1.0

if __name__ == "__main__":
    args = tyro.cli(Args)
    print(args.env_id)
    # args.batch_size = int(args.num_envs * args.num_steps)
    # args.num_iterations = args.total_timesteps // args.batch_size
    run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
    if args.track:
        import wandb

        if args.continue_training is not None and args.continue_training > 1:
            file_name = f"{args.save_location}/wandb_access_codes/{args.exp_name}__{'_'.join(args.env_id.split(' '))}__{args.num_heads}__{args.num_layers}__{args.graph_mode}_s{args.seed}_code.txt"
            if not os.path.exists(args.save_location):
                os.makedirs(args.save_location)
            if not os.path.exists(f'{args.save_location}/wandb_access_codes'):
                os.makedirs(f'{args.save_location}/wandb_access_codes')
            with open(file_name, "r") as f:
                # get the id from the file f
                wand_id = f.read().strip()
                print(run_name, wand_id)
                
                
        wandb.init(
            project=args.wandb_project_name,
            entity=args.wandb_entity,
            sync_tensorboard=True,
            config=vars(args),
            name=run_name,
            monitor_gym=True,
            save_code=True,
            resume= "allow",
        )
    writer = SummaryWriter(f"runs/{run_name}")
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )

    # TODO: write the Wandb access code into a single file for easy access
    if args.track:
        file_name = f"{args.save_location}/wandb_access_codes/{args.exp_name}__{'_'.join(args.env_id.split(' '))}__{args.num_heads}__{args.num_layers}__{args.graph_mode}_s{args.seed}_code.txt"
        if not os.path.exists(args.save_location):
            os.makedirs(args.save_location)
        if not os.path.exists(f'{args.save_location}/wandb_access_codes'):
            os.makedirs(f'{args.save_location}/wandb_access_codes')
        with open(file_name, "w") as f:
            print(f"{wandb.run.id}", file=f)
        print(f"Wandb ID is saved to {file_name}")
        # TRY NOT TO MODIFY: seeding
        # random.seed(args.seed)
        # np.random.seed(args.seed)
        # torch.manual_seed(args.seed)
        # torch.backends.cudnn.deterministic = args.torch_deterministic

    
    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
    set_seed(args.seed, device)

    layer_data = None
    # env setup (Batchs + Greedy + EDF)
    envs = [make_env(args.env_location, 1, 1, reward_mode=args.reward_mode) for _ in range(args.num_envs + 4)]
    # assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

    actor = HetGatSolverSimultaneous(args.graph_mode, args.num_heads, args.num_layers, layer_data, args.temperature).to(device)
    actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.policy_lr, eps=1e-8, weight_decay=1e-1)
        
    # if save_location does not exist, create it
    if not os.path.exists(args.save_location):
        os.makedirs(args.save_location)
    env_id = "_".join(args.env_id.split(" "))
    checkpoint_base = os.path.join(args.save_location, f"{args.exp_name}__{env_id}__{args.seed}__{args.graph_mode}__")
    
    if args.continue_training is None or args.continue_training <= 1:
        initial_checkpoint = checkpoint_base + f"{str(0).zfill(5)}.pt"
        save_model(actor, initial_checkpoint)
    else:
        checkpoint = checkpoint_base + f"{str(args.continue_training).zfill(5)}.pt"
        actor = load_model(actor, checkpoint)
        
    baselines = {
        "edf": EarliestDeadlineFirstAgent(),
        "improved_edf": ImprovedEarliestDeadlineFirstAgent(),
        "milp_solver": MILP_Solver()
    }
    
    # Automatic entropy tuning
    if args.autotune:
        target_entropy = -args.target_entropy_scale * torch.log(1 / torch.tensor((envs[0].num_tasks)) + 1 / torch.tensor((envs[0].num_agents)))
        log_alpha = torch.zeros(1, requires_grad=True, device=device)
        alpha = log_alpha.exp().item()
        a_optimizer = optim.Adam([log_alpha], lr=args.q_lr, eps=1e-4)
    else:
        alpha = args.alpha

    start_time = time.time()
    if args.continue_training is not None or args.continue_training > 1:
        start = args.continue_training
    else:
        start = 1
        
    for global_step in range(start, args.num_iterations + 1):
        envs = [make_env(args.env_location, global_step, args.start_problem, args.end_problem, reward_mode=args.reward_mode) for _ in range(args.num_envs + len(baselines) + 1)]
        # Memory Buffers
        obs = [env.reset()[0] for env in envs]

        dones = torch.zeros(len(envs)).to(device)
        next_obs = [None for _ in range(len(envs))]
        
        baselines["milp_solver"].set_environment(envs[-1])
        max_steps = max([env.num_tasks for env in envs])
        print("MaxSteps", max_steps)
        print(f"Step {global_step}: Training - {'no_critic' not in args.reward_mode} - {args.partition_learning} - {global_step} - {args.learning_starts}")
        print("="*90)
        print(f"{global_step}:\tGreedy\t" + "\t".join([str(i) for i in range(args.num_envs)]) + "\tEDF\tImpEDF\tMILP")
        print("-"*90)
        scores = torch.zeros(len(envs)).to(device)
        final_rewards = torch.zeros(len(envs)).to(device)
            
        actions = [None for _ in range(len(envs))]
        raw_rewards = [None for _ in range(len(envs))]
        infeasible_counts = [None for _ in range(len(envs))]
        log_probs = [None for _ in range(args.num_envs + 1)]
        entropies = [None for _ in range(args.num_envs + 1)]
        advantages = [None for _ in range(args.num_envs + 1)]
        for i, env in enumerate(envs):
            if i == 0:
                with torch.no_grad(): # Greedy Schedule for the first environment (0). Does not require gradient
                    schedule, log_prob, entropy, _ = actor.get_action_and_memory(obs[i], greedy=True)
            elif i <= args.num_envs:
                schedule, log_prob, entropy, _ = actor.get_action_and_memory(obs[i], adaptive_temperature=args.adaptive_temperature, greedy=False)
            else:
                schedule = []
                for step in range(0, max_steps):
                    if i == args.num_envs + 1:
                        schedule.append(baselines["edf"].get_action(obs[i])[0])
                        obs[i] = envs[i].step(schedule[-1])[0]
                    elif i == args.num_envs + 2:
                        schedule.append(baselines["improved_edf"].get_action(obs[i])[0])
                        obs[i] = envs[i].step(schedule[-1])[0]
                    elif i == args.num_envs + 3:
                        schedule.append(baselines["milp_solver"].get_action(obs[i])[0])
                        obs[i] = envs[i].step(schedule[-1])[0]
                    else:
                        raise ValueError(f"Invalid Environment ID: {i}")
            
            print(f"{i}:\t{schedule}")
            actions[i] = schedule
            step_rewards = []
            if i <= args.num_envs:
                for action in schedule:
                    _, r, _, _ = envs[i].step(action)
                    step_rewards.append(r)
                step_rewards = np.array(step_rewards)
            reward = envs[i].get_raw_score()
            
            # reward = np.array(step_rewards)
            infeasible = envs[i].num_infeasible
            raw_rewards[i] = reward
            infeasible_counts[i] = infeasible
            
            if i <= args.num_envs:
                advantage = step_rewards # - raw_rewards[0]
                task_allocation_log_prob = log_prob[0]
                sequence_log_prob = log_prob[1]
                sequence_discounted_reward = get_discounted_reward(advantage, envs[i].num_tasks, args.gamma)
                
                log_prob = task_allocation_log_prob + sequence_log_prob
                log_probs[i] = log_prob
                entropies[i] = entropy[0].mean() + entropy[1]
                advantages[i] = sequence_discounted_reward
        
        print("="*90)
        # print rewards 2 significant digits
        rewards_string = " ".join([f"{r:.2f}" for r in raw_rewards])
        print(f"Rewards:\t {rewards_string}")
        print(f"Infeasible:\t {' '.join([str(r) for r in infeasible_counts])}")
        print("="*90)
        # # Run Advantage Actor Critic since we have a single action
        # advantages = np.stack(advantages[1:])   
        # log_probs = torch.stack(log_probs[1:])
        # entropies = torch.stack(entropies[1:])
        
        advantages = np.stack(advantages)
        log_probs = torch.stack(log_probs)
        entropies = torch.stack(entropies)
        
        actor_loss = -1.0 * (log_probs * torch.tensor(advantages)).mean(1) - (args.entropy_coeff * entropies)
        # - 1.0 * (entropies - entropies.mean().detach()) / (entropies.max().detach() - entropies.min().detach() + 1e-9)
        
        # advantages = torch.tensor(np.array(raw_rewards[1:args.num_envs + 1]) - raw_rewards[0])
        # print(f"Advantages:\t\t {' '.join([str(r) for r in advantages.detach().cpu().numpy()])}")
        # # Calculate the loss
        # A_x_log_probs = [torch.stack(ta_A_x_log_probs[1:]), torch.stack(s_A_x_log_probs[1:])]
        # A_x_log_probs_stacked = torch.stack(A_x_log_probs)
        # entropies = [torch.stack(ta_entropies[1:]).mean(dim=1), torch.stack(s_log_probs[1:])]
        # entropies_stacked = torch.stack(entropies)
        # actor_loss = -1.0 * (A_x_log_probs_stacked.mean()) #  + args.entropy_coeff * entropies_stacked.mean())
        
        actor_optimizer.zero_grad()
        actor_loss.mean().backward()
        nn.utils.clip_grad_norm_(actor.parameters(), max_norm=CLIP_GRAD, norm_type=2)
        actor_optimizer.step()
        
        scores = np.array(raw_rewards)
        final_rewards = np.array(raw_rewards)        
        if args.track:
            wandb.log(
                {
                    "steps": global_step,
                    "actor_loss": actor_loss.mean().item(),
                    "mean_sample_reward": scores[1:args.num_envs + 1].mean().item(),
                    "std_sample_reward": scores[1:args.num_envs + 1].std().item(),
                    "max_sample_reward": scores[1:args.num_envs + 1].max().item(),
                    "min_sample_reward": scores[1:args.num_envs + 1].min().item(),
                    "greedy_reward": scores[0].item(),
                    "edf_reward": scores[args.num_envs + 1].item(),
                    "improved_edf_reward": scores[args.num_envs + 2].item(),
                    "milp_reward": scores[args.num_envs + 3].item(),
                    # "edf": 
                    # "q_reward_1_mean": qf_values[0].mean().item(), # qf1_values.mean().item(),
                    # "q_reward_2_mean": qf_values[1].mean().item(), 
                    
                    "mean_sample_reward_last": final_rewards[1:args.num_envs + 1].mean().item(),
                    "std_sample_reward_last": final_rewards[1:args.num_envs + 1].std().item(),
                    "max_sample_reward_last": final_rewards[1:args.num_envs + 1].max().item(),
                    "min_sample_reward_last": final_rewards[1:args.num_envs + 1].min().item(),
                    "greedy_reward_last": final_rewards[0].item(),
                    "edf_reward_last": final_rewards[args.num_envs + 1].item(),
                    "improved_edf_reward_last": final_rewards[args.num_envs + 2].item(),
                    "milp_reward_last": final_rewards[args.num_envs + 3].item(),
                    
                    "greedy_infeasible": infeasible_counts[0],
                    "mean_infeasible": np.array(infeasible_counts[1:args.num_envs + 1]).mean().item(),
                    "max_infeasible": np.array(infeasible_counts[1:args.num_envs + 1]).max().item(),
                    "min_infeasible": np.array(infeasible_counts[1:args.num_envs + 1]).min().item(),
                    "edf_infeasible": infeasible_counts[args.num_envs + 1],
                    "improved_edf_infeasible": infeasible_counts[args.num_envs + 2],
                    "milp_infeasible": infeasible_counts[args.num_envs + 3],
                    
                }
            )
        # update the target networks
        if global_step % 100 == 0:
            # writer.add_scalar("losses/qf1_values", qf_values[0].mean().item(), global_step)
            # writer.add_scalar("losses/qf2_values", qf_values[1].mean().item(), global_step)
            writer.add_scalar("losses/actor_loss", actor_loss.mean().item(), global_step)
            writer.add_scalar("losses/alpha", alpha, global_step)
            print("SPS:", int(global_step / (time.time() - start_time)))
            writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

        if global_step % args.checkpoint_step == 0:
            save_model(actor, checkpoint_base + f"{str(global_step).zfill(5)}.pt")
                
    save_model(actor, checkpoint_base + f"{str(args.num_iterations).zfill(5)}.pt")
    # save_model(qf[0], checkpoint_base + f"qf1_{str(args.num_iterations).zfill(5)}.pt")
    # save_model(qf[1], checkpoint_base + f"qf2_{str(args.num_iterations).zfill(5)}.pt")
    # [env.close() for env in envs]
    writer.close()