import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import os
import pickle
from torch.distributions import Categorical
import time


def get_components(step):
    t = step + 1
    if t <= 3:
        return [0] * 19

    log2_t = int(np.floor(np.log2(t - 1)))
    upper_limit = min(6, max(0, log2_t - 1))

    power_components = [2 ** (k) for k in range(upper_limit + 1)]
    num_zeros = 19 - len(power_components)
    components = [0] * num_zeros + power_components
    return components


def get_frequency_features(history_path, step=None):
    states_by_time = []

    with open(history_path, 'rb') as f:
        while True:
            try:
                state_data = pickle.load(f)
                state = state_data['state']
                states_by_time.append(state)
            except EOFError:
                break

    num_agents = 3
    max_freq_steps = 64
    epsilon = 1e-10
    padded_features = np.zeros((num_agents, max_freq_steps, 26)) + epsilon

    if len(states_by_time) == 0:
        components = [0] * 19
        return {
            'features': padded_features,
            'components': components,
            'combined_features': padded_features
        }

    states_by_agent = list(map(list, zip(*states_by_time)))
    freq_features = []
    for agent_states in states_by_agent:
        agent_states = np.array(agent_states)
        freq_data = np.fft.fft(agent_states, axis=0)
        magnitude_spectrum = np.abs(freq_data)
        freq_features.append(magnitude_spectrum)

    combined_features = np.stack(freq_features, axis=0)
    components = get_components(step)
    n_components = max(components)

    padded_features[:, :n_components, :] = combined_features[:, :n_components, :]

    return {
        'features': padded_features,
        'components': components,
        'combined_features': combined_features,
    }


def compute_weighted_reward(agent_inputs, central_input):
    embed_dim = 20
    num_heads = 4

    device = central_input.device

    keys = agent_inputs
    keys = keys.unsqueeze(1)

    # central_input 的形状为 (1, 20)
    query = central_input.unsqueeze(0)

    # 使用 MultiheadAttention 进行注意力计算
    multihead_attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads).to(device)

    # 执行多头注意力操作
    attn_output, weights = multihead_attn(query, keys, keys)

    # 提取权重并返回
    final_weights = weights.squeeze().detach().cpu().numpy()

    return final_weights


class PolicyNet(torch.nn.Module):
    ''' 策略网络是一个两层 MLP '''

    def __init__(self):
        super(PolicyNet, self).__init__()
        self.actor_fc = nn.Sequential(
            nn.Linear(4992, 1024),
            nn.ReLU(),
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 19),
        )

    def forward(self, state):
        if not isinstance(state, torch.Tensor):
            state = torch.FloatTensor(state)
        if len(state.shape) == 1:
            state = state.unsqueeze(0)

        batch_size, num_agents, freq_steps, _ = state.shape
        x = state.reshape(batch_size, -1)

        action_probs = self.actor_fc(x)
        action_probs = F.softmax(action_probs, dim=-1)

        return action_probs


# batch_size, num_agents, freq_steps, _ = x.shape
#         x = x.reshape(batch_size, -1)
#         x = self.fc(x)
class VNet(torch.nn.Module):
    ''' 价值网络是一个两层 MLP '''

    def __init__(self):
        super(VNet, self).__init__()
        self.critic_fc = nn.Sequential(
            nn.Linear(4992, 1024),
            nn.ReLU(),
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )

    def forward(self, state):
        if not isinstance(state, torch.Tensor):
            state = torch.FloatTensor(state)
        if len(state.shape) == 1:
            state = state.unsqueeze(0)

        batch_size, num_agents, freq_steps, _ = state.shape
        x = state.reshape(batch_size, -1)

        state_value = self.critic_fc(x)
        return state_value


class FrequencyFeatureProcessor:
    def __init__(self):
        super(FrequencyFeatureProcessor, self).__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # 使用全连接层处理频域特征
        self.fc = nn.Sequential(
            nn.Linear(4992, 1024),
            nn.ReLU(),
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 16),
        ).to(self.device)

    def process_freq_features(self, freq_features):
        if not isinstance(freq_features, torch.Tensor):
            freq_features = torch.FloatTensor(freq_features).to(self.device)

        # 处理形状为 (batch_size, num_agents, freq_steps, 18) 的输入
        if len(freq_features.shape) == 5:
            batch_size, extra_dim, num_agents, freq_steps, _ = freq_features.shape
            x = freq_features.reshape(batch_size * extra_dim, -1)
            x = self.fc(x)
            return x
        elif len(freq_features.shape) == 4:
            batch_size, num_agents, freq_steps, _ = freq_features.shape
            x = freq_features.reshape(batch_size, -1)
            x = self.fc(x)
            return x
        else:
            batch_size, freq_steps, _ = freq_features.shape
            x = freq_features.reshape(batch_size, -1)  # 展平后输入 fc
            x = self.fc(x)
            return x


