# simple_act_policy.py — use agent.th without PyMARL env, ε=0 (greedy)

import os
import numpy as np
import torch as th
from src.agents.qmix.rnn_agent import RNNAgent  # RNN agent class from PyMARL
from argparse import Namespace

# ==========================================================
# MODEL AND ENVIRONMENT CONFIGURATION
# ==========================================================
AGENT_PATH       = r"models/baseline/qmix/agent.th"  
# ↑ Path to the trained model (.th). Change to match your experiment.
# Example alternative: r"results\models\qmix__2025-08-08_01-02-10\96496\agent.th"

N_AGENTS         = 2    # Number of agents in the environment
N_ACTIONS        = 4    # Number of possible actions per agent
OBS_DIM          = 66   # Observation dimension per agent (adjust to your env)
RNN_HIDDEN_DIM   = 64   # Hidden dimension of the RNN (must match training config)
OBS_LAST_ACTION  = True # Whether training included "last action" in observations
OBS_AGENT_ID     = True # Whether training included "agent ID" in observations
# ==========================================================

# Build the agent and load weights
INPUT_SHAPE = OBS_DIM \
              + (N_ACTIONS if OBS_LAST_ACTION else 0) \
              + (N_AGENTS if OBS_AGENT_ID else 0)

args = Namespace(
    rnn_hidden_dim=RNN_HIDDEN_DIM,
    n_actions=N_ACTIONS
)

agent = RNNAgent(input_shape=INPUT_SHAPE, args=args)
state = th.load(AGENT_PATH, map_location="cpu")
agent.load_state_dict(state)
agent.eval()  # evaluation mode (no gradients)

# Recurrent hidden states + last actions for each agent
hidden = [th.zeros(1, RNN_HIDDEN_DIM) for _ in range(N_AGENTS)]
last_actions = [np.zeros(N_ACTIONS, dtype=np.float32) for _ in range(N_AGENTS)]

def reset_policy():
    """
    Reset hidden states and last actions.
    Call this at the beginning of each episode.
    """
    global hidden, last_actions
    hidden = [th.zeros(1, RNN_HIDDEN_DIM) for _ in range(N_AGENTS)]
    last_actions = [np.zeros(N_ACTIONS, dtype=np.float32) for _ in range(N_AGENTS)]

@th.no_grad()
def act(obs_list, avail_actions=None):
    """
    Select greedy actions (ε=0) for all agents.

    Parameters
    ----------
    obs_list : list[np.array]
        List of observations per agent, each of shape (OBS_DIM,)
    avail_actions : list[np.array] or None
        Optional list of binary masks per agent of shape (N_ACTIONS,).
        If provided, invalid actions (0 entries) are masked out.
        If None, all actions are assumed available.

    Returns
    -------
    actions : list[int]
        Selected actions (one per agent).
    """
    actions = []
    for i in range(N_AGENTS):
        # Build input vector
        parts = [np.asarray(obs_list[i], dtype=np.float32)]
        if OBS_LAST_ACTION:
            parts.append(last_actions[i])
        if OBS_AGENT_ID:
            onehot = np.zeros(N_AGENTS, dtype=np.float32)
            onehot[i] = 1.0
            parts.append(onehot)

        x = th.tensor(np.concatenate(parts), dtype=th.float32).unsqueeze(0)  # (1, input_shape)

        # Forward pass through agent
        q, h = agent(x, hidden[i])   # q: (1, N_ACTIONS)
        hidden[i] = h
        q = q.squeeze(0)             # (N_ACTIONS,)

        # Apply action mask if given
        if avail_actions is not None:
            mask = th.tensor(avail_actions[i], dtype=th.bool)
            q[~mask] = -1e9

        # Select greedy action
        a = int(q.argmax().item())
        actions.append(a)

        # Update last action one-hot
        if OBS_LAST_ACTION:
            la = np.zeros(N_ACTIONS, dtype=np.float32)
            la[a] = 1.0
            last_actions[i] = la

    return actions
