import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import torch_scatter
from torch_geometric.nn import MessagePassing, GATConv
from torch_geometric.utils import dense_to_sparse

class GraphMemoryModule(nn.Module):
    """图记忆模块，为每个智能体生成基于意图的图结构表示"""
    def __init__(self, obs_dim, node_dim=128, edge_dim=16, hidden_dim=256, time_attn_heads=4):
        super(GraphMemoryModule, self).__init__()
        
        self.obs_dim = obs_dim
        self.node_dim = node_dim
        self.edge_dim = edge_dim
        self.hidden_dim = hidden_dim
        
        # 节点特征提取网络
        self.node_encoder = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, node_dim)
        )
        
        # 边特征生成网络
        self.edge_encoder = nn.Sequential(
            nn.Linear(2, hidden_dim // 4),
            nn.ReLU(),
            nn.Linear(hidden_dim // 4, edge_dim)
        )
        
        # 时间注意力层
        self.time_attention = nn.MultiheadAttention(
            embed_dim=node_dim,
            num_heads=time_attn_heads,
            batch_first=True
        )
        
        # 智能体间注意力
        self.agent_query_proj = nn.Linear(obs_dim, node_dim)
        self.agent_key_proj = nn.Linear(obs_dim, node_dim)
        self.agent_value_proj = nn.Linear(obs_dim, node_dim)
        
        # 合并层
        self.combine_layer = nn.Sequential(
            nn.Linear(2 * node_dim, node_dim),
            nn.ReLU(),
            nn.Linear(node_dim, node_dim)
        )
        
        # 位置预测网络
        self.position_predictor = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 2)
        )
    
    def pos_loss_compute(self, pred_positions, true_positions):
        return F.mse_loss(pred_positions, true_positions)

    def forward(self, obs, num_nodes, positions=None, use_pos=True):
        """
        Args:
            obs: 智能体观测 [batch_size, time_step, num_agents, obs_dim]
            num_nodes: 每个智能体图中的节点数
            positions: 智能体位置信息 [batch_size, time_step, num_agents, 2]
            use_pos: 是否使用输入的位置
        """
        batch_size, time_steps, num_agents, obs_dim = obs.shape
        device = obs.device
        
        # 如果需要预测位置或位置未提供
        if not use_pos or positions is None:
            flat_obs = obs.reshape(-1, obs_dim)
            predicted_positions = self.position_predictor(flat_obs)
            positions = predicted_positions.reshape(batch_size, time_steps, num_agents, 2)
        
        # 1. 批量编码所有智能体的观测
        flat_obs = obs.reshape(batch_size * time_steps * num_agents, obs_dim)
        encoded_obs = self.node_encoder(flat_obs)
        encoded_obs = encoded_obs.reshape(batch_size, time_steps, num_agents, self.node_dim)
        
        # 2. 为每个智能体构建自身节点和其他智能体的节点
        nodes = torch.zeros(batch_size, time_steps, num_agents, num_nodes, self.node_dim, device=device)
        nodes[:, :, :, 0, :] = encoded_obs
        
        # 2.2 计算智能体间的注意力
        queries = self.agent_query_proj(obs)
        keys = self.agent_key_proj(obs)
        values = self.agent_value_proj(obs)
        
        q_reshape = queries.reshape(batch_size * time_steps, num_agents, self.node_dim)
        k_reshape = keys.reshape(batch_size * time_steps, num_agents, self.node_dim)
        v_reshape = values.reshape(batch_size * time_steps, num_agents, self.node_dim)
        
        attention_scores = torch.bmm(q_reshape, k_reshape.transpose(1, 2)) / (self.node_dim ** 0.5)
        
        mask = torch.eye(num_agents, device=device).unsqueeze(0).expand(batch_size * time_steps, -1, -1)
        attention_scores = attention_scores.masked_fill(mask == 1, -1e9)
        
        attention_weights = F.softmax(attention_scores, dim=2)
        weighted_values = torch.bmm(attention_weights, v_reshape)
        weighted_values = weighted_values.reshape(batch_size, time_steps, num_agents, self.node_dim)
        
        # 2.3 为每个智能体填充表示其他智能体的节点
        idx_map = torch.ones(num_agents, num_agents, dtype=torch.bool, device=device)
        idx_map.fill_diagonal_(False)
        
        for n in range(num_agents):
            agent_obs = encoded_obs[:, :, n:n+1].expand(-1, -1, num_agents-1, -1)
            other_values = weighted_values[:, :, n:n+1, :].expand(-1, -1, num_agents-1, -1)
            combined = torch.cat([agent_obs, other_values], dim=-1)
            combined = self.combine_layer(combined)
            nodes[:, :, n, 1:, :] = combined
        
        # 3. 时间注意力
        nodes_time_reshape = nodes.permute(0, 2, 3, 1, 4).reshape(batch_size * num_agents * num_nodes, time_steps, self.node_dim)
        
        position_ids = torch.arange(time_steps, dtype=torch.long, device=device)
        position_encoding = torch.zeros(time_steps, self.node_dim, device=device)
        div_term = torch.exp(torch.arange(0, self.node_dim, 2, device=device) * -(math.log(10000.0) / self.node_dim))
        position_encoding[:, 0::2] = torch.sin(position_ids.unsqueeze(1) * div_term)
        position_encoding[:, 1::2] = torch.cos(position_ids.unsqueeze(1) * div_term)
        
        nodes_with_pos = nodes_time_reshape + position_encoding.unsqueeze(0)
        causal_mask = torch.triu(torch.ones(time_steps, time_steps, device=device) * float('-inf'), diagonal=1)
        
        attended_nodes, _ = self.time_attention(
            nodes_with_pos, nodes_with_pos, nodes_with_pos,
            attn_mask=causal_mask,
            need_weights=False
        )
        
        attended_nodes = attended_nodes.reshape(batch_size, num_agents, num_nodes, time_steps, self.node_dim)
        attended_nodes = attended_nodes.permute(0, 3, 1, 2, 4)
        nodes = nodes + attended_nodes
        
        # 4. 生成边特征
        edges = None
        edge_index = None
        
        edges = torch.zeros(batch_size, time_steps, num_agents, num_nodes-1, self.edge_dim, device=device)
        edge_index = torch.zeros(batch_size, time_steps, num_agents, num_nodes-1, 2, dtype=torch.long, device=device)
        
        for n in range(num_agents):
            agent_pos = positions[:, :, n:n+1, :]
            rel_positions = positions - agent_pos
            other_indices = torch.arange(num_agents, device=device)[idx_map[n]]
            rel_pos_others = rel_positions[:, :, other_indices, :]
            flat_rel_pos = rel_pos_others.reshape(-1, 2)
            flat_edge_features = self.edge_encoder(flat_rel_pos)
            edge_features = flat_edge_features.reshape(batch_size, time_steps, num_nodes-1, self.edge_dim)
            edges[:, :, n, :, :] = edge_features
            edge_index[:, :, n, :, 0] = 0
            edge_index[:, :, n, :, 1] = torch.arange(1, num_nodes, device=device).unsqueeze(0).unsqueeze(0).expand(batch_size, time_steps, -1)
        
        return {
            "nodes": nodes,
            "edges": edges,
            "edge_index": edge_index,
            "predicted_positions": positions if not use_pos or positions is None else None
        }, positions


