import torch
import math

def SAIM(current_task_vectors, merged_model_state_dict, base_model_state_dict, task_count=1, beta=1.0):
    """
    Adaptive ISO-C method that dynamically adjusts SVD balance coefficient based on number of merged tasks, while supporting beta weight parameter
    
    Args:
        current_task_vectors: Current batch of task vectors to merge
        merged_model_state_dict: State dictionary of current merged model
        base_model_state_dict: State dictionary of base model
        task_count: Number of merged tasks (including current task)
        beta: Coefficient controlling new vs old task weights, higher beta means higher new task weight
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Performing adaptive ISO-C merge, current task count: {task_count}, beta value: {beta}")
    
    # Calculate cumulative deviation vector (difference between current merged model and base model)
    cumulative_vector = {}
    for key in merged_model_state_dict:
        if key in base_model_state_dict:
            cumulative_vector[key] = merged_model_state_dict[key].cpu() - base_model_state_dict[key].cpu()
    
    # Calculate average of current task vectors
    current_vector = {}
    for key in current_task_vectors[0].vector:
        tvs = [tv.vector[key].to(device) for tv in current_task_vectors]
        current_vector[key] = sum(tvs) / len(tvs)
    
    # Calculate new merged vector
    new_vector = {}
    for key in current_vector:
        # Check if it's a 2D matrix
        if len(current_vector[key].shape) == 2 and "text_projection" not in key:
            # Combine cumulative vector and current vector
            if key in cumulative_vector:               
                # Use beta parameter to calculate weighted sum vector: (2-beta)*historical vector + beta*current vector
                weighted_vector = (2-beta) * cumulative_vector[key].to(device) + beta * current_vector[key]
                
                # Perform SVD decomposition on weighted sum vector
                U, S, V = torch.linalg.svd(weighted_vector, full_matrices=False)
                
                # Calculate balance coefficient (decreases as task count increases)
                balance_factor = 1.0 / math.sqrt(task_count) 
                
                # Partially balance singular values
                S_mean = S.mean()
                S_balanced = S_mean + (S - S_mean) * balance_factor
                
                # Reconstruct vector
                new_vector[key] = torch.linalg.multi_dot((U, torch.diag(S_balanced), V))
            else:
                # If no cumulative vector, directly apply ISO-C to current vector
                U, S, V = torch.linalg.svd(current_vector[key], full_matrices=False)
                S_balanced = torch.ones_like(S) * S.mean()
                new_vector[key] = torch.linalg.multi_dot((U, torch.diag(S_balanced), V))
        else:
            # For non-matrix layers, use current vector directly
            new_vector[key] = current_vector[key]
    
    return new_vector