import os
import torch
import json
import datetime
import gc
import numpy as np
from collections import defaultdict
from omegaconf import DictConfig, OmegaConf
import hydra

from src.models import ImageEncoder
from src.models.task_vectors import NonLinearTaskVector
from src.utils import parse_arguments
from src.eval.eval import eval_single_dataset
from src.finetune.sabcd_finetune import sabcd_finetune
from src.finetune.continual_finetune import continual_finetune
from src.data.data_utils import DATASETS
from src.config.config_utils import convert_omegaconf_to_native, create_experiment_dir, load_method_config
from src.models.model_utils import load_pretrained_model
from src.merging.merging import merge_tasks_incremental
from src.merging.task_arithmetic import save_cumulative_vectors, load_cumulative_vectors
from src.merging.ties_merge import save_ties_cumulative_vectors, load_ties_cumulative_vectors
from src.merging.swa_merge import save_swa_weights, load_swa_weights
from analyze.sar_analysis import calculate_sar_metrics


@hydra.main(config_path="config", config_name="config", version_base="1.3")
def run_continual_learning(cfg: DictConfig) -> None:
    """
    Implement a true continual learning process and compare different merging methods:
    1. Each method fine-tunes the new task using its own current model
    2. Merge the fine-tuned task vectors into their respective current models using different methods
    3. Evaluate the performance of different methods on all seen tasks
    4. Calculate SAR metrics for each stage to assess subspace alignment
    5. Continue processing the next task
    """
    args = parse_arguments()
    args.model_location = cfg.model_location
    args.model = cfg.model
    args.data_location = os.path.expanduser(cfg.data_location)
    args.device = "cuda" if torch.cuda.is_available() else "cpu"
    args.save_dir = os.path.join(args.model_location, args.model)

    # Merge method settings
    merge_methods = cfg.merge_methods if hasattr(
        cfg, 'merge_methods') else ["SAIM"]

    # Global default configuration
    global_config = {
        "optimize_alpha": cfg.optimize_alpha if hasattr(cfg, 'optimize_alpha') else False,
        "optimize_beta": cfg.optimize_beta if hasattr(cfg, 'optimize_beta') else False,
        "early_stopping": cfg.early_stopping if hasattr(cfg, 'early_stopping') else True,
        "early_stopping_patience": cfg.early_stopping_patience if hasattr(cfg, 'early_stopping_patience') else 3,
        "use_sabcd": cfg.use_sabcd if hasattr(cfg, 'use_sabcd') else True,
        "finetune_from_pretrained": cfg.finetune_from_pretrained if hasattr(cfg, 'finetune_from_pretrained') else True,
        "vector_from_pretrained": cfg.vector_from_pretrained if hasattr(cfg, 'vector_from_pretrained') else False,
        "calculate_sar": cfg.calculate_sar if hasattr(cfg, 'calculate_sar') else True,  # Whether to calculate SAR metrics
        "sar_epsilon": cfg.sar_epsilon if hasattr(cfg, 'sar_epsilon') else 0.05  # Epsilon parameter for SAR calculation
    }

    # Experiment continuation settings
    continue_experiment = cfg.continue_experiment if hasattr(
        cfg, 'continue_experiment') else False
    prev_experiment_dir = cfg.prev_experiment_dir if hasattr(
        cfg, 'prev_experiment_dir') else None
    start_task_idx = cfg.start_task_idx if hasattr(
        cfg, 'start_task_idx') else 0

    # Load specific configurations for each method
    method_configs = {}

    for method in merge_methods:
        try:
            # Try to load method-specific configuration file
            method_config_path = os.path.join(
                hydra.utils.get_original_cwd(),
                "config", "method", f"{method}.yaml"
            )

            if os.path.exists(method_config_path):
                print(f"Loading specific configuration for {method} method: {method_config_path}")
                method_cfg = OmegaConf.load(method_config_path)

                # Use new function to load all configuration items
                method_configs[method] = load_method_config(
                    method_cfg, global_config)

                # Print main configuration information
                print(f"{method} method configuration: Alpha optimization={method_configs[method].get('optimize_alpha', False)}, " +
                      f"Beta optimization={method_configs[method].get('optimize_beta', False)}, " +
                      f"Use SABCD={method_configs[method].get('use_sabcd', True)}, " +
                      f"Fine-tune from pretrained={method_configs[method].get('finetune_from_pretrained', True)}, " +
                      f"Vector from pretrained={method_configs[method].get('vector_from_pretrained', False)}")
            else:
                print(f"No specific configuration file found for {method} method, using global configuration")
                method_configs[method] = global_config.copy()
        except Exception as e:
            print(f"Error loading {method} method configuration: {e}, using global default configuration")
            method_configs[method] = global_config.copy()

    if continue_experiment and prev_experiment_dir:
        print(f"\n===== Continuing experiment: {prev_experiment_dir} =====")
        print(
            f"Continuing experiment from task index {start_task_idx} (corresponding task: {DATASETS[start_task_idx]})")

        # Use previous result directory
        result_dir = prev_experiment_dir

        # Load previous experiment configuration
        try:
            with open(os.path.join(result_dir, "experiment_config.json"), "r") as f:
                prev_config = json.load(f)
                # Can recover some settings from previous configuration
                if "merge_methods" in prev_config:
                    merge_methods = prev_config["merge_methods"]

                # If there are method-specific configurations, can also recover from here
                if "method_configs" in prev_config:
                    for method, config in prev_config["method_configs"].items():
                        if method in method_configs:
                            method_configs[method].update(config)
        except Exception as e:
            print(f"Error loading previous configuration: {e}")
            prev_config = {"merge_methods": merge_methods}

        # Load previous experiment results
        try:
            results_path = os.path.join(result_dir, "experiment_results.json")
            if os.path.exists(results_path):
                with open(results_path, "r") as f:
                    experiment_results = json.load(f)
                print(f"Previous experiment results loaded")
            else:
                experiment_results = defaultdict(dict)
                print("No previous results file found, creating new results dictionary")
        except Exception as e:
            print(f"Error loading previous results: {e}")
            experiment_results = defaultdict(dict)

        # Load previous method comparison results
        try:
            comparison_path = os.path.join(
                result_dir, "methods_comparison.json")
            if os.path.exists(comparison_path):
                with open(comparison_path, "r") as f:
                    comparison_results = json.load(f)
                print(f"Previous method comparison results loaded")
            else:
                comparison_results = {}
        except Exception as e:
            print(f"Error loading previous method comparison results: {e}")
            comparison_results = {}
            
        # Load previous SAR analysis results
        try:
            sar_path = os.path.join(result_dir, "sar_analysis_results.json")
            if os.path.exists(sar_path):
                with open(sar_path, "r") as f:
                    sar_results = json.load(f)
                print(f"Previous SAR analysis results loaded")
            else:
                sar_results = {}
                print("No previous SAR results file found, creating new results dictionary")
        except Exception as e:
            print(f"Error loading previous SAR results: {e}")
            sar_results = {}

        # Try to load previous cumulative vectors
        try:
            # Only load cumulative vectors for methods currently in use
            for method in merge_methods:
                if method == "task_arithmetic":
                    load_cumulative_vectors(result_dir)
                    print(f"Cumulative vectors for task_arithmetic loaded")
                elif method == "ties_merge":
                    load_ties_cumulative_vectors(result_dir)
                    print(f"Cumulative vectors for ties_merge loaded")
                elif method == "swa":
                    load_swa_weights(result_dir)
                    print(f"Cumulative weights for swa loaded")
        except Exception as e:
            print(f"Error loading cumulative vectors: {e}")

    else:
        print("Creating new experiment")
        for method in merge_methods:
            config = method_configs[method]
            print(f"{method} method: Alpha optimization={'Enabled' if config['optimize_alpha'] else 'Disabled'}, " +
                  f"Beta optimization={'Enabled' if config['optimize_beta'] else 'Disabled'}, " +
                  f"Use SABCD={'Yes' if config['use_sabcd'] else 'No'}, " +
                  f"Fine-tune mode={'Fine-tune from pretrained model' if config['finetune_from_pretrained'] else 'Fine-tune from current model'}, " +
                  f"Vector calculation base={'Pretrained model' if config['vector_from_pretrained'] else 'Current model'}")

        first_method = merge_methods[0]
        result_dir = create_experiment_dir(
            model_name=args.model,
            method_name=first_method,
            use_sabcd=method_configs[first_method].get("use_sabcd", True),
            finetune_from_pretrained=method_configs[first_method].get("finetune_from_pretrained", True)
        )

        # Record experiment configuration
        config_info = {
            "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            "datasets": DATASETS,
            "model": args.model,
            "merge_methods": merge_methods,
            "method_configs": {method: config for method, config in method_configs.items()},
            "global_config": global_config
        }
        converted_config_info = convert_omegaconf_to_native(config_info)

        with open(os.path.join(result_dir, "experiment_config.json"), "w") as f:
            json.dump(converted_config_info, f, indent=2)

        # New experiment results dictionary
        experiment_results = defaultdict(dict)
        comparison_results = {}
        sar_results = {}  # SAR analysis results

    # Load pretrained model as starting point
    pretrained_check = load_pretrained_model(args)

    # Initialize independent models and tracking variables for each merging method
    current_models = {}
    merged_tasks_history = {}
    task_counts = {}
    
    # Dictionary for SAR calculation results by stage
    sar_by_stage = {}

    for method in merge_methods:
        if continue_experiment:
            # Try to load previous model
            model_path = os.path.join(
                result_dir, f"model_{method}_after_task_{start_task_idx}.pt")
            if os.path.exists(model_path):
                print(f"Loading previous {method} model: {model_path}")
                model = ImageEncoder(args.model)
                model.load_state_dict(torch.load(
                    model_path, map_location=args.device))
                model = model.to(args.device)
                current_models[method] = model

                # Restore previous task history
                if method in experiment_results and any(key.startswith(f"task_{start_task_idx}_") for key in experiment_results[method]):
                    for task_key in experiment_results[method]:
                        if task_key.startswith(f"task_{start_task_idx}_"):
                            merged_tasks_history[method] = experiment_results[method][task_key]["merged_tasks"]
                            print(
                                f"Restoring task history for {method} method: {merged_tasks_history[method]}")
                            break
                else:
                    # If no specific record found, assume all tasks before start_task_idx have been processed
                    merged_tasks_history[method] = [DATASETS[i]
                                                    for i in range(start_task_idx)]
                    print(
                        f"No specific record found, assuming processed tasks for {method} method: {merged_tasks_history[method]}")

                # Set task count
                task_counts[method] = len(merged_tasks_history[method])
            else:
                print(f"Previous {method} model not found, creating new model")
                model = ImageEncoder(args.model)
                model.load_state_dict(pretrained_check)
                model = model.to(args.device)
                current_models[method] = model
                merged_tasks_history[method] = []
                task_counts[method] = 0
        else:
            # Create brand new model instances
            model = ImageEncoder(args.model)
            # Load the same pretrained weights
            model.load_state_dict(pretrained_check)
            # Move to correct device
            model = model.to(args.device)
            # Store in dictionary
            current_models[method] = model
            merged_tasks_history[method] = []
            task_counts[method] = 0

    print("Model loading completed, ready to start continual learning")

    # Create fine-tuned model directory
    finetune_models_dir = os.path.join(result_dir, "finetunedModels")
    os.makedirs(finetune_models_dir, exist_ok=True)

    # Initialize task vector dictionary and fine-tuned model dictionary
    task_vectors_dict = {}
    finetuned_models_dict = {}  # Save fine-tuned models for SAR calculation
    for method in merge_methods:
        task_vectors_dict[method] = {}
        finetuned_models_dict[method] = {}

    # Process each task sequentially
    for task_idx, dataset_name in enumerate(DATASETS[start_task_idx:], start=start_task_idx):
        print("\n" + "="*50)
        print(f"Task {task_idx+1}/{len(DATASETS)}: {dataset_name}")
        print("="*50)

        # Clean memory before task starts
        gc.collect()
        torch.cuda.empty_cache()

        # Extract dataset name and set parameter 0
        dataset_val = dataset_name + "Val"

        args.batch_size = 32 if args.model == "ViT-L-14" else 128
        args.num_grad_accumulation = 4 if args.model == "ViT-L-14" else 1
        args.lr = 1e-5

        # Process independent fine-tuning and task vector calculation for each method
        for method in merge_methods:
            # Get specific configuration for this method
            method_config = method_configs[method]

            # Decide fine-tune mode based on configuration
            if method_config["finetune_from_pretrained"]:
                print(f"\n{method} method - Step 1: Fine-tune task {dataset_name} based on original pretrained model")

                # Fine-tune based on pretrained model
                finetune_model_path = os.path.join(
                    finetune_models_dir, f"finetuned_pretrained_{dataset_name}.pt")
                pretrained_model_path = os.path.join(
                    result_dir, "pretrained_model.pt")

                # Save pretrained model first
                if not os.path.exists(pretrained_model_path):
                    torch.save(pretrained_check, pretrained_model_path)

                # First try to load fine-tuned model from external directory
                external_model_loaded = False
                
                # ckpt_model_path = os.path.join(args.model_location, args.model, dataset_name+'Val', "nonlinear_finetuned.pt")
                # if method_config["finetune_from_pretrained"] and not method_config["use_sabcd"]:
                #     if os.path.exists(ckpt_model_path):
                #         print(f"Prioritize loading external fine-tuned model: {ckpt_model_path}")
                #         external_state_dict = torch.load(ckpt_model_path, map_location=args.device)
                #         torch.save(external_state_dict, finetune_model_path)
                #         external_model_loaded = True
                #     else:
                #         print(f"External fine-tuned model not found: {ckpt_model_path}")
                
                # if not external_model_loaded:
                #     model_type_dir = args.model  # e.g., ViT-B-32
                #     finetune_type = "SABCD" if method_config["use_sabcd"] else "NOSABCD"
                #     external_model_path = os.path.join("finetunedModels", model_type_dir, finetune_type, f"finetuned_pretrained_{dataset_name}.pt")
                #     if not os.path.exists(external_model_path):
                #         external_model_path = os.path.join("finetunedModels", model_type_dir, finetune_type, f"{dataset_name}.pt")
                #     if os.path.exists(external_model_path):
                #         print(f"Found external pre-fine-tuned model: {external_model_path}, loading directly")
                #         external_state_dict = torch.load(external_model_path, map_location=args.device)
                #         torch.save(external_state_dict, finetune_model_path)
                #         external_model_loaded = True
                #     else:
                #         print(f"External pre-fine-tuned model not found: {external_model_path}")
                        
                # If no external model and local doesn't exist, perform fine-tuning
                if not external_model_loaded and (not os.path.exists(finetune_model_path) or not continue_experiment):
                    print(f"Fine-tuned model not found, performing fine-tuning operation...")
                    if method_config["use_sabcd"]:
                        # Use SA-BCD fine-tuning for subsequent tasks
                        print(f"{method} method - Using SA-BCD fine-tuning method for task")
                        sabcd_finetune(
                            args=args,
                            train_dataset=dataset_val,
                            starting_model_path=pretrained_model_path,
                            output_path=finetune_model_path
                        )
                    else:
                        continual_finetune(
                            args=args,
                            train_dataset=dataset_val,
                            starting_model_path=pretrained_model_path,
                            output_path=finetune_model_path
                        )
                elif os.path.exists(finetune_model_path) and not external_model_loaded:
                    print(f"Existing fine-tuned model found: {finetune_model_path}, skipping fine-tuning step")

                # Load fine-tuned model state
                finetuned_state_dict = torch.load(finetune_model_path)
                finetuned_on_device = {
                    k: v.to(args.device) for k, v in finetuned_state_dict.items()}
                
                # Store fine-tuned model for subsequent SAR analysis
                finetuned_models_dict[method][dataset_name] = finetuned_state_dict

                # Ensure pretrained model state dict is available on device
                pretrained_on_device = {
                    k: v.to(args.device) for k, v in pretrained_check.items()}
                
                current_state_dict = current_models[method].state_dict()
                current_on_device = {k: v.to(args.device) for k, v in current_state_dict.items()}

            else:
                # Fine-tune based on current merged model
                print(f"\n{method} method - Step 1: Fine-tune task {dataset_name} based on current merged model")

                # Create specific fine-tuned model path for current method
                finetune_model_path = os.path.join(
                    finetune_models_dir, f"finetuned_{method}_{dataset_name}.pt")
                temp_current_model_path = os.path.join(
                    result_dir, f"temp_{method}_current_model.pt")
                torch.save(
                    current_models[method].state_dict(), temp_current_model_path)

                # Check if fine-tuned model already exists
                if os.path.exists(finetune_model_path) and continue_experiment:
                    print(f"Existing fine-tuned model found: {finetune_model_path}, skipping fine-tuning step")
                else:
                    # If not exists, perform fine-tuning
                    if method_config["use_sabcd"]:
                        # Use SA-BCD fine-tuning for subsequent tasks
                        print(f"{method} method - Using SA-BCD fine-tuning method for task")
                        sabcd_finetune(
                            args=args,
                            train_dataset=dataset_val,
                            starting_model_path=temp_current_model_path,
                            output_path=finetune_model_path
                        )
                    else:
                        continual_finetune(
                            args=args,
                            train_dataset=dataset_val,
                            starting_model_path=temp_current_model_path,
                            output_path=finetune_model_path
                        )

                # Load fine-tuned model state
                finetuned_state_dict = torch.load(finetune_model_path)
                finetuned_on_device = {
                    k: v.to(args.device) for k, v in finetuned_state_dict.items()}
                
                # Store fine-tuned model for subsequent SAR analysis
                finetuned_models_dict[method][dataset_name] = finetuned_state_dict

                # Get current model state dict and move to device
                current_state_dict = current_models[method].state_dict()
                current_on_device = {k: v.to(args.device)
                                    for k, v in current_state_dict.items()}

                # Ensure pretrained model state dict is available on device
                pretrained_on_device = {
                    k: v.to(args.device) for k, v in pretrained_check.items()}
                
                # Store fine-tuned model for subsequent SAR analysis
                finetuned_models_dict[method][dataset_name] = finetuned_state_dict

                # Get current model state dict and move to device
                current_state_dict = current_models[method].state_dict()
                current_on_device = {k: v.to(args.device)
                                     for k, v in current_state_dict.items()}

                # Ensure pretrained model state dict is available on device
                pretrained_on_device = {
                    k: v.to(args.device) for k, v in pretrained_check.items()}

            # Independently of fine-tune mode, decide task vector calculation base based on vector_from_pretrained configuration
            if method_config["vector_from_pretrained"]:
                # Use pretrained model as base to calculate task vector
                print(f"\n{method} method - Step 2: Calculate task vector (from pretrained model to fine-tuned model)")
                task_vector = NonLinearTaskVector(
                    args.model, pretrained_on_device, finetuned_on_device)
            else:
                # Use current merged model as base to calculate task vector
                print(f"\n{method} method - Step 2: Calculate task vector (from current merged model to fine-tuned model)")
                task_vector = NonLinearTaskVector(
                    args.model, current_on_device, finetuned_on_device)

            # Store directly in memory
            task_vectors_dict[method][dataset_name] = task_vector

            # Release fine-tuned model state dict
            del finetuned_state_dict
            if 'finetuned_on_device' in locals():
                del finetuned_on_device
            if 'current_on_device' in locals():
                del current_on_device
            gc.collect()
            torch.cuda.empty_cache()

            # Process merging task vectors into model
            print(f"\n{method} method - Step 3: Merge task vectors into model")

            # Update task count
            task_counts[method] += 1

            # Add to merged tasks list for this method
            if dataset_name not in merged_tasks_history[method]:
                merged_tasks_history[method].append(dataset_name)

            # Create complete evaluation task list including current task
            all_tasks_for_eval = merged_tasks_history[method].copy()

            # Use task vectors in memory
            current_task_vectors = task_vectors_dict[method]

            args.result_dir = result_dir
            cfg.method_config = OmegaConf.create(method_config)
            # Merge tasks - using method-specific configuration
            current_models[method], _ = merge_tasks_incremental(
                [dataset_name],  # Current task
                current_task_vectors,  # Dictionary containing task vectors
                current_models[method],
                pretrained_check,
                cfg,
                args,
                method=method,
                task_count=task_counts[method],
                early_stopping=global_config["early_stopping"],
                patience=global_config["early_stopping_patience"],
                all_seen_tasks=all_tasks_for_eval,  # Use complete list including current task
            )

            # Save current model
            model_path = os.path.join(
                result_dir, f"model_{method}_after_task_{task_idx+1}.pt")
            torch.save(current_models[method].state_dict(), model_path)
            
            # Step 3.1: Calculate SAR metrics after merging - using CPU for calculation
            if global_config["calculate_sar"]:
                print(f"\n{method} method - Step 3.1: Calculate SAR metrics to assess subspace alignment (using CPU calculation)")
                
                # Get state dict from current model and transfer to CPU
                merged_state_dict = {k: v.cpu() for k, v in current_models[method].state_dict().items()}
                
                # Prepare single-task model state dicts for SAR calculation, using saved fine-tuned models
                task_models_for_sar = {}
                for task_name in merged_tasks_history[method]:
                    # 1. Check if fine-tuned model is already in memory
                    if task_name in finetuned_models_dict[method]:
                        # Transfer to CPU
                        task_models_for_sar[task_name] = {k: v.cpu() for k, v in finetuned_models_dict[method][task_name].items()}
                    else:
                        # 2. If not in memory, try to load from disk (directly to CPU)
                        try:
                            # Determine possible file names based on fine-tune mode
                            if method_configs[method]["finetune_from_pretrained"]:
                                # Case of fine-tuning from pretrained model
                                model_path = os.path.join(finetune_models_dir, f"finetuned_pretrained_{task_name}.pt")
                            else:
                                # Case of fine-tuning from current model
                                model_path = os.path.join(finetune_models_dir, f"finetuned_{method}_{task_name}.pt")
                            
                            # Check if file exists, load directly to CPU
                            if os.path.exists(model_path):
                                print(f"Loading fine-tuned model for task {task_name} to CPU from disk: {model_path}")
                                task_model_state = torch.load(model_path, map_location='cpu')
                                task_models_for_sar[task_name] = task_model_state
                                # Also update memory dictionary
                                finetuned_models_dict[method][task_name] = task_model_state
                            else:
                                # Try alternative naming format, also load to CPU
                                alt_model_path = os.path.join(finetune_models_dir, 
                                                        f"finetuned_{'pretrained' if not method_configs[method]['finetune_from_pretrained'] else method}_{task_name}.pt")
                                if os.path.exists(alt_model_path):
                                    print(f"Loading fine-tuned model for task {task_name} to CPU from disk (alternative path): {alt_model_path}")
                                    task_model_state = torch.load(alt_model_path, map_location='cpu')
                                    task_models_for_sar[task_name] = task_model_state
                                    finetuned_models_dict[method][task_name] = task_model_state
                                else:
                                    print(f"Warning: Fine-tuned model file for task {task_name} not found, skipping SAR calculation for it")
                        except Exception as e:
                            print(f"Error loading fine-tuned model for task {task_name}: {e}")
                
                # Transfer pretrained model to CPU as well
                pretrained_cpu = {k: v.cpu() for k, v in pretrained_check.items()}
                
                # Calculate SAR metrics - on CPU
                print(f"Calculating SAR metrics for {method} method after task {dataset_name} on CPU...")
                try:
                    sar_metrics = calculate_sar_metrics(
                        pretrained_cpu, 
                        task_models_for_sar,
                        merged_state_dict,
                        device='cpu'  # Explicitly specify using CPU
                    )
                    
                    # Record SAR analysis results
                    stage_key = f"{method}_task_{task_idx+1}_{dataset_name}"
                    sar_by_stage[stage_key] = {
                        "task_avg_sar": sar_metrics["task_avg_sar"],
                        "overall_avg_sar": sar_metrics["overall_avg_sar"],
                        "valid_layer_count": sar_metrics["valid_layer_count"],
                        "skipped_layers": len(sar_metrics["skipped_layers"]),
                        "merged_tasks": merged_tasks_history[method].copy()
                    }
                    
                    # Print SAR results
                    print(f"\nSAR metrics for {method} method after merging task {dataset_name}:")
                    print("Average SAR values for each task:")
                    for task_name, avg_sar in sar_metrics["task_avg_sar"].items():
                        if isinstance(avg_sar, float) and not np.isnan(avg_sar):
                            print(f"  {task_name}: {avg_sar:.4f}")
                        else:
                            print(f"  {task_name}: NaN")
                    
                    overall_avg_sar = sar_metrics["overall_avg_sar"]
                    if isinstance(overall_avg_sar, float) and not np.isnan(overall_avg_sar):
                        print(f"\nOverall average SAR value: {overall_avg_sar:.4f}")
                    else:
                        print("\nOverall average SAR value: NaN")
                        
                    print(f"Valid layer count: {sar_metrics['valid_layer_count']} (skipped {len(sar_metrics['skipped_layers'])} layers)")
                    
                    # Perform memory cleanup after SAR calculation
                    task_models_for_sar.clear()
                    del pretrained_cpu, merged_state_dict
                    gc.collect()
                    
                    # Save detailed SAR analysis results
                    sar_detail_path = os.path.join(
                        result_dir, 
                        f"sar_metrics_{method}_task_{task_idx+1}_{dataset_name}.json"
                    )
                    with open(sar_detail_path, 'w') as f:
                        # Convert nan and inf to strings
                        def convert_nan_inf(obj):
                            if isinstance(obj, float):
                                if np.isnan(obj):
                                    return "NaN"
                                elif np.isinf(obj):
                                    return "Inf" if obj > 0 else "-Inf"
                                else:
                                    return obj
                            elif isinstance(obj, dict):
                                return {k: convert_nan_inf(v) for k, v in obj.items()}
                            elif isinstance(obj, list):
                                return [convert_nan_inf(item) for item in obj]
                            else:
                                return obj
                        
                        json.dump(convert_nan_inf(sar_metrics), f, indent=2)
                        
                    print(f"Detailed SAR metrics saved to: {sar_detail_path}")
                except Exception as e:
                    print(f"Error calculating SAR metrics: {e}")
                    sar_by_stage[f"{method}_task_{task_idx+1}_{dataset_name}"] = {
                        "error": str(e),
                        "merged_tasks": merged_tasks_history[method].copy()
                    }

        # Evaluate performance of each method's current model on all seen tasks
        print(f"\nStep 4: Evaluate performance of all methods' models on seen tasks")

        # Store results for current task
        current_task_results = {}

        for method in merge_methods:
            print(f"\n----- Evaluating model performance for {method} method -----")
            task_accuracies = {}

            # Evaluate seen tasks
            for seen_task in merged_tasks_history[method]:
                seen_task_val = seen_task + "Val"
                result = eval_single_dataset(
                    current_models[method], seen_task_val, args)
                acc = result['top1']
                task_accuracies[seen_task] = acc
                print(f"  - Task {seen_task}: {acc*100:.2f}%")
                torch.cuda.empty_cache()  # Clean cache after each evaluation

            # Calculate average accuracy
            avg_accuracy = sum(task_accuracies.values()) / len(task_accuracies)
            print(f"\nAverage accuracy for {method} method on all seen tasks: {avg_accuracy*100:.2f}%")

            # Store evaluation results
            experiment_results[method][f"task_{task_idx+1}_{dataset_name}"] = {
                "task_accuracies": task_accuracies,
                "average_accuracy": avg_accuracy,
                "merged_tasks": merged_tasks_history[method].copy()
            }

            # Current stage results
            current_task_results[method] = {
                "average_accuracy": avg_accuracy
            }

            #  Save cumulative vectors
            try:
                for method in merge_methods:
                    if method == "task_arithmetic":
                        save_cumulative_vectors(result_dir)
                        print(f"Cumulative vectors for task_arithmetic saved")
                    elif method == "ties_merge":
                        save_ties_cumulative_vectors(result_dir)
                        print(f"Cumulative vectors for ties_merge saved")
                    elif method == "swa":
                        save_swa_weights(result_dir)
                        print(f"Cumulative weights for swa saved")
                print("Required cumulative vectors saved")
            except Exception as e:
                print(f"Error saving cumulative vectors: {e}")

        # Compare performance differences between methods at current stage
        best_method = max(current_task_results.items(),
                          key=lambda x: x[1]["average_accuracy"])[0]
        print(f"\nBest performing method after task {task_idx+1} ({dataset_name}): {best_method}, "
              f"Accuracy: {current_task_results[best_method]['average_accuracy']*100:.2f}%")

        # Save experiment results
        results_path = os.path.join(result_dir, "experiment_results.json")
        converted_results = convert_omegaconf_to_native(experiment_results)
        with open(results_path, "w") as f:
            json.dump(converted_results, f, indent=2, default=lambda o: float(
                o) if isinstance(o, (np.float16, np.float32, np.float64)) else o)

        # Save method comparison results
        comparison_results[f"task_{task_idx+1}_{dataset_name}"] = {
            "best_method": best_method,
            "accuracies": {method: current_task_results[method]["average_accuracy"] for method in merge_methods}
        }

        converted_comparison = convert_omegaconf_to_native(comparison_results)
        comparison_path = os.path.join(result_dir, "methods_comparison.json")
        with open(comparison_path, "w") as f:
            json.dump(converted_comparison, f, indent=2, default=lambda o: float(
                o) if isinstance(o, (np.float16, np.float32, np.float64)) else o)
            
        # Save SAR analysis results
        if global_config["calculate_sar"]:
            sar_results.update(sar_by_stage)
            sar_path = os.path.join(result_dir, "sar_analysis_results.json")
            with open(sar_path, "w") as f:
                json.dump(sar_results, f, indent=2, default=lambda o: float(
                    o) if isinstance(o, (np.float16, np.float32, np.float64)) else o)

        # Clean memory after task ends
        gc.collect()
        torch.cuda.empty_cache()

    # Save final comparison results
    final_comparison = {}
    for method in merge_methods:
        final_task = f"task_{len(DATASETS)}_{DATASETS[-1]}"
        if method in experiment_results and final_task in experiment_results[method]:
            final_comparison[method] = experiment_results[method][final_task]["average_accuracy"]

    if final_comparison:
        best_final_method = max(final_comparison.items(), key=lambda x: x[1])[0]
        print(
            f"\nBest performing method after all tasks: {best_final_method}, Final accuracy: {final_comparison[best_final_method]*100:.2f}%")

        converted_final = convert_omegaconf_to_native({
            "best_method": best_final_method,
            "final_accuracies": final_comparison
        })
        final_comparison_path = os.path.join(result_dir, "final_comparison.json")
        with open(final_comparison_path, "w") as f:
            json.dump(converted_final, f, indent=2)

    # If SAR calculation is enabled, generate SAR metrics trend summary over tasks
    if global_config["calculate_sar"] and sar_by_stage:
        sar_trend = {}
        for method in merge_methods:
            sar_trend[method] = []
            for task_idx, dataset_name in enumerate(DATASETS[:len(merged_tasks_history[method])], start=1):
                stage_key = f"{method}_task_{task_idx}_{dataset_name}"
                if stage_key in sar_by_stage:
                    sar_trend[method].append({
                        "task_idx": task_idx,
                        "task_name": dataset_name,
                        "overall_avg_sar": sar_by_stage[stage_key].get("overall_avg_sar", "NaN"),
                        "task_specific_sar": sar_by_stage[stage_key].get("task_avg_sar", {}).get(dataset_name, "NaN")
                    })
        
        # Save SAR trend
        sar_trend_path = os.path.join(result_dir, "sar_trend.json")
        with open(sar_trend_path, "w") as f:
            json.dump(sar_trend, f, indent=2, default=lambda o: "NaN" if isinstance(o, float) and np.isnan(o) else float(
                o) if isinstance(o, (np.float16, np.float32, np.float64)) else o)
        print(f"SAR trend data saved to: {sar_trend_path}")

    print(f"\nContinual learning experiment completed! All results saved to {result_dir} folder")


if __name__ == "__main__":
    try:
        run_continual_learning()
    except Exception as e:
        print("\n" + "="*80)
        print(f"Program execution error: {e}")
        print("Error details:")
        import traceback
        traceback.print_exc()
        print("="*80 + "\n")
        import time
        time.sleep(5)        
        os.system("shutdown")
    os.system("shutdown")