##
## (c) Anonymous authors (2026)
##
## > Informed asymmetric actor-critic
##
##

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm

from .elman_rnn import ElmanRNN


class PolicyNet(nn.Module):
    """
    
    Policy network

    input_dim: input dimension
    hidden_dim: width of the network

    """

    def __init__(self, input_dim, hidden_dim, num_actions):
        super().__init__()

        # Elman-type RNN to encode history
        self.rnn = ElmanRNN(input_dim, hidden_dim)

        # Linear readout layer
        self.readout = nn.Linear(hidden_dim, num_actions)

    def forward(self, inputs):
        """
        Forward pass
        """
        rnn_out, _ = self.rnn(inputs)
        logits = self.readout(rnn_out)

        # Softmax output
        return F.log_softmax(logits, dim=-1)


class RecurrentActor(nn.Module):
    def __init__(self, obs_dim, action_dim, hidden_dim):
        super().__init__()
        self.rnn = nn.GRU(obs_dim, hidden_dim, batch_first=True)
        self.policy_head = nn.Linear(hidden_dim, action_dim)

    def forward(self, obs_seq, h0=None):
        out, h = self.rnn(obs_seq, h0)
        logits = self.policy_head(out)
        return torch.softmax(logits, dim=-1), h


class AsymmetricCriticLinearReadout(nn.Module):
    """

    Linear readout for asymmetric critic

    hidden_dim: dimension of the latent history representation
    latent_dim: dimension of the privileged information

    """

    def __init__(self, hidden_dim, latent_dim):
        super().__init__()

        # Linear readout layer with concatenated input (latent history representation + privileged information)
        self.readout = nn.Linear(hidden_dim + latent_dim, 1)

    def forward(self, h_t, i_t):
        """
        Forward pass

        h_t: latent h_t representation
        i_t: privileged information

        """

        # Ensure i_t has batch dimension matching h_t
        if i_t.dim() == 1:
            i_t = i_t.unsqueeze(0)

        # Linear readout
        return self.readout(torch.cat([h_t, i_t], dim=-1))


class AsymmetricCriticNonLinearReadout(nn.Module):
    """

    Non-linear readout for asymmetric critic

    hidden_dim: dimension of the latent history representation
    latent_dim: dimension of the privileged information

    """

    def __init__(self, hidden_dim, latent_dim):
        super().__init__()
        input_dim = hidden_dim + latent_dim

        # First hidden layer with 256 units
        self.fc1 = nn.Linear(input_dim, 256)
        # Linear readout layer
        self.fc2 = nn.Linear(256, 1)

    def forward(self, h_t, i_t):
        """
        Forward pass

        h_t: latent h_t representation
        i_t: privileged information

        """

        # Ensure i_t has batch dimension matching h_t
        if i_t.dim() == 1:
            i_t = i_t.unsqueeze(0)
        if h_t.dim() == 1:
            h_t = h_t.unsqueeze(0)

        x = torch.cat([h_t, i_t], dim=-1)

        # ReLU activation after first hidden layer
        x = F.relu(self.fc1(x))

        # Linear readout
        return self.fc2(x)


class RecurrentCritic(nn.Module):
    def __init__(self, obs_dim, action_dim, hidden_dim):
        super().__init__()
        self.rnn = nn.GRU(obs_dim + action_dim, hidden_dim, batch_first=True)
        self.head = nn.Sequential(
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x, h0=None):
        out, h = self.rnn(x, h0)
        v = self.head(out).squeeze(-1)
        return v, out


class AsymmetricRecurrentCritic(nn.Module):
    def __init__(self, obs_dim, action_dim, hidden_dim, latent_dim):
        super().__init__()
        self.rnn = nn.GRU(obs_dim + action_dim, hidden_dim, batch_first=True)
        self.head = nn.Sequential(
            nn.Linear(hidden_dim + latent_dim, 1)
        )

    def forward(self, x, i, h0=None):
        out, h = self.rnn(x, h0)
        combined = torch.cat([out, i], dim=-1)
        v = self.head(combined).squeeze(-1)
        return v, out
