from src.eval.evaluation import evaluate_model_on_tasks
from src.models.model_utils import apply_merged_vector
from src.merging.saim import SAIM

def print_optimization_results(param_name, param_accuracies, best_param):
    """Print parameter optimization results summary"""
    print(f"\n{param_name} coefficient optimization results summary:")
    for param, accuracy, _ in param_accuracies:
        star = "*" if param == best_param else " "
        print(f"{star} {param_name} = {param}: Average accuracy = {accuracy*100:.2f}%")
    
    # Print detailed results for best parameter on each task
    best_result = next((res for param, _, res in param_accuracies if param == best_param), None)
    if best_result:
        print(f"\nPerformance of best {param_name} on each task:")
        for task, acc in best_result.items():
            print(f"  - {task}: {acc*100:.2f}%")

    print(f"\n{param_name} optimization completed, best value: {best_param}")

def optimize_parameter(param_name, param_values, merged_vector_fn, apply_model_fn, 
                      eval_tasks, args, early_stopping=True, patience=3):
    """
    General parameter optimization framework
    
    Args:
        param_name: Parameter name ('alpha' or 'beta')
        param_values: List of parameter values
        merged_vector_fn: Function to generate merged vector, takes parameter value as input (can be None)
        apply_model_fn: Function to apply model, creates test model
        eval_tasks: List of tasks to evaluate
        args: Arguments containing necessary evaluation parameters like device
        early_stopping: Whether to enable early stopping
        patience: Early stopping patience value
        
    Returns:
        tuple: (best parameter value, best model)
    """
    print(f"Starting optimization of {param_name} coefficient...")
    device = args.device
    
    print(f"Will evaluate performance on {len(eval_tasks)} tasks: {eval_tasks}")

    best_accuracy = -1.0
    best_param = 1.0  # Default value
    best_model = None
    not_improved_count = 0  # Early stopping counter
    param_accuracies = []  # Store results

    for param_value in param_values:
        print(f"Evaluating {param_name} value = {param_value}:")
        
        # 1. Generate merged vector
        merged_vector = merged_vector_fn(param_value) if merged_vector_fn else None
            
        # 2. Apply vector to create model
        test_model = apply_model_fn(param_value, merged_vector)
        test_model = test_model.to(device)

        # 3. Evaluate model performance
        task_accuracies, avg_accuracy = evaluate_model_on_tasks(test_model, eval_tasks, args)
        param_accuracies.append((param_value, avg_accuracy, task_accuracies))

        # 4. Update best results
        if avg_accuracy > best_accuracy:
            best_accuracy = avg_accuracy
            best_param = param_value
            best_model = test_model
            not_improved_count = 0  # Reset early stopping counter
            print(f"  * Found new best {param_name}: {best_param}, average accuracy: {best_accuracy*100:.2f}%")
        else:
            not_improved_count += 1
            print(f"  - Accuracy not improved, consecutive non-improvements: {not_improved_count}/{patience}")

            # Check if early stopping is triggered
            if early_stopping and not_improved_count >= patience:
                print(f"\nEarly stopping triggered! No performance improvement in {patience} consecutive iterations. Stopping at {param_name} = {param_value}")
                break

    # Print results summary
    print_optimization_results(param_name, param_accuracies, best_param)
    
    return best_param, best_model
   
def optimize_merge_alpha(base_model_state_dict, merged_vector, current_state_dict, tasks_to_merge, task_vectors, 
                         cfg, args, early_stopping=True, patience=3, all_seen_tasks=None, method="SAIM"):
    """Wrapper function for optimizing alpha coefficient"""
    
    # Define alpha value range
    if method in ["SAIM", "iso_c"]: alpha_values = [round(0.1 * i, 1) for i in range(5, 31)]  # 0.7 to 3.0, step 0.1
    # elif method == "dare": alpha_values = [round(0.1 * i, 1) for i in range(18, 31)]  # 1.8 to 3.0, step 0.1
    else: alpha_values = [round(0.1 * i, 1) for i in range(1, 31)]  # 0.1 to 3.0, step 0.1
    
    # Create function to apply model 
    def apply_alpha_model(alpha, merged_vec):
        return apply_merged_vector(
            base_model_state_dict,
            merged_vector,  
            alpha,
            args.device,
            method,
            task_vectors[tasks_to_merge[0]].model_name
        )
    
    # Determine evaluation task list
    eval_tasks = all_seen_tasks if all_seen_tasks else tasks_to_merge
    
    # Call general parameter optimization framework
    return optimize_parameter(
        "alpha", 
        alpha_values, 
        None,  # No need for merged_vector_fn since we already have merged_vector
        apply_alpha_model, 
        eval_tasks,
        args, 
        early_stopping, 
        patience
    )

def optimize_merge_beta(base_model_state_dict, current_task_vectors, current_state_dict, tasks_to_merge, 
                      task_vectors, cfg, args, task_count=1, early_stopping=True, patience=3, all_seen_tasks=None,
                      method="SAIM"):
    """Wrapper function for optimizing beta coefficient"""
    
    # Define beta value range
    beta_values = [round(0.1 * i, 1) for i in range(1, 21)]  # 0.1 to 2.0, step 0.1
    
    # Create function to generate merged vector
    def create_beta_vector(beta):
        if method == "SAIM":
            return SAIM(
                current_task_vectors, 
                current_state_dict, 
                base_model_state_dict, 
                task_count, 
                beta
            )
        else:
            raise ValueError(f"Unsupported method: {method}")
    
    # Create function to apply model
    def apply_beta_model(beta, merged_vec):
        return apply_merged_vector(
            base_model_state_dict,
            merged_vec,
            1.0,  # Use default alpha=1.0
            args.device,
            method,
            task_vectors[tasks_to_merge[0]].model_name
        )
    
    # Determine evaluation task list
    eval_tasks = all_seen_tasks if all_seen_tasks else tasks_to_merge
    
    # Call general parameter optimization framework
    return optimize_parameter(
        "beta", 
        beta_values, 
        create_beta_vector, 
        apply_beta_model, 
        eval_tasks,
        args, 
        early_stopping, 
        patience
    )