from typing import List

import torch
from torch import Tensor
from torch.utils.data import DataLoader

from utils import verbose_iterator


def test(
    cl_type, model, dataloader_list: List[DataLoader], n_tasks, n_classes, base, verbose
) -> Tensor:
    """
    Evaluate the model(s) on a list of data loaders.

    :param model: Model or list of models to evaluate
    :param dataloader_list: List of dataloaders for test datasets
    :return: Tensor of accuracies for all test sets
    """
    print("Testing...")
    acc_list = torch.zeros(len(dataloader_list))
    if isinstance(model, dict):
        feature = model["feature"]
        model = model["classifier"]
        is_model_dict = True
    else:
        is_model_dict = False
    with torch.no_grad():
        for i, dataloader in enumerate(verbose_iterator(dataloader_list, verbose)):
            correct = torch.zeros(len(model))
            total = 0
            for model_ in model:
                model_.eval()
            device = next(model[0].parameters()).device

            for inputs, labels in dataloader:
                if isinstance(inputs, torch.Tensor):
                    inputs = inputs.to(device)
                if isinstance(inputs, dict):
                    inputs = {k: v.to(device) for k, v in inputs.items()}
                labels = labels.to(device)
                n_class_per_task = n_classes // n_tasks
                # labels = labels % n_class_per_task

                total += len(labels)

                with torch.no_grad():
                    inputs = base(inputs)
                    if is_model_dict:
                        inputs = feature(inputs)
                with torch.no_grad():
                    for j, model_ in enumerate(model):
                        correct[j] += (
                            (model_(inputs).argmax(dim=1) == labels).sum().item()
                        )

            acc_list[i - 1] = correct.max().item() / total

    return acc_list
