import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as D
from .agent_memory import AgentMemoryModule
from .graph_memory import GraphMemoryModule, PyGGNNProcessor, GNNProcessor

def reshape_batch_agents(tensor, num_agents=None):
    if num_agents is None:
        num_agents = tensor.shape[2]
    orig_shape = tensor.shape
    reshaped = tensor.reshape(orig_shape[0] * num_agents, orig_shape[1], *orig_shape[3:])
    return reshaped
    
def reshape_agents_to_batch(tensor, batch_size, num_agents, orig_shape=None):
    if orig_shape is not None:
        return tensor.reshape(orig_shape)
    current_shape = tensor.shape
    reshaped = tensor.reshape(batch_size, current_shape[1], num_agents, *current_shape[2:])
    return reshaped

# Simplify SmallConv and TransitionNetwork
class SmallConv(nn.Module):
    def __init__(self, feat_dim=512, name=None):
        super(SmallConv, self).__init__()
        self.name = name
        self.feat_dim = feat_dim
        self.conv = nn.Sequential(
            nn.Conv2d(4, 32, 8, stride=4), nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2), nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1), nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, feat_dim)
        )
    
    def forward(self, x):
        return self.conv(x.permute(0, 3, 1, 2))

class TransitionNetwork(nn.Module):
    def __init__(self, action_dim=4, name=None):
        super(TransitionNetwork, self).__init__()
        self.name = name
        self.network = nn.Sequential(
            nn.Linear(512 + action_dim, 512), nn.LeakyReLU(),
            nn.Linear(512, 512)
        )
    
    def forward(self, inputs):
        features, actions = inputs
        return self.network(torch.cat([features, actions], dim=-1))

def flatten_two_dims(x):
    return x.reshape(x.shape[0] * x.shape[1], *x.shape[2:])


