import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import numpy as np
from typing import List, Tuple, Dict, Optional
from collections import defaultdict, deque
from modules.DMM import DecisionMod
from modules.COMMSM import CommsMod
from modules.CommsProtocol import CommunicationProtocolAnalyzer

class SpikeAgent(nn.Module):
    def __init__(self, pretrained_commsmod_path: str = None, n_msg: int = 128, n_actions: int = 3,
                 ε: float = 0.1, γ: float = 0.99, lr: float = 1e-3, beta: float = 0.9,
                 freeze_commsmod: bool = True, use_shaped_rewards: bool = True,
                 use_auxiliary_loss: bool = True, auxiliary_weight: float = 0.1):
        """
        SpikeAgent that contructs CommsMod and DecionNet into 1 Agent.
        """
        super(SpikeAgent, self).__init__()
        self.n_msg = n_msg
        self.n_actions = n_actions
        self.ε = ε
        self.γ = γ
        self.use_shaped_rewards = use_shaped_rewards
        self.use_auxiliary_loss = use_auxiliary_loss
        self.auxiliary_weight = auxiliary_weight
        self.freeze_commsmod = freeze_commsmod
        
        # Initialize communication networks
        if pretrained_commsmod_path:
            self.comm = self._load_pretrained_commsmod(pretrained_commsmod_path)
            self.comm_target = self._load_pretrained_commsmod(pretrained_commsmod_path)
            
            if freeze_commsmod:
                for param in self.comm.parameters():
                    param.requires_grad = False
                for param in self.comm_target.parameters():
                    param.requires_grad = False
        else:
            self.comm = CommsMod(embedding_dim=n_msg, beta=beta)
            self.comm_target = CommsMod(embedding_dim=n_msg, beta=beta)
        

        self.decision = DecisionMod(
            n_msg=n_msg,
            n_actions=n_actions,
            beta=beta,
            threshold=0.8,
            num_steps=25,
            hidden_dim=256,
            temperature=1.5  # Start with higher temperature
        )
        
        self.decision_target = DecisionMod(
            n_msg=n_msg,
            n_actions=n_actions,
            beta=beta,
            threshold=0.8,
            num_steps=25,
            hidden_dim=256,
            temperature=1.5
        )

        self.protocol_analyzer = CommunicationProtocolAnalyzer(
            n_msg=n_msg,
            num_steps=25
        )
        
        # Sync target networks
        self.sync_target()
        
        # Set up optimizer
        trainable_params = []
        if not freeze_commsmod or not pretrained_commsmod_path:
            trainable_params.extend(list(self.comm.parameters()))
        trainable_params.extend(list(self.decision.parameters()))
        
        self.optimizer = optim.Adam(trainable_params, lr=lr)
        
        # Learning rate scheduler with patience
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='max', factor=0.5, patience=10
        )
        
        # Tracking
        self.message_history = defaultdict(list)
        self.protocol_metrics_history = []
        self.decision_confidence = deque(maxlen=1000)
        self.temporal_attention_patterns = []
        self.accuracy_history = deque(maxlen=50)
    
    def _load_pretrained_commsmod(self, path: str):
        """Load pretrained CommsMod from file"""
        model = CommsMod(embedding_dim=self.n_msg)
        model.load_state_dict(torch.load(path, map_location='cpu'))
        return model
    
    def sync_target(self):
        """Synchronize target networks with current networks"""
        device = next(self.parameters()).device
        
        self.comm_target.to(device)
        self.decision_target.to(device)
        
        self.comm_target.load_state_dict(self.comm.state_dict())
        self.decision_target.load_state_dict(self.decision.state_dict())
    
    def soft_update_target(self, tau: float = 0.001):
        """Soft update of target networks"""
        device = next(self.parameters()).device
        
        self.comm_target.to(device)
        self.decision_target.to(device)
        
        for target_param, param in zip(self.comm_target.parameters(), self.comm.parameters()):
            target_param.data.copy_(tau * param.data + (1.0 - tau) * target_param.data)
        
        for target_param, param in zip(self.decision_target.parameters(), self.decision.parameters()):
            target_param.data.copy_(tau * param.data + (1.0 - tau) * target_param.data)
    
    @torch.no_grad()
    def send_message(self, img: torch.Tensor, track_protocol: bool = False, label: Optional[int] = None) -> torch.Tensor:
        """Encode an image into a temporal spike message"""
        device = next(self.parameters()).device
        
        if img.dim() == 3:
            img = img.unsqueeze(0)
        
        img = img.to(device)
        
        embeddings, logits, spk_rec = self.comm(img)
        
        if track_protocol and label is not None:
            self.message_history[label].append(spk_rec.squeeze(1).detach().cpu())
        
        return spk_rec.squeeze(1).to(device)  # [T, n_msg]
    
    @torch.no_grad()
    def make_decision(self, candidates: List[torch.Tensor], sender_msg: torch.Tensor,
                      candidate_labels: Optional[List[int]] = None) -> Tuple[int, torch.Tensor, Dict]:
        """Make decision with fixed confidence calculation"""
        device = next(self.parameters()).device
        
        candidates = [c.to(device) for c in candidates]
        sender_msg = sender_msg.to(device)
        
        # Encode all candidates
        candidate_spikes = self.encode_candidates(candidates, use_target=False)
        
        # Prepare sender message
        sender_spikes = sender_msg.unsqueeze(1)  # [T, 1, n_msg]
        
        # Get Q-values and info from DecisionMod
        Qs, info_dict = self.decision(
            sender_spikes,
            candidate_spikes,
            return_auxiliary=(candidate_labels is not None and self.use_auxiliary_loss)
        )
        Qs = Qs.squeeze(0)  # [1, K] -> [K]
        
        # Track temporal attention
        if 'sender_attention' in info_dict:
            self.temporal_attention_patterns.append(info_dict['sender_attention'].detach().cpu())
        
        # Use confidence scores from the fixed DecisionMod
        confidence_scores = info_dict['confidence_scores'].squeeze(0)
        
        # Improved exploration strategy
        if random.random() < self.ε:
            # Exploration
            if random.random() < 0.5:
                # Thompson sampling based on confidence
                action = torch.multinomial(confidence_scores, 1).item()
            else:
                # Random exploration
                action = random.randint(0, len(candidates) - 1)
            exploration_happened = True
        else:
            # Exploitation
            action = torch.argmax(Qs).item()
            exploration_happened = False
        
        # Track confidence
        action_confidence = confidence_scores[action].item()
        self.decision_confidence.append(action_confidence)
        
        # Update info dict
        info_dict.update({
            'action_confidence': action_confidence,
            'max_confidence': confidence_scores.max().item(),
            'confidence_entropy': -(confidence_scores * torch.log(confidence_scores + 1e-8)).sum().item(),
            'is_exploration': exploration_happened,
            'q_gap': (Qs.max() - Qs.min()).item()
        })
        
        return action, Qs, info_dict
    
    def compute_protocol_aware_reward(self, action: int, target_idx: int,
                                      info_dict: Dict, sender_msg: torch.Tensor,
                                      accuracy_history: List[float] = None) -> float:
        """Gentler reward shaping that encourages appropriate confidence"""
        base_reward = 1.0 if action == target_idx else 0.0
        
        if not self.use_shaped_rewards:
            return base_reward
        
        action_confidence = info_dict.get('action_confidence', 0.5)
        
        # Calculate target confidence based on recent performance
        if self.accuracy_history and len(self.accuracy_history) > 5:
            recent_accuracy = np.mean(list(self.accuracy_history)[-10:]) / 100.0
            target_confidence = min(0.95, max(0.6, recent_accuracy))
        else:
            target_confidence = 0.75
        
        if base_reward > 0:  # Correct decision
            # Gentle bonus for good calibration
            calibration_bonus = 0.1 * np.exp(-2 * abs(action_confidence - target_confidence))
            shaped_reward = base_reward + calibration_bonus
        else:  # Wrong decision
            # Less punishment for appropriate low confidence
            if action_confidence < 0.5:
                shaped_reward = 0.1  # Small reward for knowing you don't know
            else:
                shaped_reward = 0.0
        
        return float(np.clip(shaped_reward, 0, 1.2))
    
    def update(self, transition: Tuple) -> float:
        """Simplified update with unified loss function"""
        # Unpack transition
        if len(transition) == 9:
            (target_img, candidates, action, reward, next_img,
             next_candidates, done, info_dict, candidate_labels) = transition
        else:
            (target_img, candidates, action, reward, next_img,
             next_candidates, done, info_dict) = transition
            candidate_labels = None
        
        device = next(self.parameters()).device
        
        # Ensure all inputs are on the correct device
        target_img = target_img.to(device)
        candidates = [c.to(device) for c in candidates]
        
        # Encode current state
        current_spikes = self._encode_single_image(target_img, use_target=False)
        current_cand_spikes = self.encode_candidates(candidates, use_target=False)
        
        # Get current Q-values
        Q_pred, current_info = self.decision(
            current_spikes,
            current_cand_spikes,
            return_auxiliary=(candidate_labels is not None and self.use_auxiliary_loss)
        )
        Q_pred = Q_pred.squeeze(0)
        Q_pred_selected = Q_pred[action]
        
        # Get confidence scores
        confidence = current_info['confidence_scores'].squeeze(0)
        
        # Compute shaped reward
        if info_dict and self.use_shaped_rewards:
            target_idx = None
            for idx, cand in enumerate(candidates):
                if torch.equal(cand.cpu(), target_img.cpu()):
                    target_idx = idx
                    break
            if target_idx is None:
                target_idx = action
            
            shaped_reward = self.compute_protocol_aware_reward(
                action, target_idx, info_dict, current_spikes.squeeze(1)
            )
        else:
            shaped_reward = reward
        
        # Scale reward moderately (removed 10x multiplier)
        shaped_reward = float(shaped_reward) * 2.0  # Gentle 2x instead of 10x
        
        # Q-learning target
        if done:
            Q_target = torch.tensor(shaped_reward, dtype=torch.float32, device=device)
        else:
            # This path typically won't be used in referential game
            next_spikes = self._encode_single_image(next_img, use_target=True)
            next_cand_spikes = self.encode_candidates(next_candidates, use_target=True)
            with torch.no_grad():
                Q_next, _ = self.decision_target(next_spikes, next_cand_spikes)
                Q_next = Q_next.squeeze(0)
                Q_next_max = Q_next.max().item()
            Q_target = torch.tensor(float(shaped_reward + self.γ * Q_next_max),
                                    dtype=torch.float32, device=device)
        
        # Main Q-learning loss
        q_loss = F.smooth_l1_loss(Q_pred_selected, Q_target)
        
        # Unified confidence calibration loss
        if hasattr(self, 'accuracy_history') and len(self.accuracy_history) > 5:
                recent_accuracy = np.mean(list(self.accuracy_history)[-10:]) / 100.0
                target_confidence = min(0.95, max(0.6, recent_accuracy))
        else:
            target_confidence = 0.75
        
        # Single calibration loss
        if action == target_idx:
            confidence_target = torch.tensor(target_confidence, dtype=torch.float32, device=device)
        else:
            confidence_target = torch.tensor(0.3, dtype=torch.float32, device=device)
        
        calibration_loss = F.mse_loss(confidence[action], confidence_target)
        
        # Entropy regularization for exploration
        Q_probs = F.softmax(Q_pred, dim=0)
        entropy = -(Q_probs * torch.log(Q_probs + 1e-8)).sum()
        
        # Total loss with balanced weights
        total_loss = q_loss + 0.05 * calibration_loss - 0.01 * entropy
        
        # Auxiliary loss if available
        if self.use_auxiliary_loss and candidate_labels is not None and 'class_predictions' in current_info:
            labels_tensor = torch.tensor(candidate_labels, dtype=torch.long, device=device).unsqueeze(0)
            aux_loss = self.decision.compute_auxiliary_loss(
                current_info['class_predictions'],
                labels_tensor
            )
            total_loss += self.auxiliary_weight * aux_loss
        
        # Backprop
        self.optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
        self.optimizer.step()
        
        return total_loss.item()
    
    def _encode_single_image(self, img: torch.Tensor, use_target: bool = False) -> torch.Tensor:
        """Encode a single image"""
        device = next(self.parameters()).device
        
        if img.dim() == 3:
            img = img.unsqueeze(0)
        
        img = img.to(device)
        
        comm_net = self.comm_target if use_target else self.comm
        embeddings, logits, spk_rec = comm_net(img)
        
        return spk_rec.to(device)  # [T, 1, n_msg]
    
    def encode_candidates(self, candidates: List[torch.Tensor], use_target: bool = False) -> torch.Tensor:
        """Encode candidate images efficiently"""
        device = next(self.parameters()).device
        
        candidates = [c.to(device) for c in candidates]
        batch = torch.stack(candidates)
        
        comm_net = self.comm_target if use_target else self.comm
        embeddings, logits, spk_rec = comm_net(batch)
        
        return spk_rec.unsqueeze(1).to(device)  # [T, 1, K, n_msg]
    
    def update_exploration_rate(self, new_epsilon: float):
        """Update exploration rate"""
        self.ε = new_epsilon
    
    def update_learning_rate(self, metric: float):
        """Update learning rate based on performance metric"""
        self.scheduler.step(metric)
    
    def analyze_protocol_development(self) -> Dict:
        """Analyze communication protocol development"""
        if len(self.message_history) == 0:
            return {
                'within_class_similarity': 0.0,
                'between_class_similarity': 0.0,
                'protocol_discriminability': 0.0,
                'attention_consistency': 0.0,
                'attention_entropy': 0.0
            }
        
        # Convert message history to analyzable format
        all_messages = []
        all_labels = []
        
        for label, messages in self.message_history.items():
            recent_messages = messages[-20:]
            for msg in recent_messages:
                if isinstance(msg, torch.Tensor):
                    msg = msg.cpu()
                avg_msg = msg.mean(0) if msg.dim() > 1 else msg
                all_messages.append(avg_msg)
                all_labels.append(label)
        
        if len(all_messages) < 2:
            return {
                'within_class_similarity': 0.0,
                'between_class_similarity': 0.0,
                'protocol_discriminability': 0.0,
                'attention_consistency': 0.0,
                'attention_entropy': 0.0
            }
        
        # Use protocol analyzer
        protocol_metrics = self.protocol_analyzer.analyze_protocol(
            all_messages, all_labels
        )
        
        # Add temporal attention analysis
        if self.temporal_attention_patterns:
            recent_attention = self.temporal_attention_patterns[-50:]
            if len(recent_attention) > 0:
                recent_attention_cpu = []
                for att in recent_attention:
                    if isinstance(att, torch.Tensor):
                        recent_attention_cpu.append(att.cpu())
                    else:
                        recent_attention_cpu.append(att)
                
                if len(recent_attention_cpu) > 0:
                    attention_tensor = torch.stack(recent_attention_cpu)
                    attention_std = torch.std(attention_tensor, dim=0).mean().item()
                    protocol_metrics['attention_consistency'] = 1.0 / (1.0 + attention_std)
                    
                    avg_attention = attention_tensor.mean(0)
                    if avg_attention.dim() > 0:
                        avg_attention = F.softmax(avg_attention.squeeze(), dim=-1)
                        entropy = -(avg_attention * torch.log(avg_attention + 1e-8)).sum().item()
                        protocol_metrics['attention_entropy'] = entropy
                    else:
                        protocol_metrics['attention_entropy'] = 0.0
                else:
                    protocol_metrics['attention_consistency'] = 0.0
                    protocol_metrics['attention_entropy'] = 0.0
            else:
                protocol_metrics['attention_consistency'] = 0.0
                protocol_metrics['attention_entropy'] = 0.0
        else:
            protocol_metrics['attention_consistency'] = 0.0
            protocol_metrics['attention_entropy'] = 0.0
        
        return protocol_metrics
    
    def get_communication_statistics(self) -> Dict:
        """Get comprehensive statistics"""
        stats = {
            'avg_decision_confidence': np.mean(self.decision_confidence) if self.decision_confidence else 0,
            'exploration_rate': self.ε,
            'learning_rate': self.optimizer.param_groups[0]['lr']
        }
        
        # Add protocol development metrics
        protocol_metrics = self.analyze_protocol_development()
        stats.update(protocol_metrics)
        
        # Add temporal attention stats
        if self.temporal_attention_patterns:
            recent_attention = torch.stack(self.temporal_attention_patterns[-50:])
            stats['peak_attention_time'] = recent_attention.mean(0).argmax().item()
            stats['attention_spread'] = torch.std(recent_attention.mean(0)).item()
        
        # Confidence calibration metrics
        if len(self.decision_confidence) > 10:
            recent_conf = list(self.decision_confidence)[-100:]
            stats.update({
                'confidence_std': np.std(recent_conf),
                'confidence_trend': np.corrcoef(range(len(recent_conf)), recent_conf)[0,1] if len(recent_conf) > 3 else 0,
                'low_confidence_rate': np.mean([c < 0.5 for c in recent_conf]),
                'high_confidence_rate': np.mean([c > 0.8 for c in recent_conf])
            })
        
        return stats
    
    def clear_message_history(self):
        """Clear message history to prevent memory issues"""
        self.message_history.clear()
        self.temporal_attention_patterns = self.temporal_attention_patterns[-1000:]

    def save_agent(self, path: str):
        """Save the complete agent state."""
        torch.save({
            'comm_state_dict': self.comm.state_dict(),
            'decision_state_dict': self.decision.state_dict(),
            'protocol_analyzer_state_dict': self.protocol_analyzer.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'epsilon': self.ε,
            'decision_confidence': self.decision_confidence,
            'protocol_metrics_history': self.protocol_metrics_history
        }, path)
    
    def load_agent(self, path: str):
        """Load the complete agent state."""
        checkpoint = torch.load(path, map_location=self.device)
        self.comm.load_state_dict(checkpoint['comm_state_dict'])
        self.decision.load_state_dict(checkpoint['decision_state_dict'])
        if 'protocol_analyzer_state_dict' in checkpoint:
            self.protocol_analyzer.load_state_dict(checkpoint['protocol_analyzer_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.ε = checkpoint['epsilon']
        self.decision_confidence = checkpoint.get('decision_confidence', [])
        self.protocol_metrics_history = checkpoint.get('protocol_metrics_history', [])
        self.sync_target()
