import numpy as np
import torch
from contextlib import nullcontext

def class_il_eval(model, n_classes, test_dataloaders, task_id, device, model_type, use_torch_amp=False):
    model.eval()

    correct = 0
    total = 0
    num_tasks = len(test_dataloaders)

    accuracies_per_task = []
    confusion_matrix_class = np.zeros(
        (n_classes * (task_id + 1), n_classes * (task_id + 1)), dtype=int
    )
    confusion_matrix_task = np.zeros((task_id + 1, task_id + 1), dtype=int)
    task_predicted_counts = np.zeros(num_tasks)

    all_confidences = []
    all_predictions = []
    all_labels = []

    # Choose the context manager based on the use_amp flag
    if use_torch_amp:
        ctx = (torch.no_grad(), torch.cuda.amp.autocast())
    else:
        ctx = (torch.no_grad(),)
    # Use torch cuda amp if args.use_torch_amp is set
    with ctx[0], (ctx[1] if len(ctx) > 1 else nullcontext()):
        for test_dataloader in test_dataloaders:
            correct_task = 0
            total_task = 0

            for images, labels in test_dataloader:
                images, labels = images.to(device), labels.to(device)
                outputs = torch.cat([model(images, tid, use_lora=True, training=False) for tid in range(num_tasks)], dim=1)
                probs = torch.softmax(outputs, dim=1)
                confidences, predicted = torch.max(probs, 1)

                all_confidences.extend(confidences.cpu().numpy())
                all_predictions.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                total_task += labels.size(0)
                correct_task += (predicted == labels).sum().item()

                for true_label, pred_label in zip(
                    labels.cpu().numpy(), predicted.cpu().numpy()
                ):
                    confusion_matrix_class[true_label, pred_label] += 1
                    confusion_matrix_task[
                        true_label // n_classes, pred_label // n_classes
                    ] += 1
                    task_predicted_counts[pred_label // n_classes] += 1

            task_accuracy = 100 * correct_task / total_task if total_task > 0 else 0
            accuracies_per_task.append(task_accuracy)

    overall_accuracy = 100 * correct / total if total > 0 else 0

    return (
        overall_accuracy,
        accuracies_per_task,
        confusion_matrix_class,
        confusion_matrix_task,
        (np.array(all_confidences), np.array(all_predictions), np.array(all_labels)),
        task_predicted_counts,
    )


def task_il_eval(model, n_classes, test_dataloaders, device, model_type):
    model.eval()

    total_correct = 0
    total_samples = 0
    accuracies_per_task = []

    with torch.no_grad():
        for task_id, test_dataloader in enumerate(test_dataloaders):
            task_correct = 0
            task_samples = 0

            for images, labels in test_dataloader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images, task_id, use_lora=True, training=False)
                mapped_labels = labels % n_classes
                _, predicted = torch.max(outputs.data, 1)

                total_samples += mapped_labels.size(0)
                total_correct += (predicted == mapped_labels).sum().item()

                task_samples += mapped_labels.size(0)
                task_correct += (predicted == mapped_labels).sum().item()

            if task_samples > 0:
                task_accuracy = 100 * task_correct / task_samples
                accuracies_per_task.append(task_accuracy)

    overall_accuracy = 100 * total_correct / total_samples if total_samples > 0 else 0

    return overall_accuracy, accuracies_per_task
