import torch
import os
from src.models.task_vectors import NonLinearTaskVector
from src.models.model_utils import load_pretrained_model

def magmax_merge(task_vectors):
    """
    Implement MagMax merge algorithm
    For each parameter position, select the parameter value with the largest absolute value from all task vectors
    
    Args:
        task_vectors: List of task vectors to merge
        
    Returns:
        merged_vector: Merged task vector
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Performing MagMax merge, merging {len(task_vectors)} task vectors")
    
    # Initialize result vector
    merged_vector = {}
    
    # For each parameter key in the first task vector
    for key in task_vectors[0].vector:
        # Collect values for this key from all task vectors and move to same device
        all_params = [tv.vector[key].to(device) for tv in task_vectors]
        
        # Initialize merged parameter as the first task vector's parameter
        merged_param = all_params[0].clone()
        merged_abs = torch.abs(merged_param)
        
        # Iterate through remaining task vectors
        for param in all_params[1:]:
            param_abs = torch.abs(param)
            # Find positions where current task vector has larger absolute value
            mask = param_abs > merged_abs
            # Update merged parameter at these positions
            merged_param[mask] = param[mask]
            # Update merged absolute values
            merged_abs[mask] = param_abs[mask]
        
        merged_vector[key] = merged_param
    
    return merged_vector

def recover_task_vectors_from_finetuned_models(tasks, experiment_dir, args):
    """
    Reconstruct task vectors from saved fine-tuned models
    
    Args:
        tasks: List of tasks to recover
        experiment_dir: Experiment results directory
        model_name: Model name (e.g., "ViT-B-32")
        device: Computing device
        
    Returns:
        recovered_vectors: Recovered task vectors dictionary {task_name: task_vector}
    """
    print(f"Recovering task vectors from saved fine-tuned models for MagMax method...")
    
    # Define fine-tuned models directory
    finetune_models_dir = os.path.join(experiment_dir, "finetunedModels")
    pretrained_state_dict = load_pretrained_model(args)    
    pretrained_on_device = {k: v.to(args.device) for k, v in pretrained_state_dict.items()}
    # Initialize result dictionary
    recovered_vectors = {}
    
    # Process each task to recover
    for task in tasks:
        finetuned_path = os.path.join(finetune_models_dir, f"finetuned_magmax_{task}.pt")
            
        try:
            # Load fine-tuned model state dictionary
            finetuned_state_dict = torch.load(finetuned_path, map_location=args.device)
            finetuned_on_device = {k: v.to(args.device) for k, v in finetuned_state_dict.items()}
            
            # Calculate task vector (MagMax always uses pretrained model as base)
            task_vector = NonLinearTaskVector(args.model, pretrained_on_device, finetuned_on_device)
            
            # Save to result dictionary
            recovered_vectors[task] = task_vector
            
        except Exception as e:
            print(f"  Error recovering task vector for '{task}': {e}")
            
    # Report results
    print(f"Task vector recovery completed: Successfully recovered {len(recovered_vectors)}/{len(tasks)} task vectors")
    
    # Release memory
    del pretrained_state_dict
    import gc
    gc.collect()
    torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    return recovered_vectors