from torch.utils.data import DataLoader
from loguru import logger
from utils.training import get_embedded_img

def evaluate_model(model, scenario, args):
    final_res = ""
    final_acc = 0
    total_num = 0
    for task_id, test_taskset in enumerate(scenario):
        acc = 0
        num_samples = 0
        test_loader = DataLoader(test_taskset, batch_size=args.test_batch_size, shuffle=True)
        for x, y, t in test_loader:
            x, y = x.cuda(), y.cuda()

            if args.train_type == 'intercontinet':
                pred = model.predict_classes(x)#, task_id)#[:, :model.num_classes]
            # print(pred)
                acc += (pred == y).sum().item()
            else:
                pred = model(x)[:, :model.num_classes]
                acc += (pred.argmax(1) == y).sum().item()
            num_samples += len(y)
        total_num += num_samples
        final_acc += acc
        acc /= num_samples
        final_res += f"Task {task_id + 1}: accuracy = {acc}\n"
        if args.wandb:
            import wandb
            wandb.log({f"Task {task_id + 1} Acc": acc})
    final_res += f"Total accuracy = {final_acc / total_num}"
    logger.info(final_res)
    if args.wandb:
        import wandb
        wandb.log({"Avg Acc": final_acc / total_num})
        if args.track_buffer:
            wandb.log({"Avg Cert": args.num_cert.compute().item()})
    return final_acc / total_num, args.num_cert.compute().item() if args.track_buffer else 0

def eval_current_task(model, dataloader, task_id, args):
    acc = 0
    num_samples = 0
    for x, y, t in dataloader:
        x, y = x.cuda(), y.cuda()

        if args.train_type == 'intercontinet':
            pred = model.predict_classes(x)#, task_id)#[:, :model.num_classes]
            # print(pred)
            acc += (pred == y).sum().item()
        else:
            pred = model(x)[:, :model.num_classes]
            acc += (pred.argmax(1) == y).sum().item()
        num_samples += len(y)
    acc /= num_samples
    logger.info(f"Task {task_id}: accuracy = {acc}")
    if args.wandb:
        import wandb
        wandb.log({f"Task {task_id} Acc": acc})
    return acc
def eval_all_task(model, dataloaders, args):
    accs = []
    for task_id, dataloader in enumerate(dataloaders):
        acc = 0
        num_samples = 0
        for x, y, t in dataloader:
            x, y = x.cuda(), y.cuda()
            # if args.embed_img:
            #     x = get_embedded_img(x, args)
            pred = model(x)[:, :model.num_classes]
            acc += (pred.argmax(1) == y).sum().item()
            num_samples += len(y)
        acc /= num_samples
        accs.append(acc)
    acc_total = sum(accs) / len(accs)
    logger.info(f"Total val accuracy = {acc_total}")
    if args.wandb:
        import wandb
        wandb.log({"Total val Acc": acc_total})