import gymnasium as gym
import torch.optim as optim
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import random
from torch.distributions import Normal
from collections import deque
import matplotlib.pyplot as plt
"""
Global constants
"""
SEED = 42
GAMMA = 0.99
GAE_LAMBDA = 0.95
CLIP_EPS = 0.2
LR = 3e-4
ENTROPY_COEF = 0.0
VF_COEF = 0.5
MAX_GRAD_NORM = 0.5
BATCH_SIZE = 2048
MINI_BATCH_SIZE = 64
PPO_EPOCHS = 10
MAX_STEPS = 1000

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

STATE_DIM = 11
ACTION_DIM = 3
ACTION_HIGH = torch.FloatTensor(np.ones(ACTION_DIM))
ACTION_LOW = - torch.FloatTensor(np.ones(ACTION_DIM))

#%%
"""
Networks
"""
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, action_dim)
        )
        self.log_std = nn.Parameter(torch.ones(action_dim) * -0.5)

    def forward(self, state):
        mean = self.net(state)
        std = torch.exp(self.log_std)
        return mean, std

class Critic(nn.Module):
    def __init__(self, state_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )

    def forward(self, state):
        return self.net(state).squeeze(-1)

class Reward(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim + action_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )
    def forward(self, state, action):
        input = torch.cat((state, action), dim=1)
        reward = self.net(input)
        return reward
#%%
if __name__ == '__main__':
    actor = Actor(state_dim=STATE_DIM, action_dim=ACTION_DIM)
    count = 0
    for param in actor.parameters():
        count += param.numel()
    state = torch.randn(4,11)
    mean, std = actor(state)
