from src.eval.eval import eval_single_dataset
from src.data.data_utils import DATASETS

def evaluate_task(dataset_name, merged_model, task_vectors, pretrained_model, args, merged_tasks=[]):
    """Evaluate a single task using task vectors or merged model"""
    dataset_val = dataset_name + "Val"

    # Evaluate using merged model
    print(f"  - Evaluating using merged model (merged tasks: {dataset_name in merged_tasks})")
    merged_accuracy = eval_single_dataset(
        merged_model, dataset_val, args)['top1']
    print(f"  - Merged model accuracy: {merged_accuracy*100:.2f}%")

    return {
        'accuracy': merged_accuracy,
        'is_merged': dataset_name in merged_tasks
    }

def evaluate_tasks(task_range, merged_model, task_vectors, pretrained_model, args, merged_tasks=[]):
    """Evaluate tasks within specified range, using different strategies for merged and unmerged tasks"""
    results = {}
    for i in range(task_range[0], task_range[1]+1):
        dataset_name = DATASETS[i]

        print(f"Evaluating task {i}: {dataset_name}")
        result = evaluate_task(dataset_name, merged_model,
                               task_vectors, pretrained_model, args, merged_tasks)
        results[dataset_name] = result
        print(f"Task {dataset_name} final accuracy: {result['accuracy']*100:.2f}%")

    return results

def evaluate_model_on_tasks(model, tasks, args):
    """Evaluate model performance on specified task set"""
    accuracies = []
    task_accuracies = {}
    
    for task in tasks:
        dataset_val = task + "Val"
        result = eval_single_dataset(model, dataset_val, args)
        accuracies.append(result['top1'])
        task_accuracies[task] = result['top1']
        print(f"  - Task {task}: {result['top1']*100:.2f}%")
    
    # Calculate average accuracy
    avg_accuracy = sum(accuracies) / len(accuracies) if accuracies else 0
    print(f"  > Average accuracy across all tasks: {avg_accuracy*100:.2f}%")
    
    return task_accuracies, avg_accuracy

def select_best_tasks(results, n=3, exclude_tasks=[]):
    """Select the n tasks with highest accuracy"""
    # Filter out excluded tasks
    filtered_results = {k: v for k,
                        v in results.items() if k not in exclude_tasks}

    # Sort by accuracy
    sorted_tasks = sorted(filtered_results.items(),
                          key=lambda x: x[1]['accuracy'], reverse=True)

    # Select top n
    best_tasks = [task for task, _ in sorted_tasks[:n]]
    print(f"Selected best tasks: {best_tasks}")

    return best_tasks