import torch
from collections import defaultdict
from tqdm import tqdm
import numpy as np

def compute_singular_value_decomposition(task_matrix):
    """Compute singular value decomposition of task matrix"""
    try:
        U, S, V = torch.svd(task_matrix)
        return U, S, V
    except Exception as e:
        print(f"SVD computation error: {e}")
        U, S, V = torch.linalg.svd(task_matrix, full_matrices=False)
        return U, S, V


def compute_subspace_alignment_ratio(src_matrix, trg_matrix, k_M):
    """Compute subspace alignment ratio (SAR)"""
    # Compute SVD of target matrix
    U_trg, _, _ = compute_singular_value_decomposition(trg_matrix)
    
    # Get first k_M left singular vectors
    U_k_M_trg = U_trg[:, :k_M]
    
    # Compute projection matrix
    projection = U_k_M_trg @ U_k_M_trg.T
    
    # Compute Frobenius norm of projected source matrix
    projected_src = projection @ src_matrix
    projected_norm = torch.norm(projected_src, p='fro')
    
    # Compute Frobenius norm of source matrix
    src_norm = torch.norm(src_matrix, p='fro')
    
    # Compute SAR
    sar = (projected_norm / src_norm).item()
    
    return sar


def determine_optimal_k(merged_matrix, epsilon=0.05):
    """Determine optimal subspace dimension k_M"""
    U, S, V = compute_singular_value_decomposition(merged_matrix)
    
    # Total energy
    total_energy = torch.sum(S**2)
    
    for k in range(1, min(U.shape[1], V.shape[0]) + 1):
        # Energy of first k singular values
        energy_k = torch.sum(S[:k]**2)
        # Compute approximation error
        error = 1.0 - (energy_k / total_energy)
        
        if error <= epsilon:
            return k
    
    # If no k satisfying the condition is found, return maximum possible k
    return min(U.shape[1], V.shape[0])


def compute_average_sar(task_matrices, merged_matrix, k_M):
    """Compute average subspace alignment ratio (SAR_avg)"""
    sar_values = {}
    
    for task_name, task_matrix in task_matrices.items():
        sar = compute_subspace_alignment_ratio(task_matrix, merged_matrix, k_M)
        sar_values[task_name] = sar
    
    sar_avg = sum(sar_values.values()) / len(sar_values)
    
    return sar_avg, sar_values


def extract_task_matrices(pretrained_model, task_models, device):
    """Extract task matrices (differences between task models and pretrained model)"""
    task_matrices_by_layer = defaultdict(dict)
    
    # Get all layer names
    layer_names = list(pretrained_model.keys())
    
    for layer_name in layer_names:
        if 'weight' in layer_name and pretrained_model[layer_name].dim() == 2:
            pretrained_weight = pretrained_model[layer_name].to(device)
            
            for task_name, task_model in task_models.items():
                task_weight = task_model[layer_name].to(device)
                # Compute task matrix (difference)
                task_matrix = task_weight - pretrained_weight
                task_matrices_by_layer[layer_name][task_name] = task_matrix
    
    return task_matrices_by_layer


def extract_merged_matrix(pretrained_model, merged_model, device):
    """Extract difference matrices between merged model and pretrained model"""
    merged_matrices = {}
    
    # Get all layer names
    layer_names = list(pretrained_model.keys())
    
    for layer_name in layer_names:
        if 'weight' in layer_name and pretrained_model[layer_name].dim() == 2:
            pretrained_weight = pretrained_model[layer_name].to(device)
            merged_weight = merged_model[layer_name].to(device)
            
            # Compute difference matrix
            merged_matrix = merged_weight - pretrained_weight
            merged_matrices[layer_name] = merged_matrix
    
    return merged_matrices


