import gymnasium as gym
import torch.optim as optim
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import random
from torch.distributions import Normal
from collections import deque
import matplotlib.pyplot as plt
"""
Global constants
"""
SEED = 42
MAX_STEPS = 500

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

STATE_DIM = 6
ACTION_DIM = 2

#%%
"""
Networks
"""
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_dim)
        self._init_weights()

    def _init_weights(self, mean=0.0, std=1.0):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=mean, std=std)
                nn.init.normal_(m.bias, mean=mean, std=std)  # or normal_ if you want biases randomized too

    def forward(self, state):
        x = F.tanh(self.fc1(state) / torch.sqrt(torch.tensor(self.fc1.in_features)))
        x = F.tanh(self.fc2(x) / torch.sqrt(torch.tensor(self.fc2.in_features)))
        logits = self.fc3(x) / torch.sqrt(torch.tensor(self.fc3.in_features))
        probs = F.softmax(logits, dim=-1)
        return probs

class Critic(nn.Module):
    def __init__(self, state_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )

    def forward(self, state):
        return self.net(state).squeeze(-1)

class Reward(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim + action_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )
    def forward(self, state, action):
        input = torch.cat((state, action), dim=1)
        reward = self.net(input)
        return reward

#%%
if __name__ == '__main__':
    actor = Actor(state_dim=4, action_dim=2)
    env = gym.make('CartPole-v1')
    state, _ = env.reset()
    state = torch.tensor(state, dtype=torch.float32)
    policy = actor(state)
