import os
import torch
import copy

# Global dictionary to store cumulative weights
SWA_CUMULATIVE_WEIGHTS = {}

def swa_merge(finetuned_model_state_dict, method_name="swa"):
    """
    Implement Stochastic Weight Averaging (SWA) merging method
    
    Core formula: θ_merged = 1/2 * θ_cum + 1/2 * θ_current
    
    Args:
        finetuned_model_state_dict: Current fine-tuned model state dictionary (θ_current)
        method_name: Method name, used to distinguish cumulative weights
        
    Returns:
        merged_state_dict: Merged model state dictionary
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Performing SWA weight averaging merge...")
    
    # Get or initialize cumulative weights
    global SWA_CUMULATIVE_WEIGHTS
    if method_name not in SWA_CUMULATIVE_WEIGHTS:
        print(f"Initializing cumulative weights for {method_name}")
        # On first use, directly use current fine-tuned model weights
        SWA_CUMULATIVE_WEIGHTS[method_name] = copy.deepcopy(finetuned_model_state_dict)
    else:
        print(f"Updating cumulative weights for {method_name}")
        # Calculate average of current cumulative weights and new fine-tuned model
        for key in SWA_CUMULATIVE_WEIGHTS[method_name]:
            if key in finetuned_model_state_dict:
                SWA_CUMULATIVE_WEIGHTS[method_name][key] = 0.5 * SWA_CUMULATIVE_WEIGHTS[method_name][key].to(device) + \
                                                           0.5 * finetuned_model_state_dict[key].to(device)
    
    # Return copy of cumulative weights
    merged_state_dict = {}
    for key in SWA_CUMULATIVE_WEIGHTS[method_name]:
        merged_state_dict[key] = SWA_CUMULATIVE_WEIGHTS[method_name][key].clone()
    
    return merged_state_dict


# SWA application function integrated with merge_tasks_incremental function
def apply_swa_merge(current_model, finetuned_model_state_dict, method_name="swa"):
    """
    Apply SWA merge and return merged model
    
    Args:
        current_model: Current model object
        finetuned_model_state_dict: Fine-tuned model state dictionary
        method_name: SWA method name
        
    Returns:
        Merged model object
    """
    # Get merged state dictionary
    merged_state_dict = swa_merge(finetuned_model_state_dict, method_name)
    
    # Create new model instance
    merged_model = copy.deepcopy(current_model)
    
    # Load merged state dictionary
    merged_model.load_state_dict(merged_state_dict)
    
    return merged_model


# Helper functions for saving and loading cumulative weights
def save_swa_weights(save_dir):
    """Save all SWA cumulative weights to disk"""
    os.makedirs(save_dir, exist_ok=True)
    for method_name, weights in SWA_CUMULATIVE_WEIGHTS.items():
        save_path = os.path.join(save_dir, f"swa_weights_{method_name}.pt")
        torch.save(weights, save_path)
        print(f"Saved SWA cumulative weights for {method_name} to {save_path}")


def load_swa_weights(load_dir):
    """Load all SWA cumulative weights from disk"""
    global SWA_CUMULATIVE_WEIGHTS
    for filename in os.listdir(load_dir):
        if filename.startswith("swa_weights_") and filename.endswith(".pt"):
            method_name = filename.replace("swa_weights_", "").replace(".pt", "")
            load_path = os.path.join(load_dir, filename)
            SWA_CUMULATIVE_WEIGHTS[method_name] = torch.load(load_path)
            print(f"Loaded SWA cumulative weights for {method_name}")