import torch
import os

# Global dictionary to store cumulative task vectors
CUMULATIVE_VECTORS = {}

def task_arithmetic_merge(task_vectors, method_name="task_arithmetic"):
    """
    Implement task arithmetic model merging method - based on cumulative task vectors
    
    Core formula: θ_merged = θ_0 + α * (τ_cum + τ_current)
    
    Args:
        task_vectors: Current task vector list
        method_name: Method name, used to distinguish cumulative vectors

    Returns:
        merged_vector: Merged vector (sum of cumulative task vectors + current task vectors)
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Performing improved task arithmetic model merging, adding {len(task_vectors)} new task vectors to cumulative vector")
    
    # Initialize sum of current task vectors
    current_task_sum = {}
    for key in task_vectors[0].vector:
        tvs = [tv.vector[key].to(device) for tv in task_vectors]
        current_task_sum[key] = sum(tvs)
    
    # Get or initialize cumulative vector
    global CUMULATIVE_VECTORS
    if method_name not in CUMULATIVE_VECTORS:
        print(f"Initializing cumulative task vectors for {method_name}")
        CUMULATIVE_VECTORS[method_name] = current_task_sum
    else:
        print(f"Updating cumulative task vectors for {method_name}")
        for key in CUMULATIVE_VECTORS[method_name]:
            if key in current_task_sum:
                CUMULATIVE_VECTORS[method_name][key] += current_task_sum[key]
    
    merged_vector = {}
    for key in CUMULATIVE_VECTORS[method_name]:
        merged_vector[key] = CUMULATIVE_VECTORS[method_name][key].clone()
    
    return merged_vector


# Helper functions for saving and loading cumulative vectors
def save_cumulative_vectors(save_dir):
    """Save all cumulative vectors to disk"""
    os.makedirs(save_dir, exist_ok=True)
    for method_name, vectors in CUMULATIVE_VECTORS.items():
        save_path = os.path.join(save_dir, f"cumulative_vector_{method_name}.pt")
        torch.save(vectors, save_path)
        print(f"Saved cumulative vectors for {method_name} to {save_path}")


def load_cumulative_vectors(load_dir):
    """Load all cumulative vectors from disk"""
    global CUMULATIVE_VECTORS
    for filename in os.listdir(load_dir):
        if filename.startswith("cumulative_vector_") and filename.endswith(".pt"):
            method_name = filename.replace("cumulative_vector_", "").replace(".pt", "")
            load_path = os.path.join(load_dir, filename)
            CUMULATIVE_VECTORS[method_name] = torch.load(load_path)
            print(f"Loaded cumulative vectors for {method_name}")