import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import numpy as np
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from torch.distributions import Beta

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim=2, hiddens_sac=[128, 128]):
        super(Actor, self).__init__()
        
        self.state_dim = state_dim
        self.action_dim = action_dim
        net_layers = []
        in_dim = state_dim
        for h_dim in hiddens_sac:
            net_layers.append(nn.Linear(in_dim, h_dim))
            net_layers.append(nn.ReLU())
            in_dim = h_dim 
        self.net = nn.Sequential(*net_layers)
        last_hidden_dim = hiddens_sac[-1]
        self.mean = nn.Linear(last_hidden_dim, action_dim)
        self.log_std = nn.Linear(last_hidden_dim, action_dim)
        self._init_weights()
    
    def _init_weights(self):
        """初始化权重 (此函数无需修改，可正确处理动态网络)"""
        for m in self.net.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)
        
        nn.init.xavier_uniform_(self.mean.weight, gain=0.01)
        nn.init.zeros_(self.mean.bias)
        
        nn.init.xavier_uniform_(self.log_std.weight, gain=0.01)
        nn.init.zeros_(self.log_std.bias)
    
    def forward(self, state, deterministic=False, return_log_prob=True):
        """前向传播 (此函数无需修改)"""
        x = self.net(state)
        
        mean = self.mean(x)
        log_std = self.log_std(x)
        log_std = torch.clamp(log_std, -20, 2)
        std = log_std.exp()
        
        if deterministic:
            return torch.sigmoid(mean)
        
        normal = Normal(mean, std)
        x_t = normal.rsample()
        action = torch.sigmoid(x_t)
        
        if return_log_prob:
            log_prob = normal.log_prob(x_t)
            log_prob -= torch.log(action * (1 - action) + 1e-6)
            log_prob = log_prob.sum(1, keepdim=True)
            return action, log_prob
        else:
            return action

    def log_prob(self, state, action):
        """计算对数概率 (此函数无需修改)"""
        x = self.net(state)
        mean = self.mean(x)
        log_std = torch.clamp(self.log_std(x), -20, 2)
        std = log_std.exp()
        normal = Normal(mean, std)

        eps = 1e-6
        a_clamped = torch.clamp(action, eps, 1 - eps)

        x_t = torch.log(a_clamped) - torch.log(1 - a_clamped)

        log_prob = normal.log_prob(x_t)
        log_prob = log_prob - torch.log(a_clamped * (1 - a_clamped) + eps)

        return log_prob.sum(-1, keepdim=True)


class Critic(nn.Module):
    def __init__(self, state_dim, action_dim=2, hiddens_sac=[128, 128]):
        super(Critic, self).__init__()
        
        self.state_dim = state_dim
        self.action_dim = action_dim
        def _create_q_net(hiddens):
            q_layers = []
            in_dim = state_dim + action_dim
            for h_dim in hiddens:
                q_layers.append(nn.Linear(in_dim, h_dim))
                q_layers.append(nn.ReLU())
                in_dim = h_dim
            q_layers.append(nn.Linear(in_dim, 1))
            return nn.Sequential(*q_layers)
        self.q1 = _create_q_net(hiddens_sac)
        self.q2 = _create_q_net(hiddens_sac)
        self._init_weights()
    
    def _init_weights(self):
        """初始化权重 (此函数无需修改)"""
        for net in [self.q1, self.q2]:
            for m in net.modules():
                if isinstance(m, nn.Linear):
                    nn.init.xavier_uniform_(m.weight)
                    nn.init.zeros_(m.bias)
    
    def forward(self, state, action):
        """计算Q值 (此函数无需修改)"""
        sa = torch.cat([state, action], dim=1)
        
        q1 = self.q1(sa)
        q2 = self.q2(sa)
        
        return q1, q2