class GNNProcessor(nn.Module):
    """处理GraphMemoryModule输出的图结构"""
    def __init__(self, node_dim=128, edge_dim=16, output_dim=512, hidden_dim=256, num_gnn_layers=2, attention_heads=4):
        super(GNNProcessor, self).__init__()
        
        self.node_dim = node_dim
        self.edge_dim = edge_dim
        self.output_dim = output_dim
        self.num_gnn_layers = num_gnn_layers
        
        self.edge_proj = nn.Linear(edge_dim, hidden_dim)
        
        self.gnn_layers = nn.ModuleList()
        for i in range(num_gnn_layers):
            in_channels = node_dim if i == 0 else hidden_dim
            self.gnn_layers.append(
                GATConv(
                    in_channels=in_channels, 
                    out_channels=hidden_dim // attention_heads,
                    heads=attention_heads,
                    concat=True,
                    edge_dim=edge_dim if i == 0 else hidden_dim
                )
            )
        
        self.output_proj = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, graph_data):
        """处理图数据并返回意图特征"""
        nodes = graph_data["nodes"]
        edges = graph_data["edges"]
        edge_index = graph_data["edge_index"]
        
        batch_size, time_steps, num_agents, num_nodes, node_dim = nodes.shape
        device = nodes.device
        
        if edges is None:
            avg_nodes = nodes.mean(dim=3)
            return self.output_proj(avg_nodes)
        
        flat_size = batch_size * time_steps * num_agents
        flat_nodes = nodes.reshape(flat_size, num_nodes, node_dim)
        flat_edges = edges.reshape(flat_size, num_nodes-1, self.edge_dim)
        flat_edge_index = edge_index.reshape(flat_size, num_nodes-1, 2)
        
        output = torch.zeros(flat_size, self.output_dim, device=device)
        
        for i in range(flat_size):
            edge_index_i = flat_edge_index[i].t().contiguous()
            edge_attr_i = flat_edges[i]
            x = flat_nodes[i]
            
            for gnn_layer in self.gnn_layers:
                x = gnn_layer(x, edge_index_i, edge_attr_i)
                x = F.relu(x)
            
            output[i] = self.output_proj(x[0])
        
        output = output.reshape(batch_size, time_steps, num_agents, self.output_dim)
        
        return output