class CentralAgent:
    def __init__(self, dic_agent_conf, dic_path, cnt_round):
        self.dic_path = dic_path
        self.dic_agent_conf = dic_agent_conf
        self.cnt_round = cnt_round
        self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        self.gamma = 0.98
        self.lmbda = 0.95
        self.epochs = 10
        self.eps = 0.2
        self.batch_size = 20

        self.actor = PolicyNet().to(self.device)
        self.critic = VNet().to(self.device)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=0.001)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=0.001)
        self.freq_processor = FrequencyFeatureProcessor()

    def choose_history_length(self, history_path, current_step):
        result = get_frequency_features(history_path, current_step)
        freq_tensor = torch.FloatTensor(result['features']).unsqueeze(0).to(self.device)

        with torch.no_grad():
            action_probs = self.actor(freq_tensor)
            state_values = self.critic(freq_tensor)
            dist = Categorical(action_probs)
            action = dist.sample()

            max_freq_steps = 64
            epsilon = 1e-10
            num_agents = 3
            selected_features = np.zeros((num_agents, max_freq_steps, 26)) + epsilon
            actual_steps = min(result['combined_features'].shape[1], result['components'][action])
            selected_features[:, :actual_steps, :] = result['combined_features'][:, :actual_steps, :]
            selected_features = np.expand_dims(selected_features, axis=0)
            processed_features = self.freq_processor.process_freq_features(selected_features)

            query = torch.cat([action_probs, state_values], dim=-1)

        return processed_features, action, query, result['features']

    def compute_advantage(self, gamma, lmbda, td_delta):
        ''' 广义优势估计 GAE '''
        td_delta = td_delta.detach().numpy()
        advantage_list = []
        advantage = 0.0
        for delta in td_delta[::-1]:
            advantage = gamma * lmbda * advantage + delta
            advantage_list.append(advantage)
        advantage_list.reverse()
        return torch.tensor(np.array(advantage_list), dtype=torch.float)

    def update(self, transition_dict):
        states = torch.tensor(np.array(transition_dict['states']), dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)
        rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(np.array(transition_dict['next_states']), dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1, 1).to(self.device)

        # 计算TD目标和优势函数
        td_target = rewards + self.gamma * self.critic(next_states) * (1 - dones)
        td_delta = td_target - self.critic(states)
        advantage = self.compute_advantage(self.gamma, self.lmbda, td_delta.cpu()).to(self.device)

        # 使用旧策略计算动作概率（重要：更新前固定旧策略）
        old_probs = self.actor(states).gather(1, actions)
        old_log_probs = torch.log(old_probs).detach()  # 分离计算图

        # 获取总样本数并初始化批次参数
        n_samples = states.size(0)
        batch_size = self.batch_size if self.batch_size > n_samples else n_samples
        beta = 0.01  # 熵系数

        for _ in range(self.epochs):
            # 生成随机排列的索引
            perm = torch.randperm(n_samples).to(self.device)

            # 小批次更新
            for i in range(0, n_samples, batch_size):
                # 获取当前批次的索引
                batch_indices = perm[i:i + batch_size]

                # 获取批次数据
                batch_states = states[batch_indices]
                batch_actions = actions[batch_indices]
                batch_old_log_probs = old_log_probs[batch_indices]
                batch_advantage = advantage[batch_indices]
                batch_td_target = td_target[batch_indices]

                # 计算新策略的概率
                new_probs = self.actor(batch_states).gather(1, batch_actions)
                new_log_probs = torch.log(new_probs)

                # 计算概率比和损失
                ratio = torch.exp(new_log_probs - batch_old_log_probs)
                surr1 = ratio * batch_advantage
                surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * batch_advantage

                # 计算熵正则项
                action_probs = self.actor(batch_states)
                entropy = torch.sum(-action_probs * torch.log(action_probs + 1e-10), dim=1).mean()

                # 计算actor和critic损失
                actor_loss = -torch.min(surr1, surr2).mean() - beta * entropy
                critic_loss = F.mse_loss(self.critic(batch_states), batch_td_target.detach())

                # 参数更新
                self.actor_optimizer.zero_grad()
                self.critic_optimizer.zero_grad()
                actor_loss.backward()
                critic_loss.backward()
                self.actor_optimizer.step()
                self.critic_optimizer.step()

    def save_central_network(self, file_name):
        torch.save(self.actor.state_dict(),
                   os.path.join(self.dic_path["PATH_TO_MODEL"], f"{file_name}.pth"))
        torch.save(self.critic.state_dict(),
                   os.path.join(self.dic_path["PATH_TO_MODEL"], f"critic_{file_name}.pth"))
        print(f"Successfully saved model {file_name}")

    def load_central_network(self, file_name):
        self.actor.load_state_dict(
            torch.load(os.path.join(self.dic_path["PATH_TO_MODEL"], f"actor_{file_name}.pth")))
        self.critic.load_state_dict(
            torch.load(os.path.join(self.dic_path["PATH_TO_MODEL"], f"critic_{file_name}.pth")))
        print(f"Successfully loaded model {file_name}")


