import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import json
import random
"""
Global constants
"""
SEED = 42
MAX_STEPS = 500

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

STATE_DIM = 17
ACTION_DIM = 6
ACTION_HIGH = torch.FloatTensor(np.ones(ACTION_DIM))
ACTION_LOW = - torch.FloatTensor(np.ones(ACTION_DIM))

#%%
"""
Networks
"""
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_dim)
        self._init_weights()
        self.log_std = nn.Parameter(torch.ones(action_dim) * -0.5)

    def _init_weights(self, mean=0.0, std=1.0):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=mean, std=std)
                nn.init.normal_(m.bias, mean=mean, std=std)  # or normal_ if you want biases randomized too

    def forward(self, state):
        x = F.tanh(self.fc1(state) / torch.sqrt(torch.tensor(self.fc1.in_features)))
        x = F.tanh(self.fc2(x) / torch.sqrt(torch.tensor(self.fc2.in_features)))
        logits = self.fc3(x) / torch.sqrt(torch.tensor(self.fc3.in_features))
        std = torch.exp(self.log_std)
        return logits, std

class Critic(nn.Module):
    def __init__(self, state_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )

    def forward(self, state):
        return self.net(state).squeeze(-1)

class Reward(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim + action_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )
    def forward(self, state, action):
        input = torch.cat((state, action), dim=1)
        reward = self.net(input)
        return reward
#%%
if __name__ == '__main__':
    actor = Actor(state_dim=STATE_DIM, action_dim=ACTION_DIM)
    env = gym.make('HalfCheetah-v5')
    state, _ = env.reset()
    state = torch.tensor(state, dtype=torch.float32)
    mean, std = actor(state)
