#!/usr/bin/env python3

from torcheval.metrics import BinaryAUPRC, MulticlassAccuracy
from torch.utils.data import DataLoader
from src.utils import configure
from ..model.arch import Model
from src import const
import random
import torch
import sys


def accuracy(model, metric, gen, is_multilabel=False):
    with torch.no_grad():
        for X, (heatmap, y) in gen:
            if is_multilabel: metric.update(model(X)[0].detach().flatten(), y.flatten())
            else: metric.update(model(X)[0].detach(), y.argmax(1))
    return metric.compute()


def iou(model, gen):
    tp = torch.tensor(0, device=const.DEVICE)
    fp_fn = torch.tensor(0, device=const.DEVICE)

    with torch.no_grad():
        for X, (heatmap, y) in gen:
            pred, cams = model(X)
            pred_map = model.get_semantic_map(cams).to(torch.bool).to(torch.float)

            tp += (pred_map == heatmap).sum()
            fp_fn += (pred_map != heatmap).sum()

    return tp / (tp + fp_fn)


if __name__ == '__main__':
    name = sys.argv[1]
    configure(name.split('/')[-2])
    random.seed(const.SEED)

    if const.DATASET == 'hardimagenet':
        from ..data.hard_imagenet import Dataset
        data = Dataset('val', ft=True)
    else:
        from ..data.oxford_iiit_pet import Dataset
        data = Dataset('valid')

    torch.set_float32_matmul_precision('high')  # hardware-specific flag

    is_multilabel = 'sbd' in name
    model = Model(is_contrastive='default' not in name, return_logits=True, logits_only=False)
    model.load_state_dict(torch.load(const.DOWNSTREAM_MODELS_DIR / f'{name}.pt', map_location=const.DEVICE, weights_only=True))
    model.name = name
    model.eval()

    torch.multiprocessing.set_start_method('spawn', force=True)
    print(accuracy(model, BinaryAUPRC(device=torch.device(const.DEVICE)) if is_multilabel else MulticlassAccuracy(device=torch.device(const.DEVICE)),
                   DataLoader(data, batch_size=const.BATCH_SIZE, num_workers=3, shuffle=True), is_multilabel=is_multilabel))
    print(iou(model, DataLoader(data, batch_size=const.BATCH_SIZE, num_workers=3, shuffle=True)))
