import os
import torch
import math
import gc
from transformers import AutoModelForCausalLM
from merging.model_utils import apply_task_vector_dict


def SAIM(base_model_path, current_vector_dict,
                       task_index=None, task_count=1, scaling_coef=1.0,
                       previous_model_state=None, beta=0.75, cache_dir="/root/autodl-tmp/huggingface"):
    """
    Merge models using SAIM method - executed entirely on CPU to avoid GPU memory shortage

    Args:
        base_model_path: Base model path
        current_vector_dict: Pre-computed current task vector
        task_index: Current task index
        task_count: Total processed tasks
        scaling_coef: Scaling coefficient
        previous_model_state: Dictionary storing previous model state
        beta: Coefficient controlling new and old task weights, larger beta means higher new task weight, default 0.8

    Returns:
        model: Merged model
    """
    print(f"Executing adaptive ISO merging, current task index: {task_index}, accumulated tasks: {task_count}")

    # Force all operations on CPU
    device = torch.device("cpu")
    print(f"All calculations will be executed on CPU")

    if scaling_coef is None:
        scaling_coef = 1.0
        print(f"Using SAIM default scaling coefficient: {scaling_coef}")
        
    # Ensure previous_model_state is a valid dictionary
    if previous_model_state is None:
        previous_model_state = {}

    # Initialize accumulated vector dictionary and previous model
    cumulative_vector_dict = None
    previous_model = None

    # Load base model to CPU
    print(f"Loading base model: {base_model_path}")
    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_path,
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
        device_map="cpu",
        low_cpu_mem_usage=True,
        cache_dir=os.path.join(cache_dir, "transformers")
    )

    # Determine previous_model and accumulated vector
    if task_index == 0:
        # First task, previous_model is base_model
        previous_model = base_model
        print("First task, using base model as starting point")
    elif "current_state" in previous_model_state and task_index is not None and task_index > 0:
        print("Loading previous merged model state from memory")

        # Create temporary model and load previous state
        previous_model = AutoModelForCausalLM.from_pretrained(
            base_model_path,
            torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
            device_map="cpu",
            low_cpu_mem_usage=True,
            cache_dir=os.path.join(cache_dir, "transformers")
        )

        # Load state dictionary to CPU
        sanitized_state_dict = {}
        for k, v in previous_model_state["current_state"].items():
            sanitized_state_dict[k] = v.to(device)

        # Load sanitized state dictionary
        previous_model.load_state_dict(sanitized_state_dict)

        # Calculate accumulated deviation vector (difference between current merged model and base model)
        print("Calculating accumulated deviation vector...")
        cumulative_vector_dict = {
            k: (previous_model.state_dict()[k].to(
                device) - base_model.state_dict()[k].to(device))
            for k in base_model.state_dict() if k in previous_model.state_dict()
        }

        # Release memory
        del sanitized_state_dict, previous_model
        gc.collect()

    # Calculate new merged vector dictionary
    if cumulative_vector_dict is not None:
        # Have previous vectors, need to merge using beta weights
        print(f"Merge accumulated vector and current task vector using beta={beta}")
        combined_vector_dict = {}
        all_keys = set(cumulative_vector_dict.keys()).union(
            current_vector_dict.keys())

        for key in all_keys:
            if key in cumulative_vector_dict and key in current_vector_dict:
                # Merge according to beta parameter
                combined_vector_dict[key] = (
                    2-beta) * cumulative_vector_dict[key] + beta * current_vector_dict[key]
            elif key in cumulative_vector_dict:
                combined_vector_dict[key] = cumulative_vector_dict[key]
            else:
                combined_vector_dict[key] = current_vector_dict[key]
    else:
        # No previous vectors, directly use current vector
        combined_vector_dict = current_vector_dict

    del cumulative_vector_dict, current_vector_dict
    gc.collect()

    # Apply SVD to merged vector - this is the core operation of AdaptiveISO
    print("Applying SVD to merged vector...")
    new_vector_dict = {}
    total_keys = len(
        [k for k, v in combined_vector_dict.items() if len(v.shape) == 2])
    processed = 0

    for key, value in combined_vector_dict.items():
        # Perform SVD processing on 2D weight matrices
        if len(value.shape) == 2:
            processed += 1
            if processed % 10 == 0:
                print(f"SVD processing progress: {processed}/{total_keys}")

            try:
                # Ensure tensor is on CPU and float32 type
                tensor_for_svd = value.to(torch.float32)

                # Execute SVD operation
                k = min(100, min(tensor_for_svd.shape))
                U, S, V = torch.linalg.svd(tensor_for_svd, full_matrices=False)

                # Calculate adaptive balance factor based on task count
                balance_factor = 1.0 / math.sqrt(task_count)

                # Partially balance singular values
                S_mean = S.mean()
                S_balanced = S_mean + (S - S_mean) * balance_factor

                # Reconstruct vector
                reconstructed = torch.linalg.multi_dot(
                    [U, torch.diag(S_balanced), V])

                # Store result
                new_vector_dict[key] = reconstructed.to(value.dtype)

                # Periodically clean SVD results to free memory
                if processed % 10 == 0:
                    del U, S, V, reconstructed
                    gc.collect()

            except Exception as e:
                print(f"SVD decomposition error: {e}, use original vector for key {key}")
                new_vector_dict[key] = value
                gc.collect()
        else:
            # Non-matrix layers, use directly
            new_vector_dict[key] = value

    # Release no longer needed vectors
    del combined_vector_dict
    gc.collect()

    # Apply merged vector to base model and create final model
    print("Applying task vector to base model...")
    merged_model = apply_task_vector_dict(
        new_vector_dict, base_model, scaling_coef=scaling_coef)

    # Save merged model state dictionary...
    print("Saving merged model state dictionary...")
    previous_model_state["current_state"] = {
        k: v.clone().detach() for k, v in merged_model.state_dict().items()
    }

    return merged_model