import os
from pyexpat import features

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical


class PolicyNet(torch.nn.Module):
    ''' 策略网络是一个两层 MLP '''

    def __init__(self):
        super(PolicyNet, self).__init__()
        self.actor_fc = nn.Sequential(
            nn.Linear(42, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 19),
        )

    def forward(self, state):
        if not isinstance(state, torch.Tensor):
            state = torch.FloatTensor(state)
        if len(state.shape) == 1:
            state = state.unsqueeze(0)

        action_probs = self.actor_fc(state)
        action_probs = F.softmax(action_probs, dim=-1)

        return action_probs


class VNet(torch.nn.Module):
    ''' 价值网络是一个两层 MLP '''

    def __init__(self):
        super(VNet, self).__init__()
        self.critic_fc = nn.Sequential(
            nn.Linear(42, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )

    def forward(self, state):
        if not isinstance(state, torch.Tensor):
            state = torch.FloatTensor(state)
        if len(state.shape) == 1:
            state = state.unsqueeze(0)

        state_value = self.critic_fc(state)
        return state_value


class Distributed_Agent(torch.nn.Module):
    def __init__(self, dic_path):
        super().__init__()
        self.dic_path = dic_path

        self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        self.gamma = 0.98
        self.lmbda = 0.95
        self.epochs = 10
        self.eps = 0.2
        self.batch_size = 20

        self.actor = PolicyNet().to(self.device)
        self.critic = VNet().to(self.device)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=0.001)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=0.001)

    def choose_action_with_features(self, x_state, x_freq):
        x_state = torch.tensor(x_state, dtype=torch.float32).to(self.device)
        x_freq = x_freq.repeat(3, 1)

        with torch.no_grad():
            state = torch.cat([x_state, x_freq], dim=1)
            action_probs = self.actor(state)
            state_value = self.critic(state)
            dist = Categorical(action_probs)
            action = dist.sample()
            action_list = action.cpu().numpy().tolist()

            attention_input = torch.cat((action_probs, state_value), dim=-1)

        return action_list, attention_input

    def compute_advantage(self, gamma, lmbda, td_delta):
        ''' 广义优势估计 GAE '''
        td_delta = td_delta.detach().numpy()
        advantage_list = []
        advantage = 0.0
        for delta in td_delta[::-1]:
            advantage = gamma * lmbda * advantage + delta
            advantage_list.append(advantage)
        advantage_list.reverse()
        return torch.tensor(np.array(advantage_list), dtype=torch.float)

    def update(self, transition_dict):
        x_states = torch.tensor(np.array(transition_dict['states']), dtype=torch.float).to(self.device)
        features = torch.cat(transition_dict['features'], dim=0).to(self.device)
        states = torch.cat((x_states, features), dim=1)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)
        rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1).to(self.device)
        x_next_states = torch.tensor(np.array(transition_dict['next_states']), dtype=torch.float).to(self.device)
        next_states = torch.cat((x_next_states, features), dim=1)
        dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1, 1).to(self.device)

        # 计算TD目标和优势函数
        td_target = rewards + self.gamma * self.critic(next_states) * (1 - dones)
        td_delta = td_target - self.critic(states)
        advantage = self.compute_advantage(self.gamma, self.lmbda, td_delta.cpu()).to(self.device)

        # 使用旧策略计算动作概率（重要：更新前固定旧策略）
        old_probs = self.actor(states).gather(1, actions)
        old_log_probs = torch.log(old_probs).detach()  # 分离计算图

        # 获取总样本数并初始化批次参数
        n_samples = states.size(0)
        batch_size = self.batch_size if self.batch_size > n_samples else n_samples
        beta = 0.01  # 熵系数

        for _ in range(self.epochs):
            # 生成随机排列的索引
            perm = torch.randperm(n_samples).to(self.device)

            # 小批次更新
            for i in range(0, n_samples, batch_size):
                # 获取当前批次的索引
                batch_indices = perm[i:i + batch_size]

                # 获取批次数据
                batch_states = states[batch_indices]
                batch_actions = actions[batch_indices]
                batch_old_log_probs = old_log_probs[batch_indices]
                batch_advantage = advantage[batch_indices]
                batch_td_target = td_target[batch_indices]

                # 计算新策略的概率
                new_probs = self.actor(batch_states).gather(1, batch_actions)
                new_log_probs = torch.log(new_probs)

                # 计算概率比和损失
                ratio = torch.exp(new_log_probs - batch_old_log_probs)
                surr1 = ratio * batch_advantage
                surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * batch_advantage

                # 计算熵正则项
                action_probs = self.actor(batch_states)
                entropy = torch.sum(-action_probs * torch.log(action_probs + 1e-10), dim=1).mean()

                # 计算actor和critic损失
                actor_loss = -torch.min(surr1, surr2).mean() - beta * entropy
                critic_loss = F.mse_loss(self.critic(batch_states), batch_td_target.detach())

                # 参数更新
                self.actor_optimizer.zero_grad()
                self.critic_optimizer.zero_grad()
                actor_loss.backward()
                critic_loss.backward()
                self.actor_optimizer.step()
                self.critic_optimizer.step()

    def save_network(self, file_name):
        torch.save(self.actor.state_dict(),
                   os.path.join(self.dic_path["PATH_TO_MODEL"], f"actor_{file_name}.pth"))
        torch.save(self.critic.state_dict(),
                   os.path.join(self.dic_path["PATH_TO_MODEL"], f"critic_{file_name}.pth"))
        print(f"Successfully saved model {file_name}")

    def load_network(self, file_name):
        self.actor.load_state_dict(
            torch.load(os.path.join(self.dic_path["PATH_TO_MODEL"], f"actor_{file_name}.pth")))
        self.critic.load_state_dict(
            torch.load(os.path.join(self.dic_path["PATH_TO_MODEL"], f"critic_{file_name}.pth")))
        print(f"Successfully loaded model {file_name}")














