# modules/agents/rnn_agent.py
# RNN agent with optional FiLM conditioning on z_t

import torch
import torch.nn as nn
import torch.nn.functional as F

from modules.opponent.contrastive_encoder import FiLM

class RNNAgent(nn.Module):
    def __init__(self, input_shape, args):
        super().__init__()
        self.args = args
        self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim)
        self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)
        self.fc2 = nn.Linear(args.rnn_hidden_dim, args.n_actions)
        self.fc3 = nn.Linear(args.rnn_hidden_dim, args.n_actions)

        self.use_film = getattr(args, "use_z_film", True)
        self.z_dim = getattr(args, "z_dim", 16)
        if self.use_film:
            self.film = FiLM(args.rnn_hidden_dim, self.z_dim)

    def init_hidden(self):
        return self.fc1.weight.new(1, self.args.rnn_hidden_dim).zero_()

    def forward(self, inputs, hidden_state, z_t=None):
        # inputs: [B*n_agents, input_dim]
        x = F.relu(self.fc1(inputs))
        h_in = hidden_state.reshape(-1, self.args.rnn_hidden_dim)
        h = self.rnn(x, h_in)
        if self.use_film and z_t is not None:
            h = self.film(h, z_t)
        q1 = self.fc2(h)
        q2 = self.fc3(h)
        return torch.min(q1, q2), h
