# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_ataripy


import os
import sys
import inspect

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 

import random
import copy
import time
from dataclasses import dataclass

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
import tyro
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter

from utils.utils import set_seed
from scheduling.environment import SchedulingEnvironment, generate_environment, make_env

from solvers.hetgat_solver_individual import HetGatSolverIndividual, HetGatSolverIndividualTaskFirst
from models.graph_scheduler import GraphSchedulerCritic

from solvers.edf import EarliestDeadlineFirstAgent
from solvers.improved_edf import ImprovedEarliestDeadlineFirstAgent
from solvers.milp_solver import MILP_Solver
from training.replay_buffer import ReplayBuffer

@dataclass
class Args:
    exp_name: str = os.path.basename(__file__)[: -len(".py")]
    """the name of this experiment"""
    seed: int = 10
    """seed of the experiment"""
    torch_deterministic: bool = True
    """if toggled, `torch.backends.cudnn.deterministic=False`"""
    cuda: bool = False
    """if toggled, cuda will be enabled by default"""
    track: bool = True
    """if toggled, this experiment will be tracked with Weights and Biases"""
    wandb_project_name: str = "Task Allocation and Scheduling with Path Planning using GNNs"
    """the wandb's project name"""
    wandb_entity: str = "wandb_repo"
    """the entity (team) of wandb's project"""
    capture_video: bool = False
    """whether to capture videos of the agent performances (check out `videos` folder)"""

    # Algorithm specific arguments
    env_id: str = "advantage wrt same step greedy"
    """the id of the environment"""
    env_location: str = "data/problem_set_r5_t10_s10_f50_w50_euc_2000"
    """the location of the environment"""
    total_timesteps: int = 10000000
    """total timesteps of the experiments"""
    learning_rate: float = 1e-3
    """the learning rate of the optimizer"""
    num_envs: int = 8
    """the number of parallel game environments"""
    num_steps: int = 128
    """the number of steps to run in each environment per policy rollout"""
    anneal_lr: bool = True
    """Toggle learning rate annealing for policy and value networks"""
    gamma: float = 0.99
    """the discount factor gamma"""
    tau: float = 0.99
    """target smoothing coefficient (default: 1)"""
    batch_size: int = 64
    """the batch size of sample from the reply memory"""
    partition_learning: bool = True
    """if toggled, the partition learning will be enabled, the critic will learn up to learning_starts and the actor will learn after that"""
    learning_starts: int = 200
    """timestep to start learning"""
    policy_lr: float = 1e-3
    """the learning rate of the policy network optimizer"""
    q_lr: float = 1e-3
    """the learning rate of the Q network network optimizer"""
    update_frequency: int = 4
    """the frequency of training updates"""
    target_network_frequency: int = 8000
    """the frequency of updates for the target networks"""
    alpha: float = 1.0
    """Entropy regularization coefficient."""
    autotune: bool = False
    """automatic tuning of the entropy coefficient"""
    target_entropy_scale: float = 0.89
    """coefficient for scaling the autotune entropy target"""

    # Environment specific arguments
    start_problem: int = 1
    """the starting problem number"""
    end_problem: int = 20
    """the ending problem number"""
    
    # Training specific arguments
    num_iterations: int = 1500
    """the number of iterations"""
    # to be filled in runtime
    batch_size: int = 8
    """the batch size (computed in runtime)"""
    minibatch_size: int = 0
    """the mini-batch size (computed in runtime)"""
    num_baseline_update: int = 10
    """the number of baseline updates"""
    
    num_heads: int = 8
    """the number of heads in the GAT layer"""
    num_layers: int = 4
    """the number of layers in the GAT layer"""
    agent_layers: int = num_layers
    """the number of layers in the agent GAT layer"""
    task_layers: int = num_layers
    """the number of layers in the critic GAT layer"""
    
    # Graph specific arguments
    graph_mode: str = 'attention'
    """the mode of the graph (no_edge, edge, attention, attention_simple)"""
    critic_mode: str = 'attention'
    """the mode of the critic (no_edge, edge, attention, attention_simple)"""
    # Save Checkpoints
    save_location: str = "final_checkpoints"
    """the location to save the checkpoints"""
    checkpoint_step: int = 250
    """_summary_step to save the model"""
    continue_training: int = 1
    """the iteration to continue training from"""
    
    reward_mode: str = "base"
    """the reward mode for training [base/feasible]"""
    
    baseline_boosting: bool = True
    """if toggled, the baselines will be boosted"""
    
    num_qf: int = 2
    """the number of Q-functions"""
    temperature: float = 1.0
    """The Temperature Parameter for SoftMax for the sampler"""
    adaptive_temperature: bool = False
    """if toggled, the temperature will be adapted"""
    entropy_coeff: float = 0.01
    """the entropy coefficient"""
    exp_mode: str = "baseline"
    """ Experiment for Critic modes: [baseline/random] to allow random actions for critic training at start."""
    
