# This file is here just to define MlpPolicy/CnnPolicy
# that work for PPO
from typing import Any, Dict, List, Optional, Type

import torch as th
from gymnasium import spaces
from torch import nn

from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
from stable_baselines3.ppo.ppo import PPO

from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.torch_layers import (
    BaseFeaturesExtractor,
    CombinedExtractor,
    FlattenExtractor,
    NatureCNN,
    create_mlp,
)
from stable_baselines3.common.type_aliases import Schedule
import pdb

from stable_baselines3.common.vec_env import DummyVecEnv


MlpPolicy = ActorCriticPolicy
CnnPolicy = ActorCriticCnnPolicy
MultiInputPolicy = MultiInputActorCriticPolicy


## To be modified below
class GRUNetwork(nn.Module):
    def __init__(self, input_dim, gru_hidden_size, output_dim):
        super(GRUNetwork, self).__init__()
        self.gru = nn.GRU(input_dim, gru_hidden_size, batch_first=True)
        self.fc = nn.Linear(gru_hidden_size, output_dim)
    
    def forward(self, x):
        # Assume x has shape (batch_size, seq_len, input_dim)
        gru_out, _ = self.gru(x)
        # Use only the last output of the GRU
        output = self.fc(gru_out[:, -1, :])
        return output

class GRUPolicy(ActorCriticPolicy):
    def __init__(self, observation_space, action_space, lr_schedule, **kwargs):
        super(GRUPolicy, self).__init__(observation_space, action_space, lr_schedule, **kwargs)
        
        self.gru_net = GRUNetwork(input_dim=observation_space.shape[0], gru_hidden_size=128, output_dim=action_space.n)
    
    def forward(self, obs):
        return self.gru_net(obs)