from src.merging.swa_merge import apply_swa_merge
from src.models.model_utils import apply_merged_vector
from src.merging.saim import SAIM
from src.merging.iso_c import iso_c
from src.merging.task_arithmetic import task_arithmetic_merge
from src.merging.magmax import magmax_merge
from src.merging.ties_merge import ties_merge
from src.merging.dare import dare_merge
from src.utils.distributed import is_main_process
from src.models import ImageEncoder
import os
import torch


def merge_tasks_incremental(tasks_to_merge, task_vectors, current_model, base_model_state_dict, cfg, args,
                            method="iso_c", task_count=1, svd_stats=None, early_stopping=True, patience=3, all_seen_tasks=None):
    """Incremental task merging, supporting optional alpha and beta optimization"""
    from src.optimizers.optimization import optimize_merge_alpha, optimize_merge_beta
    # Determine if main process
    is_main = is_main_process()

    # Get configuration for whether to optimize alpha and beta
    optimize_alpha = cfg.method_config.optimize_alpha if hasattr(
        cfg.method_config, 'optimize_alpha') else False
    optimize_beta = cfg.method_config.optimize_beta if hasattr(
        cfg.method_config, 'optimize_beta') else False

    # Print information only in main process
    if is_main:
        print(f"Incremental task merging: {tasks_to_merge} (using {method} method)")
        print(f"Alpha optimization: {'Enabled' if optimize_alpha else 'Disabled (using default value 1.0)'}")
        print(f"Beta optimization: {'Enabled' if optimize_beta else 'Disabled (using default value 1.0)'}")

    # Select task vectors for current tasks
    selected_task_vectors = [task_vectors[task]
                             for task in tasks_to_merge
                             if task in task_vectors and task_vectors[task] is not None]

    if not selected_task_vectors:
        if is_main:
            print("Warning: No valid task vectors available for merging")
        return current_model, svd_stats

    current_state_dict = current_model.state_dict()
    use_sabcd = cfg.method_config.use_sabcd if hasattr(
        cfg.method_config, 'use_sabcd') else False
    best_beta = 1.0 if use_sabcd else 0.8  # Default beta value

    # If using SAIM method and need to optimize beta, optimize beta parameter first (main process only)
    if method == "SAIM" and optimize_beta:
        # Main process executes beta optimization
        if is_main:
            print(f"\nStep 1: Optimizing beta parameter for {method}...")
            best_beta, _ = optimize_merge_beta(
                base_model_state_dict,
                selected_task_vectors,
                current_state_dict,
                tasks_to_merge,
                task_vectors,
                cfg,
                args,
                task_count=task_count,
                early_stopping=early_stopping,
                patience=patience,
                all_seen_tasks=all_seen_tasks,
                method=method
            )
            print(f"Best beta value found for {method} method: {best_beta}")

        # Broadcast best_beta value to all processes
        if hasattr(args, 'world_size') and args.world_size > 1:
            # Create tensor for broadcasting
            beta_tensor = torch.tensor([best_beta], device=args.device)
            torch.distributed.broadcast(beta_tensor, 0)
            best_beta = beta_tensor.item()

    # Generate merged vector based on different methods (all processes execute to ensure consistent results)
    if method == "simple":
        merged_vector = {}
        for key in selected_task_vectors[0].vector:
            tvs = [tv.vector[key].to(current_state_dict[key].device)
                   for tv in selected_task_vectors]
            merged_vector[key] = sum(tvs) / len(tvs)

    elif method == "SAIM":
        # Use optimized beta parameter
        merged_vector = SAIM(
            selected_task_vectors, current_state_dict, base_model_state_dict, task_count, best_beta)

    elif method == "iso_c":
        # Check if need to recover previous task vectors from fine-tuned models
        if all_seen_tasks and len(all_seen_tasks) > len(tasks_to_merge):
            previous_tasks = [task for task in all_seen_tasks if task not in tasks_to_merge]
            
            # Import recovery and rebuild functions from iso_c module
            from src.merging.iso_c import recover_task_vectors_from_finetuned_models, rebuild_iso_c_cumulative_vectors
            
            # Check if need to recover cumulative vectors
            need_recovery = False
            try:
                # Check if global variables exist and have content
                from src.merging.iso_c import ISO_C_CUMULATIVE_VECTORS, ISO_C_TASK_COUNTS
                if method not in ISO_C_CUMULATIVE_VECTORS or len(ISO_C_TASK_COUNTS.get(method, 0)) == 0:
                    need_recovery = True
            except:
                need_recovery = True
            
            if need_recovery and is_main:
                print(f"Detected missing ISO-C cumulative vectors, attempting recovery from fine-tuned models...")
                
                # Recover missing task vectors
                missing_tasks = [task for task in previous_tasks 
                            if task not in task_vectors or task_vectors[task] is None]
                
                if missing_tasks:
                    experiment_dir = os.path.abspath(args.result_dir if hasattr(args, "result_dir") and args.result_dir else args.save_dir)
                    
                    # Recover task vectors
                    recovered_vectors = recover_task_vectors_from_finetuned_models(
                        missing_tasks,
                        experiment_dir,
                        args
                    )
                    
                    # Update task vectors dictionary
                    for task, vector in recovered_vectors.items():
                        task_vectors[task] = vector
                
                # Rebuild cumulative vectors
                rebuild_success = rebuild_iso_c_cumulative_vectors(
                    previous_tasks,  # Rebuild based only on previous tasks, current tasks will be added in iso_c function
                    task_vectors,
                    method_name=method
                )
                
                if not rebuild_success:
                    print("Warning: Unable to fully rebuild ISO-C cumulative vectors, merge results may be inconsistent")
        
        # Execute ISO-C merging
        merged_vector = iso_c(selected_task_vectors, method_name=method)

    elif method == "magmax":
        all_tasks = all_seen_tasks if all_seen_tasks else tasks_to_merge
        missing_tasks = [
            task for task in all_tasks if task not in task_vectors or task_vectors[task] is not None]

        # If there are missing task vectors
        if missing_tasks and is_main:
            # Import recovery function from magmax module
            from src.merging.magmax import recover_task_vectors_from_finetuned_models

            experiment_dir = os.path.abspath(cfg.prev_experiment_dir)
            # Call function to recover task vectors
            recovered_vectors = recover_task_vectors_from_finetuned_models(
                missing_tasks,
                experiment_dir,
                args
            )

            # Update task vectors dictionary
            for task, vector in recovered_vectors.items():
                task_vectors[task] = vector

        # Use all available task vectors
        all_task_vectors = [task_vectors[task] for task in all_tasks
                            if task in task_vectors and task_vectors[task] is not None]

        merged_vector = magmax_merge(all_task_vectors)

    elif method == "task_arithmetic":
        merged_vector = task_arithmetic_merge(
            selected_task_vectors, method_name="task_arithmetic")

    elif method == "ties_merge":
        # Get TIES-MERGING parameters from configuration
        ties_k = cfg.method_config.ties_k if hasattr(cfg.method_config, 'ties_k') else 20
        merged_vector = ties_merge(
            selected_task_vectors, k=ties_k, method_name="ties_merge")

    elif method == "swa":
        # Load fine-tuned model
        finetune_model_path = os.path.join(
            args.result_dir,
            "finetunedModels",
            f"finetuned_pretrained_{tasks_to_merge[0]}.pt"
        )
        if os.path.exists(finetune_model_path):
            finetuned_state_dict = torch.load(
                finetune_model_path, map_location=args.device)
            # Directly apply SWA merging
            merged_model = apply_swa_merge(
                current_model, finetuned_state_dict, method)
            return merged_model, None
        else:
            print(f"Warning: Cannot find fine-tuned model file {finetune_model_path}")
            return current_model, None

    elif method == "dare":
        # Get DARE parameters from configuration
        drop_rate = cfg.method_config.dare_drop_rate if hasattr(
            cfg.method_config, 'dare_drop_rate') else 0.9
        use_rescale = cfg.method_config.dare_use_rescale if hasattr(
            cfg.method_config, 'dare_use_rescale') else True
        mask_strategy = cfg.method_config.dare_mask_strategy if hasattr(
            cfg.method_config, 'dare_mask_strategy') else "random"

        print(
            f"Performing DARE merge algorithm: drop_rate={drop_rate}, use_rescale={use_rescale}, mask_strategy={mask_strategy}")
        all_tasks = all_seen_tasks if all_seen_tasks else tasks_to_merge
        missing_tasks = [
            task for task in all_tasks if task not in task_vectors or task_vectors[task] is not None]

        # If there are missing task vectors
        if missing_tasks and is_main:
            # Import recovery function from dare module
            from src.merging.dare import recover_task_vectors_from_finetuned_models

            experiment_dir = os.path.abspath(cfg.prev_experiment_dir)
            # Call function to recover task vectors
            recovered_vectors = recover_task_vectors_from_finetuned_models(
                missing_tasks,
                experiment_dir,
                args
            )

            # Update task vectors dictionary
            for task, vector in recovered_vectors.items():
                task_vectors[task] = vector

        # Use all available task vectors
        all_task_vectors = [task_vectors[task] for task in all_tasks
                            if task in task_vectors and task_vectors[task] is not None]
        
        merged_vector = dare_merge(
            all_task_vectors,
            drop_rate=drop_rate,
            use_rescale=use_rescale,
            mask_strategy=mask_strategy
        )

    else:
        print(f"Warning: Unknown merging method {method}, returning unmodified model")
        return current_model, None

    # Step 2: Decide subsequent processing based on whether to optimize alpha
    if optimize_alpha:
        if hasattr(args, "result_dir") and args.result_dir:
            temp_dir = os.path.abspath(args.result_dir)
        else:
            temp_dir = os.path.abspath(args.save_dir)

        # Use fixed filename instead of timestamp
        temp_model_path = os.path.join(
            temp_dir, f"optimized_model_{method}_temp.pt")

        # Main process executes alpha optimization
        if is_main:
            print("\nStep 2: Optimizing alpha parameter...")
            best_alpha, best_model = optimize_merge_alpha(
                base_model_state_dict,
                merged_vector,
                current_state_dict,
                tasks_to_merge,
                task_vectors,
                cfg,
                args,
                early_stopping=early_stopping,
                patience=patience,
                all_seen_tasks=all_seen_tasks,
                method=method
            )
            print(f"Best alpha value found for {method} method: {best_alpha}")

            # Save optimized model state in main process
            if hasattr(args, 'world_size') and args.world_size > 1:
                # Ensure directory exists
                os.makedirs(os.path.dirname(temp_model_path), exist_ok=True)
                torch.save(best_model.state_dict(), temp_model_path)
                print(f"Optimized model saved to: {temp_model_path}")

        # Broadcasting and synchronization processing
        if hasattr(args, 'world_size') and args.world_size > 1:
            # Synchronize first to ensure main process has completed saving
            torch.distributed.barrier()

            if not is_main:
                # Non-main processes load optimized model
                print(f"Process {args.rank}: Loading optimized model from {temp_model_path}")
                best_model = ImageEncoder(
                    task_vectors[tasks_to_merge[0]].model_name)
                best_model.load_state_dict(torch.load(
                    temp_model_path, map_location=args.device))

            # Synchronize again to ensure all processes have completed loading
            torch.distributed.barrier()

            # Delete temporary file
            if is_main and os.path.exists(temp_model_path):
                try:
                    os.remove(temp_model_path)
                except Exception as e:
                    print(f"Warning: Unable to delete temporary file {temp_model_path}: {e}")

        return best_model, None
    else:
        # Do not optimize alpha, use default value 1.0
        alpha = 1.0
        if is_main:
            print(f"Using default alpha value: {alpha}")

        merged_model = apply_merged_vector(
            base_model_state_dict,
            merged_vector,
            alpha,
            args.device,
            method,
            task_vectors[tasks_to_merge[0]].model_name
        )
        return merged_model, None