import torch
import os
import numpy as np
from src.models.task_vectors import NonLinearTaskVector
from src.models.model_utils import load_pretrained_model

def dare_merge(task_vectors, drop_rate=0.9, use_rescale=True, mask_strategy="random"):
    """
    Implement DARE (Drop And REscale) merge algorithm
    Reduce task interference by randomly or importance-based dropping of parameters
    
    Args:
        task_vectors: List of task vectors to merge
        drop_rate: Proportion of parameters to drop, default 0.9
        use_rescale: Whether to rescale retained parameters, default True
        mask_strategy: Mask generation strategy, options "random" or "magnitude"
        
    Returns:
        merged_vector: Merged task vector
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Performing DARE merge, {len(task_vectors)} task vectors (drop_rate={drop_rate}, use_rescale={use_rescale}, strategy={mask_strategy})")
    
    # Initialize result vector
    merged_vector = {}
    
    # For each parameter key in the first task vector
    for key in task_vectors[0].vector:
        # Collect values for this key from all task vectors and move to same device
        all_params = [tv.vector[key].to(device) for tv in task_vectors]
        
        # Create masks
        if mask_strategy == "random":
            # Random mask strategy
            masks = []
            for i, param in enumerate(all_params):
                # Create independent random mask for each task
                mask = torch.rand_like(param, dtype=torch.float) > drop_rate
                masks.append(mask)
                
        elif mask_strategy == "magnitude":
            # Magnitude-based mask strategy
            masks = []
            for i, param in enumerate(all_params):
                param_abs = torch.abs(param)
                # Calculate mask based on absolute value threshold
                threshold = torch.quantile(param_abs.flatten(), drop_rate)
                mask = param_abs > threshold
                masks.append(mask)
        else:
            raise ValueError(f"Unsupported mask strategy: {mask_strategy}")
        
        # Apply masks and merge
        masked_params = []
        for i, (param, mask) in enumerate(zip(all_params, masks)):
            # Apply mask
            masked_param = param.clone() * mask
            
            # If rescaling is enabled, adjust parameter values based on retained parameter ratio
            if use_rescale:
                # Calculate retained parameter ratio
                keep_ratio = mask.float().mean()
                # Avoid division by zero
                if keep_ratio > 0:
                    scale_factor = 1.0 / keep_ratio
                    masked_param = masked_param * scale_factor
            
            masked_params.append(masked_param)
        
        # Average merge
        merged_vector[key] = sum(masked_params) / len(masked_params)
    
    return merged_vector

def recover_task_vectors_from_finetuned_models(tasks, experiment_dir, args):
    """
    Reconstruct task vectors from saved fine-tuned models for DARE merge algorithm
    
    Args:
        tasks: List of tasks to recover
        experiment_dir: Experiment results directory
        args: Parameters containing model name, device, etc.
        
    Returns:
        recovered_vectors: Recovered task vectors dictionary {task_name: task_vector}
    """
    print(f"Recovering task vectors from saved fine-tuned models for DARE method...")
    
    # Define fine-tuned models directory
    finetune_models_dir = os.path.join(experiment_dir, "finetunedModels")
    pretrained_state_dict = load_pretrained_model(args)    
    pretrained_on_device = {k: v.to(args.device) for k, v in pretrained_state_dict.items()}
    
    # Initialize result dictionary
    recovered_vectors = {}
    
    # Process each task to recover
    for task in tasks:
        # Try different naming formats for fine-tuned models
        potential_paths = [
            os.path.join(finetune_models_dir, f"finetuned_dare_{task}.pt"),          # DARE specific
            os.path.join(finetune_models_dir, f"finetuned_pretrained_{task}.pt"),    # General pretrained fine-tune
            os.path.join(finetune_models_dir, f"finetuned_{task}.pt")               # General naming
        ]
        
        finetuned_path = None
        for path in potential_paths:
            if os.path.exists(path):
                finetuned_path = path
                break
                
        if finetuned_path is None:
            print(f"  Cannot find any fine-tuned model for task '{task}', skipping recovery")
            continue
            
        try:
            # Load fine-tuned model state dictionary
            finetuned_state_dict = torch.load(finetuned_path, map_location=args.device)
            finetuned_on_device = {k: v.to(args.device) for k, v in finetuned_state_dict.items()}
            
            # Calculate task vector
            task_vector = NonLinearTaskVector(args.model, pretrained_on_device, finetuned_on_device)
            
            # Save to result dictionary
            recovered_vectors[task] = task_vector
            print(f"  Successfully recovered task vector for '{task}', source: {os.path.basename(finetuned_path)}")
            
        except Exception as e:
            print(f"  Error recovering task vector for '{task}': {e}")
            
    # Report results
    print(f"Task vector recovery completed: Successfully recovered {len(recovered_vectors)}/{len(tasks)} task vectors")
    
    # Release memory
    del pretrained_state_dict
    import gc
    gc.collect()
    torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    return recovered_vectors