import torch
import torch.nn as nn
import torch.nn.functional as F
import collections
import numpy as np


class AgentDetectionNetwork(nn.Module):
    """Network to detect if there are other agents in the observation"""
    def __init__(self, observation_shape, hidden_dim=256):
        super(AgentDetectionNetwork, self).__init__()
        
        input_dim = observation_shape[0] if isinstance(observation_shape, tuple) else observation_shape
        self.detector = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
            
    def forward(self, x):
        # Process input format
        if len(x.shape) == 4:
            x = x.permute(0, 3, 1, 2)
        
        return self.detector(x)  # Output is the probability of containing other agents


class AgentEncodingNetwork(nn.Module):
    """Network to encode observation into a feature vector describing other agents"""
    def __init__(self, observation_shape, encoding_dim=16):
        super(AgentEncodingNetwork, self).__init__()
        
        self.encoding_dim = encoding_dim
            
        input_dim = observation_shape[0] if isinstance(observation_shape, tuple) else observation_shape
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, encoding_dim)
        )
            
    def forward(self, x):
        return self.encoder(x)  # Output is the encoded vector


class TransitionCalibrationNetwork(nn.Module):
    """Network to adjust transition parameters based on intention features
    When weighted_intention is smaller, the output is closer to transition_log_sigma;
    When weighted_intention is 0, the output equals transition_log_sigma"""
    def __init__(self, sigma_dim=256, intention_dim=16, hidden_dim=64):
        super(TransitionCalibrationNetwork, self).__init__()
        
        # Feature upscaling network, converts intention_dim to sigma_dim
        self.intention_projector = nn.Sequential(
            nn.Linear(intention_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, sigma_dim),
            nn.Sigmoid()  # 将输出限制在0-1之间，作为权重系数
        )
        
        # Modulation network, encodes intention as adjustment
        self.modulation_network = nn.Sequential(
            nn.Linear(intention_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, sigma_dim),
            nn.Tanh()  # 输出-1到1之间的调整系数
        )
    
    def forward(self, transition_log_sigma, weighted_intention):
        """
        Args:
            transition_log_sigma: 转移模型的log_sigma [batch, time_steps, sigma_dim]
            weighted_intention: 加权意图特征 [batch, time_steps, intention_dim]
            
        Returns:
            calibrated_log_sigma: 校准后的log_sigma [batch, time_steps, sigma_dim]
        """
        # Calculate influence strength (closer to 0, weight is closer to 1)
        intention_magnitude = torch.norm(weighted_intention, dim=-1, keepdim=True)
        
        # Project from intention_dim to sigma_dim
        attention_weights = self.intention_projector(weighted_intention)
        modulation = self.modulation_network(weighted_intention)
        
        # Calculate blend factor based on intention_magnitude
        # When intention_magnitude is 0, blend_factor is 0
        # When intention_magnitude increases, blend_factor increases
        blend_factor = 1.0 - torch.exp(-intention_magnitude)
        
        # Calculate calibrated log_sigma
        # When blend_factor is 0 (i.e., weighted_intention is 0), output equals transition_log_sigma
        # When blend_factor increases, transition_log_sigma is adjusted more
        calibrated_log_sigma = transition_log_sigma + blend_factor * attention_weights * modulation
        
        return calibrated_log_sigma


class AgentMemoryModule(nn.Module):
    """Other agent memory module, containing detection, encoding, and memory storage functions"""
    def __init__(self, observation_shape, detection_threshold=0.5, 
                 encoding_dim=256, hidden_dim=256, device=None):
        super(AgentMemoryModule, self).__init__()
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
        self.observation_shape = observation_shape
        self.detection_threshold = detection_threshold
        
        # Detection and encoding networks
        self.detection_network = AgentDetectionNetwork(observation_shape, hidden_dim).to(self.device)
        self.encoding_network = AgentEncodingNetwork(observation_shape, encoding_dim=encoding_dim).to(self.device)
        
        # Add calibration network - for processing feature and reward calibration relationship
        self.calibration_network = nn.Sequential(
            nn.Linear(encoding_dim + 1, hidden_dim), 
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, encoding_dim)
        ).to(self.device)
        
        # Add transition adjustment network
        self.transition_calibration = TransitionCalibrationNetwork(
            sigma_dim=hidden_dim,  # Dimension of transition log_sigma
            intention_dim=encoding_dim,  # Dimension of intention feature
            hidden_dim=hidden_dim//2    
        ).to(self.device)


    def forward(self, intention_feature, detection_prob, use_calibration, current_feature):
        # reshape input data
        n_agents = use_calibration["num_agents"]
        bs, time_steps, feature_dim = intention_feature.shape[0]//n_agents, intention_feature.shape[1], intention_feature.shape[2]
        position = use_calibration["position"]

        intention_feature = intention_feature.reshape(bs, time_steps, n_agents, -1)
        detection_prob = detection_prob.reshape(bs, time_steps, n_agents, 1)
        current_feature = current_feature.reshape(bs, time_steps, n_agents, -1)
        position = position.reshape(bs, time_steps, n_agents, -1)
        
        # intention modeling loss:
        eye_mask = (1 - torch.eye(n_agents, device=intention_feature.device)).unsqueeze(0).unsqueeze(0)  # [1, 1, n_agents, n_agents]
        expanded_mask = eye_mask.expand(bs, time_steps, n_agents, n_agents)
        feature_mask = expanded_mask.unsqueeze(-1).expand(bs, time_steps, n_agents, n_agents, intention_feature.shape[-1])
    
        expanded_feature = current_feature.unsqueeze(2).expand(bs, time_steps, n_agents, n_agents, -1)
        masked_features = expanded_feature * feature_mask
        other_agents_sum = masked_features.sum(dim=3)  # [bs, time_steps, n_agents, feature_dim]
        agents_count = expanded_mask.sum(dim=3).unsqueeze(-1)  # [bs, time_steps, n_agents, 1]
        agents_count = torch.clamp(agents_count, min=1.0)  
        other_agents_mean = other_agents_sum / agents_count  
        intention_diff = F.mse_loss(intention_feature, other_agents_mean, reduction='none')  
        masked_diff = intention_diff * (detection_prob > self.detection_threshold).float()
        intention_loss = masked_diff.mean(dim=-1).unsqueeze(-1)  # [bs, time_steps, n_agents, 1]

        # detection accuracy loss:
        pos1 = position.unsqueeze(3)  
        pos2 = position.unsqueeze(2)
        dist = torch.sqrt(torch.sum((pos1 - pos2) ** 2, dim=-1) + 1e-8)  # [bs, time_steps, n_agents, n_agents]
        self_mask = torch.eye(n_agents, device=position.device).unsqueeze(0).unsqueeze(0)  # [1, 1, n_agents, n_agents]
        inf_mask = self_mask * 1e6
        masked_dist = dist + inf_mask
        min_dist = torch.min(masked_dist, dim=3, keepdim=True)[0]  
        threshold = 0.02
        dist_label = (min_dist > threshold).float()  # [bs, time_steps, n_agents, 1]
        detection_loss = F.binary_cross_entropy(detection_prob, dist_label, reduction='none')  # [bs, time_steps, n_agents, 1]

        loss = intention_loss + detection_loss

        # Return both losses
        return loss