import os
import torch
import gc
from transformers import AutoModelForCausalLM
from merging.model_utils import apply_task_vector_dict

# Global variable to store current maximum absolute value task vector
MAGMAX_TASK_VECTOR = None

def magmax_merge(base_model_path, current_vector_dict,
                task_index=None, task_count=1, scaling_coef=0.8,
                cache_dir="/root/autodl-tmp/huggingface"):
    """
    Merge models using MagMax method - incremental implementation to reduce memory usage
    
    Principle: For each parameter position, retain the parameter with the largest absolute value among all task vectors
    
    Args:
        base_model_path: Base model path
        current_vector_dict: Pre-computed current task vector
        task_index: Current task index
        task_count: Total processed tasks
        scaling_coef: Task vector scaling coefficient
        cache_dir: Cache directory
        previous_model_state: Compatible parameter, not used but retained for consistency with other merge function interfaces

    Returns:
        model: Merged model
    """
    global MAGMAX_TASK_VECTOR

    print(f"Executing MagMax continual merging, current task index: {task_index}, accumulated tasks: {task_count}")

    if scaling_coef is None:
        scaling_coef = 0.8
        print(f"Using MagMax default scaling coefficient: {scaling_coef}")
        
    # Force garbage collection to free memory
    gc.collect()
    torch.cuda.empty_cache() if torch.cuda.is_available() else None

    # Load base model to CPU
    print(f"Loading base model: {base_model_path}")
    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_path,
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
        device_map="cpu",
        low_cpu_mem_usage=True,
        cache_dir=os.path.join(cache_dir, "transformers")
    )

    # Decide how to process based on task index
    if task_index == 0 or MAGMAX_TASK_VECTOR is None:
        # First task or accumulated vector not initialized, directly use current task vector
        print("First task or accumulated vector not initialized, directly use current task vector")
        MAGMAX_TASK_VECTOR = {k: v.clone() for k, v in current_vector_dict.items()}
    else:
        # Not first task, use MagMax strategy to update accumulated vector
        print("Merge current task vector with accumulated vector using MagMax")

        # Ensure all keys are processed
        all_keys = set(MAGMAX_TASK_VECTOR.keys()).union(current_vector_dict.keys())
        
        for key in all_keys:
            if key in MAGMAX_TASK_VECTOR and key in current_vector_dict:
                # Both vectors have this key, perform MagMax merge
                
                # Ensure data types are consistent
                if MAGMAX_TASK_VECTOR[key].dtype != current_vector_dict[key].dtype:
                    current_vector_dict[key] = current_vector_dict[key].to(MAGMAX_TASK_VECTOR[key].dtype)
                
                # Get current accumulated vector and new task vector
                accumulated_vec = MAGMAX_TASK_VECTOR[key]
                current_vec = current_vector_dict[key]
                
                # Calculate absolute values
                accumulated_abs = torch.abs(accumulated_vec)
                current_abs = torch.abs(current_vec)
                
                # Create mask: True when current vector's absolute value is greater than accumulated vector
                mask = current_abs > accumulated_abs
                
                # Update accumulated vector only where mask is True
                MAGMAX_TASK_VECTOR[key][mask] = current_vec[mask]
                
            elif key in current_vector_dict:
                # Only current vector has this key, add directly
                MAGMAX_TASK_VECTOR[key] = current_vector_dict[key].clone()

    # Create copy of accumulated vector to avoid accidental modification of global variable
    combined_vector_dict = {k: v.clone() for k, v in MAGMAX_TASK_VECTOR.items()}

    # Release unnecessary variables to save memory
    del current_vector_dict
    gc.collect()
    torch.cuda.empty_cache() if torch.cuda.is_available() else None

    # Apply merged vector to base model and create final model
    print(f"Applying MagMax accumulated task vector to base model (scaling_coef={scaling_coef})...")
    merged_model = apply_task_vector_dict(
        combined_vector_dict, base_model, scaling_coef=scaling_coef)

    # Release merged vector
    del combined_vector_dict
    gc.collect()
    torch.cuda.empty_cache() if torch.cuda.is_available() else None

    return merged_model


def save_magmax_vector_to_disk(directory):
    """Save current MagMax task vector to disk"""
    global MAGMAX_TASK_VECTOR
    if MAGMAX_TASK_VECTOR is not None:
        vector_path = os.path.join(directory, "magmax_task_vector.pt")
        torch.save(MAGMAX_TASK_VECTOR, vector_path)
        print(f"Saved MagMax task vector to: {vector_path}")
        return vector_path
    return None


def load_magmax_vector_from_disk(directory):
    """Load MagMax task vector from disk"""
    global MAGMAX_TASK_VECTOR
    vector_path = os.path.join(directory, "magmax_task_vector.pt")
    if os.path.exists(vector_path):
        MAGMAX_TASK_VECTOR = torch.load(vector_path)
        print(f"Loaded MagMax task vector: {vector_path}")
        return True
    print(f"MagMax task vector file not found: {vector_path}")
    return False