# Simplify CermicModule class initialization
class CermicModule(nn.Module):
    def __init__(self, tau=0.005, loss_var_weight=0.1, loss_l2_weight=1.0, 
                aug=True, feat_dim=512, observation_shape=None, action_dim=4, memory_weight=0.2,
                gamma2=1.0, epsilon=0.1):
        super(CermicModule, self).__init__()
        
        # Basic parameters
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.feat_dim = feat_dim
        self.action_dim = int(action_dim)
        self.observation_shape = observation_shape
        
        # Loss-related parameters
        self.tau = tau
        self.loss_var_weight = loss_var_weight
        self.loss_l2_weight = loss_l2_weight
        self.aug = aug
        self.gamma2 = gamma2
        self.epsilon = epsilon
        self.beta = torch.sqrt(torch.tensor(gamma2 / epsilon))
        
        # Memory module parameters
        self.memory_weight = memory_weight
        
        # Observation statistics
        self.ob_mean = 0.0
        self.ob_std = 1.0
        
        # Calibration factors
        self.calibration_UB = 0.0
        self.calibration_LB = 0.0
        
        # Detect input type
        self.is_image_obs = isinstance(observation_shape, tuple) and len(observation_shape) > 1
        
        # Initialize networks
        self._init_networks()
        
        # Loss and reward records
        self.loss = None
        self.intrinsic_reward = None
        self.loss_info = {}
        self.kl = None
    
    def _init_networks(self):
        # Initialize feature network
        input_dim = self.observation_shape[0] if isinstance(self.observation_shape, tuple) else self.observation_shape
        self.feature_net = nn.Sequential(
            nn.Linear(input_dim, 256), nn.ReLU(),
            nn.Linear(256, 512), nn.ReLU(),
            nn.Linear(512, self.feat_dim)
        )
        self.feature_net_momentum = nn.Sequential(
            nn.Linear(input_dim, 256), nn.ReLU(),
            nn.Linear(256, 512), nn.ReLU(),
            nn.Linear(512, self.feat_dim)
        )

        self.transition_model = TransitionNetwork(action_dim=self.action_dim, name="CERMIC_transition")
        self.generative_model = nn.Sequential(
            nn.Linear(256, 512), nn.LeakyReLU(),
            nn.Linear(512, 1024)
        )
        
        # Initialize memory module
        # 1. Use AgentMemoryModule
        self.agent_memory = AgentMemoryModule(
            observation_shape=self.observation_shape,
            encoding_dim=self.feat_dim,
            device=self.device
        )

        # 2. Use GraphMemoryModule
        # self.graph_memory = GraphMemoryModule(
        #     obs_dim=self.observation_shape[0],
        #     node_dim=self.feat_dim//32,
        #     edge_dim=self.feat_dim//32,
        #     hidden_dim=self.feat_dim//16
        # )
        # self.gnn = PyGGNNProcessor(
        #     node_dim=self.feat_dim//32,
        #     edge_dim=self.feat_dim//32,
        #     hidden_dim=self.feat_dim//16
        # )

        self._initialize_momentum_parameters()

    def forward(self, obs, next_obs, ac, use_calibration=None):
        current_feature = self.feature_net(obs) # torch.Size([5, 600, 512])
        next_feature = self.feature_net_momentum(next_obs).detach()
        next_features = torch.cat((current_feature[:, 1:, :], next_feature), dim=1) # torch.Size([5, 600, 512])
        
        # TODO: GNN_feature & Other group
        if use_calibration is not None:
            num_agents = use_calibration["num_agents"]
            batch_size, time_steps = obs.shape[0] // num_agents, obs.shape[1]
            detection_prob = self.agent_memory.detection_network(obs)
            detection_mask = (detection_prob > self.agent_memory.detection_threshold).float()
            intention_feature = self.agent_memory.encoding_network(obs*detection_mask)*detection_mask
            
            # --- Added by Yiyuan Pan --- 
            # TODO: Graph Memory Implementation.
            # obs_reshaped = reshape_agents_to_batch(obs, batch_size, num_agents)
            # positions = reshape_agents_to_batch(use_calibration["position"], batch_size, num_agents)
            # detection_mask_reshaped = reshape_agents_to_batch(detection_mask, batch_size, num_agents)
            
            # USE_POS = True
            # graph_feature, pred_positions = self.graph_memory(obs_reshaped * detection_mask_reshaped, num_nodes=num_agents, positions=positions, use_pos=USE_POS)
            # graph_intention = self.gnn(graph_feature) * detection_mask_reshaped
            # graph_intention = reshape_batch_agents(graph_intention, num_agents)
            # intention_feature = graph_intention
            # ------- 2025-04-22 --------


            prev_intention_feature = torch.cat([torch.zeros_like(intention_feature[:, 0:1, :]), intention_feature[:, :time_steps-1, :]], dim=1)  
            prev_reward = torch.cat([torch.zeros_like(use_calibration["reward"][:, 0:1, :]), use_calibration["reward"][:, :time_steps-1, :]], dim=1)
            prev_intention_feature = torch.cat([prev_intention_feature, prev_reward], dim=-1)

            calibration_UB, calibration_LB = self._compute_calibration_nce(intention_feature, prev_intention_feature)
            calibration_UB, calibration_LB = calibration_UB*detection_mask, calibration_LB*detection_mask
            weighted_intention_UB = intention_feature*calibration_UB
            weighted_intention_LB = intention_feature*calibration_LB

        # Feature statistics
        obs_mean = current_feature.mean(dim=0, keepdim=True)
        obs_std = current_feature.std(dim=0, keepdim=True) + 0.1
        self.ob_mean = 0.9 * self.ob_mean + 0.1 * obs_mean.mean().item()
        self.ob_std = 0.9 * self.ob_std + 0.1 * obs_std.mean().item()
        
        # Transition model
        transition_out = self.transition_model((current_feature, ac)) # network g(s,a)
        transition_mu, transition_log_sigma = transition_out.chunk(2, dim=-1)
        transition_sigma = torch.exp(transition_log_sigma) + 0.1
        
        # Distribution and KL divergence
        latent_dis = torch.distributions.Normal(transition_mu, transition_sigma) # p(z|s,a)
        prior_dis = torch.distributions.Normal(
            torch.zeros_like(transition_mu, device=self.device),
            torch.ones_like(transition_sigma, device=self.device)
        ) # q(z)
        kl = torch.distributions.kl_divergence(latent_dis, prior_dis)
        kl = kl.sum(dim=-1)
        self.kl = kl
        KL_loss = kl
  
        log_sigma_UB = transition_log_sigma
        log_sigma_LB = transition_log_sigma
        if use_calibration is not None:
            log_sigma_UB = self.agent_memory.transition_calibration(transition_log_sigma, weighted_intention_UB)
            log_sigma_LB = self.agent_memory.transition_calibration(transition_log_sigma, weighted_intention_LB)
        
        var_loss_UB = F.relu(-self.beta - (-2 * log_sigma_UB)-5)
        var_loss_LB = -self.beta - (2 * log_sigma_LB)
        VAR_loss = var_loss_UB.mean(dim=-1) * self.loss_var_weight

        latent = latent_dis.rsample()
        rec_params = self.generative_model(latent)
        rec_mu, rec_log_sigma = rec_params.chunk(2, dim=-1)
        rec_sigma = torch.exp(rec_log_sigma) + 0.1
        rec_dis = torch.distributions.Normal(rec_mu, rec_sigma)
        log_prob = rec_dis.log_prob(next_features)
        L2_loss = log_prob.mean(dim=-1)
        rec_log_l2 = log_prob.sum(dim=-1) 
        rec_log = rec_log_l2 * self.loss_l2_weight

        # Total loss
        loss = - L2_loss + KL_loss + VAR_loss
        

        if use_calibration is not None:
            memory_loss = self.agent_memory.forward(intention_feature, detection_prob, use_calibration, current_feature)
            memory_loss = memory_loss.reshape(memory_loss.shape[0]*memory_loss.shape[2], memory_loss.shape[1])
            Mem_loss = memory_loss*self.memory_weight

        # TODO: Graph Memory --> Loss
        # if use_calibration is not None and USE_POS:
        #     graph_loss = self.graph_memory.pos_loss_compute(pred_positions, use_calibration["position"])
        
        # Record loss information
        self.loss_info = {
            "CERMIC_L2Loss": -1.0 * rec_log_l2.mean().item(),
            "CERMIC_L2Loss_w": -1.0 * rec_log_l2.mean().item() * self.loss_l2_weight,
            "CERMIC_KLLoss": kl.mean().item(),
            "CERMIC_KLLoss_w": kl.mean().item() * self.loss_var_weight,
            "CERMIC_VarLoss": VAR_loss.mean().item(),
            "CERMIC_VarLossUB": var_loss_UB.mean().item(),
            "CERMIC_VarLossLB": var_loss_LB.mean().item(),
            "CERMIC_Beta": self.beta.item(),
            "CERMIC_Loss": loss.mean().item()
        }
        
        
        self.loss = loss
        return loss.mean()

    def calculate_cermic_reward(self, obs, next_obs, ac, position_info=None, prev_extrinsic_reward=None):
        orig_shape = obs.shape
        obs = obs.reshape(obs.shape[0] * obs.shape[2], obs.shape[1], *obs.shape[3:])
        next_obs = next_obs.reshape(next_obs.shape[0] * next_obs.shape[2], next_obs.shape[1], *next_obs.shape[3:])
        ac = ac.reshape(ac.shape[0] * ac.shape[2], ac.shape[1], *ac.shape[3:]) # [bs*n_agents, time_steps, dim]

        # 分批处理
        with torch.no_grad():
            # 计算奖励
            _ = self.forward(obs, next_obs, ac) # torch.Size([5, 600, 8]) torch.Size([5, 1, 16]) torch.Size([5, 600, 2])
            kl_reward = self.kl
            intrinsic_reward = kl_reward
            intrinsic_reward = intrinsic_reward.reshape(orig_shape[0], orig_shape[1], orig_shape[2], 1)
                        
            return intrinsic_reward

    
    def _compute_calibration_nce(self, current_feature, prev_feature_with_reward):
        """Compute calibration NCE loss"""
        # Record original shape for later recovery
        original_shape = current_feature.shape  # [batch_size, time_steps, feature_dim]
        
        encoded = self.agent_memory.calibration_network(prev_feature_with_reward)
        flattened_encoded = flatten_two_dims(encoded)
        flattened_feature = flatten_two_dims(current_feature)

        batch_size = flattened_feature.shape[0]
        num_neg = 16
        temperature = 0.1
        log_num_neg = torch.log(torch.tensor(float(num_neg), device=self.device))
        
        pos_scores = torch.sum(flattened_feature * flattened_encoded, dim=-1, keepdim=True)
        
        noise = torch.randn(batch_size, num_neg, flattened_feature.shape[-1], device=self.device) * temperature
        neg_features = flattened_feature.unsqueeze(1) + noise
        neg_scores = torch.sum(neg_features * flattened_encoded.unsqueeze(1), dim=-1)
        
        all_scores = torch.cat([pos_scores, neg_scores], dim=1) / temperature
        labels = torch.zeros(batch_size, dtype=torch.long, device=self.device)
        
        nce_losses = F.cross_entropy(all_scores, labels, reduction='none')
        
        probs = F.softmax(all_scores, dim=1)
        inv_probs = F.normalize(1.0 - probs + 1e-6, p=1, dim=1)
        inv_scores = torch.log(inv_probs) * temperature
        inv_nce_losses = F.cross_entropy(inv_scores, labels, reduction='none')
        
        nce_losses = nce_losses.reshape(original_shape[0], original_shape[1]).unsqueeze(-1)
        inv_nce_losses = inv_nce_losses.reshape(original_shape[0], original_shape[1]).unsqueeze(-1)

        calibration_UB = log_num_neg + nce_losses
        calibration_LB = log_num_neg - inv_nce_losses
        
        return calibration_UB, calibration_LB
       
    def has_detected_agents(self):
        """Check if other agents are detected"""
        return len(self.agent_memory.memory_queue) > 0

    
    def reset_memory(self):
        """Reset agent memory"""
        self.agent_memory.reset_memory()

    def _initialize_momentum_parameters(self):
        """Initialize parameters of momentum network, copy from main network"""
        for param_q, param_k in zip(self.feature_net.parameters(), self.feature_net_momentum.parameters()):
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False
    
    def momentum_update(self):
        """Update momentum network parameters"""
        with torch.no_grad():
            for param_q, param_k in zip(self.feature_net.parameters(), self.feature_net_momentum.parameters()):
                param_k.data = (1.0 - self.tau) * param_k.data + self.tau * param_q.data
            
    def get_agent_memory(self):
        """Get agent memory module's memory"""
        return self.agent_memory.get_memory()

    
    def get_agent_memory_stats(self):
        """Get agent memory module's statistics"""
        return self.agent_memory.get_memory_stats()
