import torch
import numpy as np
import logging
from typing import List, Dict, Tuple, Union
from collections import defaultdict
import scipy.stats

logger = logging.getLogger("logger")

def identify_active_dimensions(update_vectors: torch.Tensor, epsilon: float = 1e-3) -> torch.Tensor:
    """
    Active dimension filtering
    """
    try:
        # Calculate max absolute value per dimension
        max_abs_updates = torch.abs(update_vectors).max(dim=0)[0]
        
        # Filter active dimensions
        active_dims = max_abs_updates > epsilon
        
        logger.debug(f"Active dimensions: {active_dims.sum().item()}/{active_dims.shape[0]}")
        return active_dims
        
    except Exception as e:
        logger.error(f"Active dimension filtering failed: {e}")
        return torch.ones(update_vectors.shape[1], dtype=torch.bool, device=update_vectors.device)

def orthogonal_decomposition(update_vector: torch.Tensor, 
                           consensus_vector: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Orthogonal decomposition: decompose update vector into parallel and orthogonal components
    """
    try:
        # Avoid division by zero
        consensus_norm_sq = torch.sum(consensus_vector ** 2)
        if consensus_norm_sq < 1e-10:
            parallel_component = torch.zeros_like(update_vector)
            orthogonal_component = update_vector.clone()
        else:
            # Parallel component
            projection_coeff = torch.sum(update_vector * consensus_vector) / consensus_norm_sq
            parallel_component = projection_coeff * consensus_vector
            
            # Orthogonal component
            orthogonal_component = update_vector - parallel_component
        
        return parallel_component, orthogonal_component
        
    except Exception as e:
        logger.error(f"Orthogonal decomposition failed: {e}")
        return torch.zeros_like(update_vector), update_vector.clone()

def compute_tdf_scores(update_vectors: torch.Tensor, 
                      active_dims: torch.Tensor) -> torch.Tensor:
    """
    Compute TDF module anomaly scores
    """
    try:
        N = update_vectors.shape[0]
        device = update_vectors.device
        
        # Restrict to active dimensions
        active_updates = update_vectors[:, active_dims]
        
        if active_updates.shape[1] == 0:
            logger.warning("No active dimensions, returning zero scores")
            return torch.zeros(N, device=device)
        
        phi_scores = torch.zeros(N, device=device)
        orthogonal_norms = torch.zeros(N, device=device)
        
        # Leave-One-Out computation for each client
        for i in range(N):
            # Build leave-one-out group consensus
            if N == 1:
                consensus_vector = torch.zeros_like(active_updates[i])
            else:
                mask = torch.ones(N, dtype=torch.bool, device=device)
                mask[i] = False
                consensus_vector = active_updates[mask].mean(dim=0)
            
            # Orthogonal decomposition
            parallel_comp, orthogonal_comp = orthogonal_decomposition(
                active_updates[i], consensus_vector
            )
            
            # Compute norms
            update_norm = torch.norm(active_updates[i])
            parallel_norm = torch.norm(parallel_comp)
            orthogonal_norm = torch.norm(orthogonal_comp)
            
            # Task contribution score
            if update_norm > 1e-10:
                phi_task = parallel_norm / update_norm
            else:
                phi_task = 0.0
            
            orthogonal_norms[i] = orthogonal_norm
        
        # Deviation strength score
        if torch.std(orthogonal_norms) > 1e-10:
            phi_dev = (orthogonal_norms - torch.mean(orthogonal_norms)) / torch.std(orthogonal_norms)
        else:
            phi_dev = torch.zeros_like(orthogonal_norms)
        
        # Recompute task contribution scores
        for i in range(N):
            if N == 1:
                consensus_vector = torch.zeros_like(active_updates[i])
            else:
                mask = torch.ones(N, dtype=torch.bool, device=device)
                mask[i] = False
                consensus_vector = active_updates[mask].mean(dim=0)
            
            parallel_comp, _ = orthogonal_decomposition(
                active_updates[i], consensus_vector
            )
            
            update_norm = torch.norm(active_updates[i])
            parallel_norm = torch.norm(parallel_comp)
            
            if update_norm > 1e-10:
                phi_task = parallel_norm / update_norm
            else:
                phi_task = 0.0
            
            # Combined score
            phi_scores[i] = (1 - phi_task) + torch.clamp(phi_dev[i], min=0)
        
        logger.debug(f"TDF score range: [{phi_scores.min():.4f}, {phi_scores.max():.4f}]")
        return phi_scores
        
    except Exception as e:
        logger.error(f"TDF score computation failed: {e}")
        return torch.zeros(update_vectors.shape[0], device=update_vectors.device)

def compute_alignment_score(vec1: torch.Tensor, vec2: torch.Tensor) -> float:
    """
    Compute alignment score
    """
    try:
        norm1 = torch.norm(vec1)
        norm2 = torch.norm(vec2)
        
        if norm1 < 1e-10 or norm2 < 1e-10:
            return 0.0
        
        alignment = torch.sum(vec1 * vec2) / (norm1 * norm2)
        return alignment.item()
        
    except Exception as e:
        logger.error(f"Alignment score computation failed: {e}")
        return 0.0

def extract_classification_heads_from_list(update_dict: List[Dict]) -> List[torch.Tensor]:
    """
    Extract classification head updates from update list
    """
    classification_heads = []
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    try:
        for i, client_update in enumerate(update_dict):
            if "weight" in client_update:
                # Use weight directly as classification head update
                weight_update = client_update["weight"].to(device)  # [num_classes, feature_dim]
                classification_heads.append(weight_update)
            else:
                # Create default zero update if no weight update
                default_update = torch.zeros(10, 512, device=device)  # Assume 10 classes, 512 dim features
                classification_heads.append(default_update)
                logger.warning(f"Client {i} has no weight update, using default zero update")
        
        return classification_heads
        
    except Exception as e:
        logger.error(f"Classification head extraction failed: {e}")
        return [torch.zeros(10, 512, device=device) for _ in range(len(update_dict))]

def extract_classification_heads(update_dict: Dict, agent_name_keys: List) -> List[torch.Tensor]:
    """
    Extract classification head updates from update dictionary
    """
    classification_heads = []
    
    try:
        for agent_name in agent_name_keys:
            client_update = update_dict[agent_name][0]
            
            # Find classification head related weights
            head_weights = []
            for param_name, param_tensor in client_update.items():
                if 'head' in param_name.lower() or 'fc' in param_name.lower() or 'classifier' in param_name.lower():
                    if 'weight' in param_name and param_tensor.dim() == 2:
                        head_weights.append(param_tensor)
            
            if head_weights:
                # Use first found classification head weight
                classification_heads.append(head_weights[0])
            else:
                # Create default small matrix if not found
                classification_heads.append(torch.zeros(10, 64, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')))
                logger.warning(f"Client {agent_name} has no classification head weight, using default")
        
        return classification_heads
        
    except Exception as e:
        logger.error(f"Classification head extraction failed: {e}")
        return [torch.zeros(10, 64) for _ in agent_name_keys]

def compute_chif_scores(classification_heads: List[torch.Tensor]) -> torch.Tensor:
    """
    Compute CHIF module 
    """
    try:
        N = len(classification_heads)
        device = classification_heads[0].device
        num_classes = classification_heads[0].shape[0]
        feature_dim = classification_heads[0].shape[1]
        
        print(f"Classification head dimensions: [{N}, {num_classes}, {feature_dim}]")
        print(f"Number of classes: {num_classes}")
        
        if N == 0 or num_classes == 0:
            return torch.zeros(N, device=device)
        
        psi_scores = torch.zeros(N, device=device)
        
        # Compute CHIF score for each client
        for i in range(N):
            client_head = classification_heads[i]  # shape: (C, d)
            max_score = float('-inf')
            best_target_class = -1
            
            # Build Leave-One-Out consensus
            if N == 1:
                # Single client case
                consensus_heads = torch.zeros_like(client_head)
            else:
                other_heads = [classification_heads[k] for k in range(N) if k != i]
                consensus_heads = torch.stack(other_heads).mean(dim=0)  # shape: (C, d)
            
            # Compute alignment for each class (for display)
            alignment_scores = []
            for c in range(num_classes):
                alignment = compute_alignment_score(client_head[c], consensus_heads[c])
                alignment_scores.append(alignment)
            
            # Compute CHIF score for each possible target class t
            for t in range(num_classes):
                # Amplification score: align(ΔW_{h,t}^(i), Δ̄W_{h,t}^(-i))
                amplification_score = compute_alignment_score(
                    client_head[t], consensus_heads[t]
                )
                
                # Suppression score
                if num_classes > 1:
                    suppression_scores = []
                    for c in range(num_classes):
                        if c != t:
                            score = compute_alignment_score(
                                client_head[c], consensus_heads[c]
                            )
                            suppression_scores.append(score)
                    suppression_score = np.mean(suppression_scores)
                else:
                    suppression_score = 0.0

                sa_score = -amplification_score + suppression_score
                if sa_score > max_score:
                    max_score = sa_score
                    best_target_class = t
            
            psi_scores[i] = max(0.0, max_score)  # Ensure non-negative score
            
            print(f"Client {i}: alignment = {[f'{score:.4f}' for score in alignment_scores[:5]]}{'...' if num_classes > 5 else ''}")
            print(f"          best target class {best_target_class}, suppression-amplification score = {psi_scores[i]:.4f}")
        
        # Normalization: use norm normalization
        psi_norm = torch.norm(psi_scores)
        if psi_norm > 1e-10:
            psi_scores = psi_scores / psi_norm
            print(f"Suppression-amplification scores after normalization: {psi_scores.tolist()}")
        else:
            psi_scores = torch.zeros_like(psi_scores)
            print("All suppression-amplification scores are zero, keeping zero after normalization")
        
        logger.debug(f"SAFD score range: [{psi_scores.min():.4f}, {psi_scores.max():.4f}]")
        return psi_scores
        
    except Exception as e:
        logger.error(f"SAFD score computation failed: {e}")
        return torch.zeros(N, device=classification_heads[0].device if classification_heads else torch.device('cpu'))

def tdf_chif_defense(dataset_name, client_num, server_model, client_models, update_dict: List[Dict], 
                    lambda1: float = 0.6, 
                    lambda2: float = 0.4,
                    epsilon: float = 1e-2,
                    mode: str = 'both') -> Tuple[List[Dict], List, List]:
    """
    TDF+CHIF joint defense main function

    Returns:
        original_update_dict: Original update list
        benign_indices: Benign client index list
        d_weight: Client weight list (0 for malicious, 1 for benign)
    """
    try:
        print("=== Starting TDF+CHIF Joint Defense ===")
        print(f"Parameters: λ1={lambda1} (TDF weight), λ2={lambda2} (CHIF weight), ε={epsilon}, mode={mode}")
        print(f"Dataset: {dataset_name}, Client count: {client_num}")
        
        N = len(update_dict)
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # 1. Extract and concatenate all client update vectors
        update_vectors = []
        
        for i, client_update in enumerate(update_dict):
            # Concatenate weight, bias, Prompt_Tokens
            flattened_params = []
            
            # Add classification head weight
            if "weight" in client_update:
                flattened_params.append(client_update["weight"].flatten())
            
            # Add classification head bias
            if "bias" in client_update:
                flattened_params.append(client_update["bias"].flatten())
            
            # Add Prompt_Tokens
            if "Prompt_Tokens" in client_update:
                flattened_params.append(client_update["Prompt_Tokens"].flatten())
            
            if flattened_params:
                client_vector = torch.cat(flattened_params).to(device)
                update_vectors.append(client_vector)
            else:
                logger.warning(f"Client {i} has no valid parameters, using zero vector")
                update_vectors.append(torch.zeros(100, device=device))  # Default dimension
        
        if not update_vectors:
            logger.error("No valid client updates found")
            return update_dict, list(range(N)), [1] * N
        
        update_vectors = torch.stack(update_vectors)  # shape: (N, d)
        print(f"Update vector dimensions: {update_vectors.shape}")
        
        # 2. Active dimension filtering
        active_dims = identify_active_dimensions(update_vectors, epsilon)
        
        # 3. Compute TDF scores
        phi_scores = compute_tdf_scores(update_vectors, active_dims)
        
        # 4. Compute CHIF scores
        print("=== CHIF Module ===")
        classification_heads = extract_classification_heads_from_list(update_dict)
        print(f"Extracted {len(classification_heads)} classification heads, dimensions: {classification_heads[0].shape if classification_heads else 'N/A'}")
        psi_scores = compute_chif_scores(classification_heads)
        
        # 5. Final anomaly score fusion
        if mode == 'tdf':
            anomaly_scores = phi_scores
            threshold_high = 2.0
            threshold_low = 1.0
            print("[Ablation] Using TDF scores only, thresholds 2.0/1.0")
        elif mode == 'safd':
            anomaly_scores = psi_scores
            threshold_high = 0.4
            threshold_low = 0.2
            print("[Ablation] Using SAFD scores only, thresholds 0.4/0.2")
        else:
            anomaly_scores = lambda1 * phi_scores + lambda2 * psi_scores
            threshold_high = 0.8
            threshold_low = 0.4
        
        print("=== Score Results ===")
        for i in range(N):
            print(f"Client {i}: TDF={phi_scores[i]:.4f}, CHIF={psi_scores[i]:.4f}, Total={anomaly_scores[i]:.4f}")

        # === New anomaly detection and weight allocation strategy ===
        # First filter benign and suspicious clients
        benign_indices = []
        suspicious_indices = []
        d_weight = [0] * N

        for i, score in enumerate(anomaly_scores):
            if score > threshold_high:
                d_weight[i] = 0  # Malicious
                print(f"Client {i} -> Malicious (score: {score:.4f})")
            elif score > threshold_low:
                suspicious_indices.append(i)
                benign_indices.append(i)
                print(f"Client {i} -> Suspicious (score: {score:.4f})")
            else:
                benign_indices.append(i)
                print(f"Client {i} -> Benign (score: {score:.4f})")

        # Normalize using benign + suspicious anomaly_scores
        norm_indices = benign_indices  # Benign + suspicious
        if norm_indices:
            norm_scores = anomaly_scores[norm_indices]
            norm = torch.norm(norm_scores)
            if norm > 1e-10:
                normed_scores = norm_scores / norm
            else:
                normed_scores = torch.zeros_like(norm_scores)
            # Assign weights
            for idx, i in enumerate(norm_indices):
                if i in suspicious_indices:
                    d_weight[i] = 1 - normed_scores[idx].item()
                    print(f"Client {i} -> Suspicious weight after normalization: {d_weight[i]:.4f}")
                else:
                    d_weight[i] = 1  # Benign
        # Malicious d_weight is already 0

        print(f"Detection complete: {len(benign_indices)}/{N} benign or suspicious clients, {N - len(benign_indices)} malicious clients")
        print(f"Final result: Keep {len(benign_indices)}/{N} benign or suspicious clients, remove {N - len(benign_indices)} malicious clients")
        print("=== TDF+CHIF Joint Defense Complete ===")

        return update_dict, benign_indices, d_weight

    except Exception as e:
        logger.error(f"TDF+CHIF defense failed: {e}")
        # Return all clients as benign in case of exception
        return update_dict, list(range(len(update_dict))), [1] * len(update_dict)

def defense_tdf_chif(dataset_name, client_num, server_model, client_models, update_dict, **kwargs):
    """
    External interface function: TDF+CHIF joint defense
    
    Task Divergence Filter (TDF) + Classification head update anomaly detection
    """
    return tdf_chif_defense(dataset_name, client_num, server_model, client_models, update_dict, **kwargs)
