import torch
import os

# Global dictionary to store cumulative task vectors
TIES_CUMULATIVE_VECTORS = {}

def keep_topk_reset_rest_to_zero(vector, k):
    """
    Keep the top k% parameters with the largest absolute values in the vector, set the rest to zero
    
    Args:
        vector: Input vector dictionary
        k: Percentage to keep (0-100)
    
    Returns:
        Processed vector dictionary
    """
    if k >= 100:
        return {key: param.clone() for key, param in vector.items()}
    
    result = {}
    for key, param in vector.items():
        # Calculate number of parameters to keep
        num_params = param.numel()
        num_keep = max(1, int(num_params * k / 100))
        
        # Flatten for topk selection
        flat_param = param.reshape(-1)
        _, indices = torch.topk(flat_param.abs(), num_keep)
        
        # Create mask
        mask = torch.zeros_like(flat_param)
        mask[indices] = 1
        mask = mask.reshape(param.shape)
        
        # Apply mask - this is τ̂_t (trimmed task vector)
        result[key] = param * mask
    
    return result

def ties_merge(task_vectors, k=20, method_name="ties_merge"):
    """
    Implement cumulative version of TIES-MERGING algorithm
    
    Args:
        task_vectors: List of NonLinearTaskVector objects for current tasks
        k: Percentage of parameters to keep (default 20%)
        method_name: Method name, used to distinguish cumulative vectors

    Returns:
        Merged task vector dictionary
    """
    if not task_vectors:
        return {}
    
    print(f"Improved TIES-MERGE: Accumulating new task vectors, keeping parameter ratio {k}%")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Step 1: Trim redundant parameters in current task vectors
    trimmed_vectors = []
    for task_vector in task_vectors:
        trimmed = keep_topk_reset_rest_to_zero(task_vector.vector, k)
        trimmed_vectors.append(trimmed)
    
    # Step 2: Calculate merged vector for current tasks
    current_merged = {}
    for key in trimmed_vectors[0].keys():
        # Calculate sum of current batch task vectors
        current_sum = sum(vec[key].to(device) for vec in trimmed_vectors)
        # Calculate sign
        current_sign = torch.sign(current_sum)
        
        # Create result tensor
        result = torch.zeros_like(current_sum)
        
        # Apply disjoint merging for each task vector
        for vector in trimmed_vectors:
            vec = vector[key].to(device)
            task_sign = torch.sign(vec)
            sign_match = (task_sign * current_sign) > 0
            result += vec * sign_match.float()
        
        current_merged[key] = result
    
    # Step 3: Get or initialize cumulative vector
    global TIES_CUMULATIVE_VECTORS
    if method_name not in TIES_CUMULATIVE_VECTORS:
        print(f"Initializing cumulative TIES vectors for {method_name}")
        TIES_CUMULATIVE_VECTORS[method_name] = current_merged
    else:
        print(f"Updating cumulative TIES vectors for {method_name}")
        # Add current task merged vector to cumulative vector
        for key in TIES_CUMULATIVE_VECTORS[method_name]:
            if key in current_merged:
                TIES_CUMULATIVE_VECTORS[method_name][key] += current_merged[key]
    
    # Return copy of cumulative vector
    merged_vector = {}
    for key in TIES_CUMULATIVE_VECTORS[method_name]:
        merged_vector[key] = TIES_CUMULATIVE_VECTORS[method_name][key].clone()
    
    return merged_vector


# Helper functions for saving and loading cumulative vectors
def save_ties_cumulative_vectors(save_dir):
    """Save all TIES cumulative vectors to disk"""
    os.makedirs(save_dir, exist_ok=True)
    for method_name, vectors in TIES_CUMULATIVE_VECTORS.items():
        save_path = os.path.join(save_dir, f"ties_cumulative_vector_{method_name}.pt")
        torch.save(vectors, save_path)
        print(f"Saved TIES cumulative vectors for {method_name} to {save_path}")


def load_ties_cumulative_vectors(load_dir):
    """Load all TIES cumulative vectors from disk"""
    global TIES_CUMULATIVE_VECTORS
    for filename in os.listdir(load_dir):
        if filename.startswith("ties_cumulative_vector_") and filename.endswith(".pt"):
            method_name = filename.replace("ties_cumulative_vector_", "").replace(".pt", "")
            load_path = os.path.join(load_dir, filename)
            TIES_CUMULATIVE_VECTORS[method_name] = torch.load(load_path)
            print(f"Loaded TIES cumulative vectors for {method_name}")