def save_model(model, path):
    torch.save(model.state_dict(), path)
    
def load_model(model, path):
    model.load_state_dict(torch.load(path))
    return model

CLIP_GRAD = 1.0

if __name__ == "__main__":
    args = tyro.cli(Args)
    print(args.env_id)
    # args.batch_size = int(args.num_envs * args.num_steps)
    # args.num_iterations = args.total_timesteps // args.batch_size
    run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
    if args.track:
        import wandb

        if args.continue_training is not None and args.continue_training > 1:
            file_name = f"{args.save_location}/wandb_access_codes/{args.exp_name}__{'_'.join(args.env_id.split(' '))}__{args.num_heads}__{args.num_layers}__{args.graph_mode}_s{args.seed}_code.txt"
            if not os.path.exists(args.save_location):
                os.makedirs(args.save_location)
            if not os.path.exists(f'{args.save_location}/wandb_access_codes'):
                os.makedirs(f'{args.save_location}/wandb_access_codes')
            with open(file_name, "r") as f:
                # get the id from the file f
                wand_id = f.read().strip()
                print(run_name, wand_id)
                
                
        wandb.init(
            project=args.wandb_project_name,
            entity=args.wandb_entity,
            sync_tensorboard=True,
            config=vars(args),
            name=run_name,
            monitor_gym=True,
            save_code=True,
            resume= "allow",
        )
    writer = SummaryWriter(f"runs/{run_name}")
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )

    # TODO: write the Wandb access code into a single file for easy access
    if args.track:
        file_name = f"{args.save_location}/wandb_access_codes/{args.exp_name}__{'_'.join(args.env_id.split(' '))}__{args.num_heads}__{args.num_layers}__{args.graph_mode}_s{args.seed}_code.txt"
        if not os.path.exists(args.save_location):
            os.makedirs(args.save_location)
        if not os.path.exists(f'{args.save_location}/wandb_access_codes'):
            os.makedirs(f'{args.save_location}/wandb_access_codes')
        with open(file_name, "w") as f:
            print(f"{wandb.run.id}", file=f)
        print(f"Wandb ID is saved to {file_name}")
        # TRY NOT TO MODIFY: seeding
        # random.seed(args.seed)
        # np.random.seed(args.seed)
        # torch.manual_seed(args.seed)
        # torch.backends.cudnn.deterministic = args.torch_deterministic

    
    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
    set_seed(args.seed, device)

    layer_data = {
        "agent": args.num_layers,
        "task": args.num_layers
    }
    # env setup (Batchs + Greedy + EDF)
    envs = [make_env(args.env_location, 1, 1, reward_mode=args.reward_mode) for _ in range(args.num_envs + 4)]
    # assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

    task_first = ('task_first' in args.graph_mode)
    if not task_first:
        actor = HetGatSolverIndividual(args.graph_mode, args.num_heads, args.num_layers, layer_data, args.temperature).to(device)
    else:
        actor = HetGatSolverIndividualTaskFirst(args.graph_mode, args.num_heads, args.num_layers, layer_data, args.temperature).to(device)
    
    actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.policy_lr, eps=1e-4, weight_decay=1e-1)
    
    qf = []
    qf_target = []
    qf_optimizers = []
    for i in range(args.num_qf):
        set_seed(args.seed, device)
        qf_ = GraphSchedulerCritic(num_heads=args.num_heads, num_layers=args.num_layers, graph=args.critic_mode).to(device)
        set_seed(args.seed, device)
        qf_target_ = GraphSchedulerCritic(num_heads=args.num_heads, num_layers=args.num_layers, graph=args.critic_mode).to(device)
        qf_target_.load_state_dict(qf_.state_dict())
        
        qf.append(qf_)
        qf_target.append(qf_target_)
        # TRY NOT TO MODIFY: eps=1e-4 increases numerical stability
        qf_optimizers.append(optim.Adam(list(qf_.parameters()), lr=args.q_lr, eps=1e-4, weight_decay=1e-1))
        

    # if save_location does not exist, create it
    if not os.path.exists(args.save_location):
        os.makedirs(args.save_location)
    env_id = "_".join(args.env_id.split(" "))
    checkpoint_base = os.path.join(args.save_location, f"{args.exp_name}__{env_id}__{args.seed}__{args.graph_mode}__")
    
    if args.continue_training is None or args.continue_training <= 1:
        initial_checkpoint = checkpoint_base + f"{str(0).zfill(5)}.pt"
        save_model(actor, initial_checkpoint)
        for qf_, qf_target_, i in zip(qf, qf_target, range(args.num_qf)):
            save_model(qf_, checkpoint_base + f"qf_{str(i).zfill(5)}_{str(0).zfill(5)}.pt")
            save_model(qf_target_, checkpoint_base + f"qf_target_{str(i).zfill(5)}_{str(0).zfill(5)}.pt")
    else:
        checkpoint = checkpoint_base + f"{str(args.continue_training).zfill(5)}.pt"
        actor = load_model(actor, checkpoint)
        for qf_, qf_target_, i in zip(qf, qf_target, range(args.num_qf)):
            qf_ = load_model(qf_, checkpoint_base + f"qf_{str(i).zfill(5)}_{str(args.continue_training).zfill(5)}.pt")
            qf_target_.load_state_dict(qf_.state_dict())
            qf[i] = qf_
            qf_target[i] = qf_target_
        
    baselines = {
        "edf": EarliestDeadlineFirstAgent(),
        "improved_edf": ImprovedEarliestDeadlineFirstAgent(),
        "milp_solver": MILP_Solver()
    }
    
    replay_buffer = ReplayBuffer(12 * 20)
    # Automatic entropy tuning
    if args.autotune:
        target_entropy = -args.target_entropy_scale * torch.log(1 / torch.tensor((envs[0].num_tasks)) + 1 / torch.tensor((envs[0].num_agents)))
        log_alpha = torch.zeros(1, requires_grad=True, device=device)
        alpha = log_alpha.exp().item()
        a_optimizer = optim.Adam([log_alpha], lr=args.q_lr, eps=1e-4)
    else:
        alpha = args.alpha

    start_time = time.time()
    if args.continue_training is not None or args.continue_training > 1:
        start = args.continue_training
    else:
        start = 1
        
    qf1_loss = torch.tensor(0.0).to(device)
    qf2_loss = torch.tensor(0.0).to(device)
    qf_loss_ = torch.tensor(0.0).to(device)
    
    # Probability Distribution per step for greedy action
    agent_probabilities = []
    task_probabilities = []
    
    # Training Loop
    for global_step in range(start, args.num_iterations + 1):
        # TRY NOT TO MODIFY: set the iterations for the environment
        envs = [make_env(args.env_location, global_step, args.start_problem, args.end_problem, reward_mode=args.reward_mode) for _ in range(args.num_envs + len(baselines) + 1)]
        # Memory Buffers
        obs = [env.reset()[0] for env in envs]

        # memory = [None for _ in range(len(envs))]
        dones = torch.zeros(len(envs)).to(device)
        next_obs = [None for _ in range(len(envs))]
        
        baselines["milp_solver"].set_environment(envs[-1])
        max_steps = max([env.num_tasks for env in envs])
        print("MaxSteps", max_steps)
        print(f"Step {global_step}: Training - {'no_critic' not in args.reward_mode} - {args.partition_learning} - {global_step} - {args.learning_starts}")
        print("="*90)
        print(f"{global_step}:\tGreedy\t" + "\t".join([str(i) for i in range(args.num_envs)]) + "\tEDF\tImpEDF\tMILP")
        print("-"*90)
        scores = torch.zeros(len(envs)).to(device)
        final_rewards = torch.zeros(len(envs)).to(device)
        
        tmp_outputs = []
        agent_probabilities.append([])
        task_probabilities.append([])
        for step in range(0, max_steps):
            actions = [None for _ in range(len(envs))]
            with torch.no_grad():    
                for i, env in enumerate(envs):
                    if obs[i] is None:
                        continue
                    if i <= args.num_envs and args.exp_mode in ['random'] and args.partition_learning and global_step < args.learning_starts:
                        actions[i] = env.sample_action_space()
                    elif i == 0:
                        actions[i], _, _, _, select_probs = actor.get_action_and_memory(obs[i], greedy=True)
                        agent_probs, task_probs = select_probs
                        agent_probabilities[-1].append(agent_probs)
                        task_probabilities[-1].append(task_probs)
                    elif i <= args.num_envs:
                        actions[i], _, _, _, _ = actor.get_action_and_memory(obs[i], adaptive_temperature=args.adaptive_temperature, greedy=False)
                    elif i == args.num_envs + 1:
                        actions[i], _, _, _ = baselines["edf"].get_action(obs[i])
                    elif i == args.num_envs + 2:
                        actions[i], _, _, _ = baselines["improved_edf"].get_action(obs[i])
                    elif i == args.num_envs + 3:
                        actions[i], _, _, _ = baselines["milp_solver"].get_action(obs[i])
                    else:
                        raise ValueError(f"Invalid Environment ID: {i}")
            
            print(f"{step+1}/{max_steps} \t" + "\t".join([str(a) for a in actions]))

            # TRY NOT TO MODIFY: execute the game and log data.
                
            a = [actions[i] for i, o in enumerate(obs) if o != None]
            outputs = [env.step(actions[i]) if obs[i] is not None else (None, +1, True, {}) for i, env in enumerate(envs)]
                
            next_obs, reward, terminations, feasibles = zip(*outputs)
            # print(f"+ {reward}")
            # milp_reward = np.array(reward[-1])
            # greedy_reward = np.array(reward[0])
            # reward = np.array(reward) - greedy_reward
            # reward = tuple(reward.tolist())
            # reward = tuple(((np.array(reward) - milp_reward) / (-1 * milp_reward)).tolist())
            # reward = np.array(reward) - greedy_reward
            # reward = tuple(reward.tolist())
            # reward = tuple(((np.array(reward) - milp_reward)).tolist())
            # reward = tuple(((np.array(reward) + 0.5) * 2.0).tolist())
            print(f"- {reward}")
            done_mask = torch.tensor(terminations, dtype=torch.float32).to(device).view(-1)
            tmp_output = []
            for i, (n_obs, reward, terminal, feasible) in enumerate(outputs):
                if terminal:
                    final_rewards[i] = reward
                # reward = reward
                if obs[i] is None:
                    continue
                if i == len(envs) - 1:
                    next_action = None
                else:
                    next_action = actions[i + 1]
                if args.baseline_boosting and i >= args.num_envs:
                    continue
                tmp_output.append((obs[i], actions[i], n_obs, reward, done_mask[i]))
            
            tmp_outputs.append(tmp_output)
            
            
            obs = copy.deepcopy(next_obs)
            if all(terminations):
                break
        
        # print(tmp_outputs)
        print(len(tmp_outputs), len(tmp_outputs[0]))
        
        for tmp_output in tmp_outputs:
            for i, (obs, actions, n_obs, reward, done) in enumerate(tmp_output):
                actual_reward = tmp_outputs[-1][i][3]
                # greedy_reward = tmp_outputs[-1][0][3]
                # advantage = actual_reward - greedy_reward
                advantage = actual_reward
                replay_buffer.append(obs, actions, n_obs, advantage, done)
                
        # print(f"TMP Outputs: {len(tmp_outputs)}")
        # replay_buffer.append(obs[i], actions[i], n_obs, reward, done_mask[i])
        
        infeasible_counts = []
        for i, env in enumerate(envs):
            scores[i] = env.get_raw_score()
            infeasible_counts.append(env.num_infeasible)
            
        print("-"*90)
        print(f"Reward:" + '\t'.join([f"{r.item():.3f}" for r in scores]))
        print(f"Inf:   " + '\t'.join([f"{r}" for r in infeasible_counts]))
        print("="*90)
        # ALGO Logic: Training the policy
        size = min(replay_buffer.size(), args.batch_size)
        
        
        for i in range(3):
            qf_values = [[0.0] for _ in range(args.num_qf)]# torch.zeros((size * args.batch_size)).to(device)
            data = replay_buffer.sample(size)
            if 'no_critic' not in args.reward_mode and args.partition_learning is False: # or global_step <= args.learning_starts):
                
                q_values = [] # torch.zeros((size * args.batch_size)).to(device)
                qf_values = [[] for _ in range(args.num_qf)]# torch.zeros((size * args.batch_size)).to(device)
                # CRITIC training
                for i, (obs, action, next_obs, reward, done) in enumerate(data):
                    with torch.no_grad():
                        if not done:
                            next_action, next_state_log_pi, _, _ = actor.get_action_and_probability(next_obs) # greedy action for the next step
                            qf_next_targets = []
                            for j in range(args.num_qf):
                                qf_target_ = qf_target[j]
                                # qf_ = qf[j]
                                qf_target_next = qf_target_(next_obs, next_action)
                                qf_next_targets.append(qf_target_next)
                            qf_next_targets = torch.stack(qf_next_targets)
                            min_qf_next_target = torch.min(qf_next_targets, dim=0).values
                            
                            # we can use the action probabilities instead of MC sampling to estimate the expectation
                            min_qf_next_target = min_qf_next_target - alpha * next_state_log_pi
                            # adapt Q-target for discrete Q-function
                            next_q_values = torch.tensor(reward) + args.gamma * min_qf_next_target
                        else:
                            next_q_values = torch.tensor(reward)
                        # q_values[i] = next_q_values
                        q_values.append(next_q_values.squeeze())
                    # use Q-values only for the taken actions
                    for j in range(args.num_qf):
                        qf_ = qf[j]
                        qf_values[j].append(qf_(obs, action).squeeze())
                        # qf_values[j][i] = qf(obs, action).squeeze()
                        
                # print(q_values)
                # q_values = torch.stack(q_values).to(device).squeeze()
                qf_values = [torch.stack(qf_value).to(device).squeeze() for qf_value in qf_values]

                qf_losses = []
                for i in range(args.num_qf):
                    q_loss_ = F.mse_loss(qf_values[i], q_values[i]).mean()
                    qf_losses.append(q_loss_.detach())
                    qf_optimizers[i].zero_grad()
                    q_loss_.backward()
                    nn.utils.clip_grad_norm_(qf[i].parameters(), max_norm=CLIP_GRAD, norm_type=2)
                    qf_optimizers[i].step()
                    
                    # qf_losses.append(qf_loss_)
                
                qf1_loss = qf_losses[0]
                qf2_loss = qf_losses[1]
                qf_loss_ = sum([qf_loss.detach() for qf_loss in qf_losses]) / args.num_qf
                print(f"QF1 Loss: {qf1_loss.item()} QF2 Loss: {qf2_loss.item()} QF Loss: {qf_loss_.item()}")
            else:
                print(f"Step {global_step}: Skipping Critic Training")
                
            actor_loss = torch.tensor(0.0).to(device)
            alpha_loss = torch.tensor(0.0).to(device)
            # ACTOR training
            if args.partition_learning is not False or global_step > args.learning_starts:
                log_pi = [None for _ in range(size)]
                action_probs = [None for _ in range(size)]
                action_probability = [None for _ in range(size)]
                min_qf_values = [None for _ in range(size)]
                entropy = [None for _ in range(size)]
                if 'bc' in args.reward_mode:
                    bc_loss = torch.tensor(0.0).to(device)
                
                for i, (obs, action, next_obs, reward, done) in enumerate(data):    
                    _, log_pi[i], action_probs[i], entropy[i] = actor.get_action_and_probability(obs, action=action)
                    # min_qf_values = reward
                    if 'no_critic' in args.reward_mode:
                        min_qf_values[i] = torch.tensor([[reward]], device=device)
                    elif done:
                        min_qf_values[i] = torch.tensor([[reward]], device=device)
                    else:
                        with torch.no_grad():
                            qf_vals = []
                            for j in range(args.num_qf):
                                qf_ = qf[j]
                                qf_vals.append(qf_(obs, action))
                            qf_vals = torch.stack(qf_vals)
                            min_qf_values[i] = (torch.min(qf_vals, dim=0).values)
                            
                    # Behavior Cloning on the Improved EDF
                    if 'bc' in args.reward_mode:
                        actions, _, _, _ = baselines["improved_edf"].get_action(obs)
                        # task
                        task_ground_truth = torch.zeros(action_probs[i][0].shape).to(device)
                        task_ground_truth[actions[0]] = 1.0
                        task_prob_loss = torch.nn.functional.binary_cross_entropy(action_probs[i][0], task_ground_truth)
                        agent_ground_truth = torch.zeros(action_probs[i][1].shape).to(device)
                        agent_ground_truth[actions[1]] = 1.0
                        agent_prob_loss = torch.nn.functional.binary_cross_entropy(action_probs[i][1], agent_ground_truth)
                        bc_loss += task_prob_loss + agent_prob_loss
                    else:
                        action_probability[i] = action_probs[i]

                log_pi = torch.stack(log_pi)
                action_probability = torch.stack(action_probs).squeeze()
                print(min_qf_values)
                min_qf_values = torch.stack(min_qf_values).squeeze()
                entropy = torch.stack(entropy).squeeze()
                
                if 'bc' in args.reward_mode and global_step < args.learning_starts:
                    actor_loss = bc_loss
                elif 'atari' in args.reward_mode:
                    # print(action_probability)
                    actor_loss = action_probability * ((alpha * log_pi) - min_qf_values)
                    if 'neg' in args.reward_mode:
                        actor_loss = -1.0 * actor_loss
                    # actor_loss = actor_loss.mean() - args.entropy_coeff * entropy.mean()
                elif 'continuous' in args.reward_mode:
                    actor_loss = (min_qf_values - alpha * log_pi) 
                elif 'pg' in args.reward_mode:
                    actor_loss = (min_qf_values * log_pi)
                else:
                    raise NotImplementedError(f'reward mode {args.reward_mode} not implemented for actor set using --reward_mode')
                    # no need for reparameterization, the expectation can be calculated for discrete actions
                    # actor_loss = 1.0 * (((alpha * log_pi) - min_qf_values))
                    # actor_loss = -1.0 * (min_qf_next_target * log_pi)
                    # print(actor_loss.shape, actor_loss)
                actor_loss = actor_loss.mean()
                actor_optimizer.zero_grad()
                actor_loss.backward()
                nn.utils.clip_grad_norm_(actor.parameters(), max_norm=CLIP_GRAD, norm_type=2)
                actor_optimizer.step()

                if args.autotune:
                    # re-use action probabilities for temperature loss
                    alpha_loss = (-log_alpha.exp() * (log_pi.detach() + target_entropy).detach()).mean()

                    a_optimizer.zero_grad()
                    alpha_loss.backward()
                    a_optimizer.step()
                    alpha = log_alpha.exp().item()
            else:
                print(f"Step {global_step}: Skipping Actor Training")    
        if args.track:
            
            # entropy equation for probability distribution : H(p) = - Σ p(x) * log(p(x))
            agent_entropies = {f'agent_entropies_{i}': -1 * (agent_probabilities[-1][i] * torch.log(agent_probabilities[-1][i])).sum().item() for i in range(len(agent_probabilities[-1]))}
            task_entropies = {f'task_entropies_{i}': -1 * (task_probabilities[-1][i] * torch.log(task_probabilities[-1][i])).sum().item() for i in range(len(task_probabilities[-1]))}
            
            print(f"Agent Entropy: {agent_entropies}")
            print(f"Task Entropy: {task_entropies}")
            
            agent_mean_entropy = np.array([agent_entropies[f'agent_entropies_{i}'] for i in range(len(agent_entropies))]).mean()
            agent_std_entropy = np.array([agent_entropies[f'agent_entropies_{i}'] for i in range(len(agent_entropies))]).std()
            task_mean_entropy = np.array([task_entropies[f'task_entropies_{i}'] for i in range(len(task_entropies))]).mean()
            task_std_entropy = np.array([task_entropies[f'task_entropies_{i}'] for i in range(len(task_entropies))]).std()
            agent_entropies['agent_mean_entropy'] = agent_mean_entropy
            agent_entropies['agent_std_entropy'] = agent_std_entropy
            task_entropies['task_mean_entropy'] = task_mean_entropy
            task_entropies['task_std_entropy'] = task_std_entropy
            
            
            logged = {
                    "steps": global_step,
                    "qf1_loss": qf1_loss.item(),
                    "qf2_loss": qf2_loss.item(),
                    "qf_loss": qf_loss_.item(),
                    "actor_loss": actor_loss.item(),
                    "mean_sample_reward": scores[1:args.num_envs + 1].mean().item(),
                    "std_sample_reward": scores[1:args.num_envs + 1].std().item(),
                    "max_sample_reward": scores[1:args.num_envs + 1].max().item(),
                    "min_sample_reward": scores[1:args.num_envs + 1].min().item(),
                    "greedy_reward": scores[0].item(),
                    "edf_reward": scores[args.num_envs + 1].item(),
                    "improved_edf_reward": scores[args.num_envs + 2].item(),
                    "milp_reward": scores[args.num_envs + 3].item(),
                    # "edf": 
                    # "q_reward_1_mean": qf_values[0].mean().item(), # qf1_values.mean().item(),
                    # "q_reward_2_mean": qf_values[1].mean().item(), 
                    
                    "mean_sample_reward_last": final_rewards[1:args.num_envs + 1].mean().item(),
                    "std_sample_reward_last": final_rewards[1:args.num_envs + 1].std().item(),
                    "max_sample_reward_last": final_rewards[1:args.num_envs + 1].max().item(),
                    "min_sample_reward_last": final_rewards[1:args.num_envs + 1].min().item(),
                    "greedy_reward_last": final_rewards[0].item(),
                    "edf_reward_last": final_rewards[args.num_envs + 1].item(),
                    "improved_edf_reward_last": final_rewards[args.num_envs + 2].item(),
                    "milp_reward_last": final_rewards[args.num_envs + 3].item(),
                    
                    "greedy_infeasible": infeasible_counts[0],
                    "mean_infeasible": np.array(infeasible_counts[1:args.num_envs + 1]).mean().item(),
                    "max_infeasible": np.array(infeasible_counts[1:args.num_envs + 1]).max().item(),
                    "min_infeasible": np.array(infeasible_counts[1:args.num_envs + 1]).min().item(),
                    "edf_infeasible": infeasible_counts[args.num_envs + 1],
                    "improved_edf_infeasible": infeasible_counts[args.num_envs + 2],
                    "milp_infeasible": infeasible_counts[args.num_envs + 3],
                }
            
            log = agent_entropies | task_entropies | logged
            wandb.log(log)
            # update the target networks
            if global_step % args.target_network_frequency == 0:
                for qf_, qf_target_ in zip(qf, qf_target):
                    for param, target_param in zip(qf_.parameters(), qf_target_.parameters()):
                        target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
                    
                    # qf_target_.load_state_dict(qf_.state_dict())
                    # qf_target_.data.copy_(args.tau * qf_.data + (1 - args.tau) * qf_target_.data)
                # for param, target_param in zip(qf[0].parameters(), qf1_target.parameters()):
                #     target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
                # for param, target_param in zip(qf[2].parameters(), qf2_target.parameters()):
                #     target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)

            if global_step % 100 == 0:
                # writer.add_scalar("losses/qf1_values", qf_values[0].mean().item(), global_step)
                # writer.add_scalar("losses/qf2_values", qf_values[1].mean().item(), global_step)
                writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step)
                writer.add_scalar("losses/qf2_loss", qf2_loss.item(), global_step)
                writer.add_scalar("losses/qf_loss", qf_loss_.item() / 2.0, global_step)
                writer.add_scalar("losses/actor_loss", actor_loss.item(), global_step)
                writer.add_scalar("losses/alpha", alpha, global_step)
                print("SPS:", int(global_step / (time.time() - start_time)))
                writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
                if args.autotune:
                    writer.add_scalar("losses/alpha_loss", alpha_loss.item(), global_step)

            if global_step % args.checkpoint_step == 0:
                save_model(actor, checkpoint_base + f"{str(global_step).zfill(5)}.pt")
                for i in range(args.num_qf):
                    save_model(qf[i], checkpoint_base + f"qf_{str(i).zfill(5)}_{str(global_step).zfill(5)}.pt")
                    save_model(qf_target[i], checkpoint_base + f"qf_target_{str(i).zfill(5)}_{str(global_step).zfill(5)}.pt")
                # save_model(qf[0], checkpoint_base + f"qf1_{str(global_step).zfill(5)}.pt")
                # save_model(qf[1], checkpoint_base + f"qf2_{str(global_step).zfill(5)}.pt")
                
    save_model(actor, checkpoint_base + f"{str(args.num_iterations).zfill(5)}.pt")
    for i in range(args.num_qf):
        save_model(qf[i], checkpoint_base + f"qf_{str(i).zfill(5)}_{str(args.num_iterations).zfill(5)}.pt")
        # save_model(qf_target[i], checkpoint_base + f"qf_target_{str(i).zfill(5)}.pt")
    # save_model(qf[0], checkpoint_base + f"qf1_{str(args.num_iterations).zfill(5)}.pt")
    # save_model(qf[1], checkpoint_base + f"qf2_{str(args.num_iterations).zfill(5)}.pt")
    envs.close()
    writer.close()