import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import json
import random
"""
Global constants
"""
SEED = 42
MAX_STEPS = 500

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)


STATE_DIM = 11
ACTION_DIM = 3
ACTION_HIGH = torch.FloatTensor(np.ones(ACTION_DIM))
ACTION_LOW = - torch.FloatTensor(np.ones(ACTION_DIM))

#%%
"""
Networks
"""
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_dim)
        self._init_weights()
        self.log_std = nn.Parameter(torch.ones(action_dim) * -0.5)

    def _init_weights(self, mean=0.0, std=1.0):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=mean, std=std)
                nn.init.normal_(m.bias, mean=mean, std=std)  # or normal_ if you want biases randomized too

    def forward(self, state):
        x = F.tanh(self.fc1(state) / torch.sqrt(torch.tensor(self.fc1.in_features)))
        x = F.tanh(self.fc2(x) / torch.sqrt(torch.tensor(self.fc2.in_features)))
        logits = self.fc3(x) / torch.sqrt(torch.tensor(self.fc3.in_features))
        std = torch.exp(self.log_std)
        return logits, std

    def sample(self, state):
        """
        Reparameterized sample for SAC:
        u ~ N(mu, std), a = tanh(u)
        Returns:
          action: (batch, action_dim)
          log_prob: (batch, 1)
        """
        mean, std = self.forward(state)  # mean: (B, A), std: (A,)
        std = std.expand_as(mean)
        dist = Normal(mean, std)
        eps = torch.randn_like(mean)
        pre_tanh = mean + eps * std
        action = torch.tanh(pre_tanh)

        # Log prob with tanh squashing correction
        log_prob = dist.log_prob(pre_tanh).sum(dim=-1, keepdim=True)
        log_prob -= torch.log(1 - action.pow(2) + 1e-6).sum(dim=-1, keepdim=True)

        return action, log_prob

    def deterministic(self, state):
        """Deterministic action (for evaluation): a = tanh(mu)."""
        mean, _ = self.forward(state)
        return torch.tanh(mean)

class Critic(nn.Module):
    def __init__(self, state_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )

    def forward(self, state):
        return self.net(state).squeeze(-1)

class Reward(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim + action_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )
    def forward(self, state, action):
        input = torch.cat((state, action), dim=1)
        reward = self.net(input)
        return reward

# For SAC Pretrain
class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 1)
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, state, action):
        # state: (B, S), action: (B, A)
        if state.dim() == 1:
            state = state.unsqueeze(0)
        if action.dim() == 1:
            action = action.unsqueeze(0)
        x = torch.cat([state, action], dim=-1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        q = self.fc3(x)
        return q

class ReplayBuffer:
    def __init__(self, state_dim, action_dim, capacity=int(1e5)):
        self.capacity = capacity
        self.ptr = 0
        self.size = 0

        self.states = np.zeros((capacity, state_dim), dtype=np.float32)
        self.actions = np.zeros((capacity, action_dim), dtype=np.float32)
        self.rewards = np.zeros((capacity, 1), dtype=np.float32)
        self.next_states = np.zeros((capacity, state_dim), dtype=np.float32)
        self.dones = np.zeros((capacity, 1), dtype=np.float32)

    def push(self, state, action, reward, next_state, done):
        idx = self.ptr
        self.states[idx] = state
        self.actions[idx] = action
        self.rewards[idx] = reward
        self.next_states[idx] = next_state
        self.dones[idx] = done

        self.ptr = (self.ptr + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)

    def sample(self, batch_size):
        idxs = np.random.randint(0, self.size, size=batch_size)
        return (
            self.states[idxs],
            self.actions[idxs],
            self.rewards[idxs],
            self.next_states[idxs],
            self.dones[idxs],
        )

    def __len__(self):
        return self.size
#%%
if __name__ == '__main__':
    actor = Actor(state_dim=STATE_DIM, action_dim=ACTION_DIM)
    env = gym.make('Hopper-v5')
    state, _ = env.reset()
    state = torch.tensor(state, dtype=torch.float32)
    mean, std = actor(state)