def calculate_sar_metrics(pretrained_check, task_models, merged_model, device):
    """
    Compute SAR metrics for ISO-C merged model, skip layers that may cause NaN
    
    Parameters:
        pretrained_check (dict): Pretrained model state dictionary
        task_models (dict): Task model state dictionaries
        merged_model (dict): Merged model state dictionary
        device (str): Device
        
    Returns:
        dict: Dictionary containing all SAR metrics
    """
    # Extract task matrices
    task_matrices_by_layer = extract_task_matrices(pretrained_check, task_models, device)
    
    # Extract difference matrices for ISO-C merged model
    merged_matrices = extract_merged_matrix(pretrained_check, merged_model, device)
    
    # Compute SAR values
    sar_values_by_layer = {}
    sar_avg_by_layer = {}
    k_M_values = {}
    skipped_layers = []
    
    print("Computing SAR values for each layer...")
    for layer_name, task_matrices in tqdm(task_matrices_by_layer.items()):
        try:
            if layer_name not in merged_matrices:
                continue
                
            # Get ISO-C merged matrix
            merged_matrix = merged_matrices[layer_name]
            
            # Check if matrix contains NaN or Inf values
            if torch.isnan(merged_matrix).any() or torch.isinf(merged_matrix).any():
                print(f"Warning: Layer {layer_name} contains NaN or Inf values, skipping this layer")
                skipped_layers.append(layer_name)
                continue
            
            # Check if matrix is zero or approximately zero
            if torch.norm(merged_matrix, p='fro') < 1e-6:
                print(f"Warning: Norm of layer {layer_name} is close to zero, skipping this layer")
                skipped_layers.append(layer_name)
                continue
            
            # Determine optimal subspace dimension k_M
            k_M = determine_optimal_k(merged_matrix)
            
            # Ensure k_M is at least 1 and does not exceed the minimum dimension of the matrix
            min_dim = min(merged_matrix.shape)
            k_M = max(1, min(k_M, min_dim - 1))
            k_M_values[layer_name] = k_M
            
            # Compute SAR values for each task (based on ISO-C merged matrix)
            layer_sar_values = {}
            valid_task_count = 0
            layer_sar_sum = 0
            
            for task_name, task_matrix in task_matrices.items():
                # Check if task matrix contains NaN or Inf values
                if torch.isnan(task_matrix).any() or torch.isinf(task_matrix).any():
                    print(f"Warning: Task {task_name} in layer {layer_name} contains NaN or Inf values, skipping this task")
                    continue
                
                # Check if task matrix is zero or approximately zero
                if torch.norm(task_matrix, p='fro') < 1e-6:
                    print(f"Warning: Norm of task {task_name} in layer {layer_name} is close to zero, skipping this task")
                    continue
                
                try:
                    sar = compute_subspace_alignment_ratio(task_matrix, merged_matrix, k_M)
                    
                    # Check if SAR value is valid
                    if not (np.isnan(sar) or np.isinf(sar)):
                        layer_sar_values[task_name] = sar
                        layer_sar_sum += sar
                        valid_task_count += 1
                    else:
                        print(f"Warning: SAR computation for task {task_name} in layer {layer_name} resulted in {sar}, skipping this task")
                except Exception as e:
                    print(f"Warning: Error computing SAR for task {task_name} in layer {layer_name}: {e}")
            
            # Only save results for this layer if there are valid SAR values
            if valid_task_count > 0:
                sar_values_by_layer[layer_name] = layer_sar_values
                sar_avg_by_layer[layer_name] = layer_sar_sum / valid_task_count
            else:
                skipped_layers.append(layer_name)
                print(f"Warning: Layer {layer_name} has no valid SAR values, skipping this layer")
                
        except Exception as e:
            print(f"Warning: Error processing layer {layer_name}: {e}")
            skipped_layers.append(layer_name)
    
    # Report number of skipped layers
    print(f"\nTotal skipped {len(skipped_layers)} layers")
    if len(skipped_layers) > 0:
        print(f"Skipped layers include: {', '.join(skipped_layers[:5])}{'...' if len(skipped_layers) > 5 else ''}")
    
    # Compute average SAR value for each task
    avg_sar_by_task = defaultdict(list)  # Changed to list for subsequent average calculation
    task_counts = defaultdict(int)
    
    for layer_name, sar_values in sar_values_by_layer.items():
        for task_name, sar in sar_values.items():
            avg_sar_by_task[task_name].append(sar)
            task_counts[task_name] += 1
    
    # Compute average, considering only valid layers
    task_avg_sar = {}
    for task_name, sar_values in avg_sar_by_task.items():
        if len(sar_values) > 0:  # Ensure at least one valid value
            task_avg_sar[task_name] = sum(sar_values) / len(sar_values)
        else:
            task_avg_sar[task_name] = float('nan')
    
    # Compute overall average SAR, considering only valid values
    valid_sars = [sar for sar in task_avg_sar.values() if not (np.isnan(sar) or np.isinf(sar))]
    overall_avg_sar = sum(valid_sars) / len(valid_sars) if valid_sars else float('nan')
    
    return {
        "layer_specific_sar": sar_values_by_layer,
        "layer_avg_sar": sar_avg_by_layer,
        "task_avg_sar": task_avg_sar,
        "overall_avg_sar": overall_avg_sar,
        "k_M_values": k_M_values,
        "skipped_layers": skipped_layers,
        "valid_layer_count": len(sar_values_by_layer)
    }