import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

class PolicyNet(nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim, num_layers):
        super(PolicyNet, self).__init__()
        # self.fc1 = nn.Linear(state_dim, hidden_dim)
        # self.fc2 = nn.Linear(hidden_dim, action_dim)
        self.preprocess = nn.Linear(state_dim, hidden_dim)
        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(nn.Sequential(nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)))
        self.fc = nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = self.preprocess(x)
        for layer in self.layers:
            x = layer(x)
        return F.softmax(self.fc(x), dim=1)
        


class ValueNet(nn.Module):
    def __init__(self, state_dim, hidden_dim, num_layers):
        super(ValueNet, self).__init__()
        self.preprocess = nn.Linear(state_dim, hidden_dim)
        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(nn.Sequential(nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)))
        self.fc = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = self.preprocess(x)
        for layer in self.layers:
            x = layer(x)
        return F.softmax(self.fc(x), dim=1)


class ActorCritic:
    def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,
                 gamma, device, num_layers):
        self.actor = PolicyNet(state_dim, hidden_dim, action_dim, num_layers).to(device)
        self.critic = ValueNet(state_dim, hidden_dim, num_layers).to(device)

        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr) 
        self.gamma = gamma
        self.device = device

    def take_action(self, state):
        state = torch.tensor([state], dtype=torch.float).to(self.device)

        probs = self.actor(state)
        action_dist = torch.distributions.Categorical(probs)
        action = action_dist.sample()
        return action.item()

    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'],
                              dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(
            self.device)
        rewards = torch.tensor(transition_dict['rewards'],
                               dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'],
                                   dtype=torch.float).to(self.device)

        td_target = rewards + self.gamma * self.critic(next_states)
        td_delta = td_target - self.critic(states)  # 时序差分误差
        log_probs = torch.log(self.actor(states).gather(1, actions))
        actor_loss = torch.mean(-log_probs * td_delta.detach())
        critic_loss = torch.mean(
            F.mse_loss(self.critic(states), td_target.detach()))
        self.actor_optimizer.zero_grad()
        self.critic_optimizer.zero_grad()
        actor_loss.backward()  
        critic_loss.backward() 
        self.actor_optimizer.step() 
        self.critic_optimizer.step() 


if __name__ == '__main__':
    policy = PolicyNet(16, 32, 3, 2)

    print(policy)