import torch
import os

# Global dictionaries to store cumulative task vectors and task counts
ISO_C_CUMULATIVE_VECTORS = {}
ISO_C_TASK_COUNTS = {}

def iso_c(task_vectors, method_name="iso_c"):
    """
    Corrected ISO-C method based on cumulative vectors, ensuring equivalence to original implementation
    
    Steps:
    1. Calculate sum of current task vectors, add to cumulative vector
    2. First calculate average of cumulative vector
    3. For 2D matrices, restore original scale before SVD
    4. Apply SVD to uniformize singular values
    
    Args:
        task_vectors: Current task vector list
        method_name: Method name, used to distinguish cumulative vectors

    Returns:
        merged_vector: ISO-C processed vector
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Performing corrected ISO-C model merge, adding {len(task_vectors)} new task vectors")
    
    # Initialize global variables
    global ISO_C_CUMULATIVE_VECTORS, ISO_C_TASK_COUNTS
    
    # Get or initialize task count
    if method_name not in ISO_C_TASK_COUNTS:
        ISO_C_TASK_COUNTS[method_name] = 0
    
    # Calculate sum of current task vectors
    current_task_sum = {}
    for key in task_vectors[0].vector:
        tvs = [tv.vector[key].to(device) for tv in task_vectors]
        current_task_sum[key] = sum(tvs)
    
    # Update cumulative vector
    if method_name not in ISO_C_CUMULATIVE_VECTORS:
        print(f"Initializing cumulative task vectors for {method_name}")
        ISO_C_CUMULATIVE_VECTORS[method_name] = current_task_sum
    else:
        print(f"Updating cumulative task vectors for {method_name}")
        for key in ISO_C_CUMULATIVE_VECTORS[method_name]:
            if key in current_task_sum:
                ISO_C_CUMULATIVE_VECTORS[method_name][key] += current_task_sum[key]
    
    # Update task count
    ISO_C_TASK_COUNTS[method_name] += len(task_vectors)
    total_task_count = ISO_C_TASK_COUNTS[method_name]
    print(f"Current cumulative total tasks: {total_task_count}")
    
    # Apply ISO-C processing - strictly following original iso.py implementation
    new_vector = {}
    print("Computing SVD and applying ISO-C uniformization...")
    for key in ISO_C_CUMULATIVE_VECTORS[method_name]:
        # Clone vector to avoid modifying original cumulative vector
        accumulated = ISO_C_CUMULATIVE_VECTORS[method_name][key].clone().to(device)
        
        # Key fix 1: First calculate average
        avg_vector = accumulated / total_task_count
        
        # Apply ISO-C processing to 2D matrices
        if len(avg_vector.shape) == 2 and "text_projection" not in key:
            try:
                # Key fix 2: Restore original scale before SVD for 2D matrices
                iso_input = avg_vector * total_task_count  # Restore original scale
                
                # SVD decomposition
                U, S, V = torch.linalg.svd(iso_input, full_matrices=False)
                
                # ISO-C uniformization: set all singular values to mean
                S_mean = torch.ones_like(S) * S.mean()
                
                # Reconstruct vector
                new_vector[key] = torch.linalg.multi_dot((U, torch.diag(S_mean), V))
            except Exception as e:
                print(f"Error processing key {key}: {e}, keeping average value")
                new_vector[key] = avg_vector
        else:
            # For non-matrix parameters, use average directly
            new_vector[key] = avg_vector
    
    return new_vector


# New function: Recover task vectors from fine-tuned models
def recover_task_vectors_from_finetuned_models(tasks, experiment_dir, args):
    """
    Recover task vectors based on fine-tuned models
    
    Args:
        tasks: List of tasks to recover
        experiment_dir: Experiment directory path
        args: Program parameters

    Returns:
        recovered_vectors: Recovered task vectors dictionary
    """
    from src.models.task_vectors import NonLinearTaskVector
    import os
    import torch
    
    print(f"Recovering task vectors from fine-tuned models: {tasks}")
    recovered_vectors = {}
    
    # Load pretrained model as base
    base_model_path = os.path.join(experiment_dir, "pretrained_model.pt")
    if not os.path.exists(base_model_path):
        print(f"Pretrained model file does not exist: {base_model_path}")
        return recovered_vectors
        
    base_model_state = torch.load(base_model_path, map_location=args.device)
    
    # Fine-tuned models directory
    finetune_dir = os.path.join(experiment_dir, "finetunedModels")
    if not os.path.exists(finetune_dir):
        print(f"Fine-tuned models directory does not exist: {finetune_dir}")
        return recovered_vectors
    
    # Recover vectors for each task
    for task in tasks:
        # Try two possible fine-tuned model naming formats
        model_path = os.path.join(finetune_dir, f"finetuned_pretrained_{task}.pt")
        if not os.path.exists(model_path):
            # Try alternative path
            model_path = os.path.join(finetune_dir, f"finetuned_iso_c_{task}.pt")
            
        if os.path.exists(model_path):
            print(f"Recovering vector for task {task} from fine-tuned model: {model_path}")
            # Load fine-tuned model
            finetune_state = torch.load(model_path, map_location=args.device)
            
            # Calculate task vector
            task_vector = NonLinearTaskVector(args.model, base_model_state, finetune_state)
            recovered_vectors[task] = task_vector
            print(f"Successfully recovered vector for task {task}")
        else:
            print(f"Cannot find fine-tuned model file for task {task}")
    
    print(f"Total recovered {len(recovered_vectors)} task vectors")
    return recovered_vectors


# New function: Rebuild cumulative vectors
def rebuild_iso_c_cumulative_vectors(tasks, task_vectors, method_name="iso_c"):
    """
    Rebuild ISO-C cumulative vectors based on recovered task vectors
    
    Args:
        tasks: List of tasks to rebuild
        task_vectors: Task vectors dictionary
        method_name: Method name

    Returns:
        success: Whether rebuild was successful
    """
    global ISO_C_CUMULATIVE_VECTORS, ISO_C_TASK_COUNTS
    
    print(f"Rebuilding cumulative vectors for {method_name}, based on {len(tasks)} tasks")
    
    # Ensure all tasks have available vectors
    available_tasks = [task for task in tasks if task in task_vectors and task_vectors[task] is not None]
    if len(available_tasks) != len(tasks):
        missing = [task for task in tasks if task not in available_tasks]
        print(f"Warning: Missing task vectors: {missing}")
    
    if not available_tasks:
        print("No available task vectors, cannot rebuild cumulative vectors")
        return False
    
    # Get device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Reset cumulative vectors
    if method_name in ISO_C_CUMULATIVE_VECTORS:
        del ISO_C_CUMULATIVE_VECTORS[method_name]
    
    # Reset task count
    ISO_C_TASK_COUNTS[method_name] = 0
    
    # Add each task vector to cumulative vectors sequentially
    for task in available_tasks:
        # Get current task vector
        task_vector = task_vectors[task]
        
        # Calculate current task vector values
        current_vector = {}
        for key in task_vector.vector:
            current_vector[key] = task_vector.vector[key].to(device)
        
        # Add to cumulative vectors
        if method_name not in ISO_C_CUMULATIVE_VECTORS:
            ISO_C_CUMULATIVE_VECTORS[method_name] = current_vector
        else:
            for key in ISO_C_CUMULATIVE_VECTORS[method_name]:
                if key in current_vector:
                    ISO_C_CUMULATIVE_VECTORS[method_name][key] += current_vector[key]
        
        # Update task count
        ISO_C_TASK_COUNTS[method_name] += 1
    
    print(f"Successfully rebuilt cumulative vectors for {method_name}, containing {ISO_C_TASK_COUNTS[method_name]} tasks")
    return True


# Helper functions for saving and loading cumulative vectors
def save_iso_c_cumulative_vectors(save_dir):
    """Save all cumulative vectors and task counts to disk"""
    os.makedirs(save_dir, exist_ok=True)
    for method_name, vectors in ISO_C_CUMULATIVE_VECTORS.items():
        save_path = os.path.join(save_dir, f"iso_c_cumulative_vector_{method_name}.pt")
        task_count = ISO_C_TASK_COUNTS.get(method_name, 0)
        torch.save({"vectors": vectors, "task_count": task_count}, save_path)
        print(f"Saved ISO-C cumulative vectors for {method_name} to {save_path}, task count: {task_count}")


def load_iso_c_cumulative_vectors(load_dir):
    """Load all cumulative vectors and task counts from disk"""
    global ISO_C_CUMULATIVE_VECTORS, ISO_C_TASK_COUNTS
    for filename in os.listdir(load_dir):
        if filename.startswith("iso_c_cumulative_vector_") and filename.endswith(".pt"):
            method_name = filename.replace("iso_c_cumulative_vector_", "").replace(".pt", "")
            load_path = os.path.join(load_dir, filename)
            saved_data = torch.load(load_path)
            
            if isinstance(saved_data, dict) and "vectors" in saved_data and "task_count" in saved_data:
                ISO_C_CUMULATIVE_VECTORS[method_name] = saved_data["vectors"]
                ISO_C_TASK_COUNTS[method_name] = saved_data["task_count"]
                print(f"Loaded ISO-C cumulative vectors for {method_name}, task count: {ISO_C_TASK_COUNTS[method_name]}")
            else:
                # Compatible with old format
                ISO_C_CUMULATIVE_VECTORS[method_name] = saved_data
                print(f"Loaded ISO-C cumulative vectors for {method_name} (old format, no task count)")