class PyGGNNProcessor(nn.Module):
    """优化版GNN处理器，使用PyTorch Geometric实现高效批处理"""
    def __init__(self, node_dim=128, edge_dim=16, output_dim=512, hidden_dim=256, num_gnn_layers=2, attention_heads=4):
        super(PyGGNNProcessor, self).__init__()
        
        self.node_dim = node_dim
        self.edge_dim = edge_dim
        self.output_dim = output_dim
        self.num_gnn_layers = num_gnn_layers
        
        self.edge_proj = nn.Linear(edge_dim, hidden_dim)
        
        self.gnn_layers = nn.ModuleList()
        for i in range(num_gnn_layers):
            in_channels = node_dim if i == 0 else hidden_dim
            self.gnn_layers.append(
                CustomGNNLayer(
                    in_channels=in_channels, 
                    out_channels=hidden_dim,
                    edge_dim=edge_dim if i == 0 else hidden_dim
                )
            )
        
        self.output_proj = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, graph_data):
        """处理图数据并返回意图特征"""
        nodes = graph_data["nodes"]
        edges = graph_data["edges"]
        edge_index = graph_data["edge_index"]
        
        batch_size, time_steps, num_agents, num_nodes, node_dim = nodes.shape
        device = nodes.device
        
        if edges is None:
            avg_nodes = nodes.mean(dim=3)
            return self.output_proj(avg_nodes)
        
        flat_nodes = nodes.reshape(-1, num_nodes, node_dim)
        flat_edges = edges.reshape(-1, num_nodes-1, self.edge_dim)
        flat_edge_index = edge_index.reshape(-1, num_nodes-1, 2)
        
        batch_idx = torch.arange(flat_nodes.size(0), device=device)
        batch_idx = batch_idx.repeat_interleave(num_nodes)
        
        x = flat_nodes.reshape(-1, node_dim)
        
        edge_index_list = []
        edge_attr_list = []
        
        for i in range(flat_nodes.size(0)):
            local_edge_index = flat_edge_index[i].t()
            global_edge_index = local_edge_index.clone()
            global_edge_index += i * num_nodes
            
            edge_index_list.append(global_edge_index)
            edge_attr_list.append(flat_edges[i])
        
        edge_index_tensor = torch.cat(edge_index_list, dim=1)
        edge_attr_tensor = torch.cat(edge_attr_list, dim=0)  
        for i, gnn_layer in enumerate(self.gnn_layers):
            if i > 0:
                edge_attr_tensor = self.edge_proj(edge_attr_tensor)
                
            x = gnn_layer(x, edge_index_tensor, edge_attr_tensor, batch_idx)
            x = F.relu(x)
        
        x = x.reshape(batch_size * time_steps * num_agents, num_nodes, -1)
        central_features = x[:, 0]  
        output = self.output_proj(central_features)
        output = output.reshape(batch_size, time_steps, num_agents, self.output_dim)
        
        return output


class CustomGNNLayer(MessagePassing):
    def __init__(self, in_channels, out_channels, edge_dim, aggr='add'):
        super(CustomGNNLayer, self).__init__(aggr=aggr)
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.edge_dim = edge_dim
        
        self.message_mlp = nn.Sequential(
            nn.Linear(in_channels + edge_dim, out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, out_channels)
        )
        self.update_mlp = nn.Sequential(
            nn.Linear(in_channels + out_channels, out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, out_channels)
        )
    
    def forward(self, x, edge_index, edge_attr, batch=None, size=None):
        if isinstance(size, int):
            size = None
        return self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)
    
    def message(self, x_j, edge_attr):
        inputs = torch.cat([x_j, edge_attr], dim=-1)
        return self.message_mlp(inputs)
    
    def update(self, aggr_out, x):
        inputs = torch.cat([x, aggr_out], dim=-1)
        return self.update_mlp(inputs)
