#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import torch
import numpy as np

from xai.thresholds import make_order_thresholds


def calc_deletion_curve(model, img, heatmap, label, imputation, num_steps):
    heatmap_shape = heatmap.shape
    order_map = torch.reshape(torch.argsort(torch.argsort(heatmap.ravel())), heatmap_shape)
    num_total_pixels = np.prod(list(order_map.shape))

    outs = [model(img)[0,label]]
    fracs_deleted = [0]

    for threshold in make_order_thresholds(num_total_pixels, num_steps):
        mask = (order_map < threshold)

        imputed_img = imputation(img, mask)

        outs.append(model(imputed_img)[0,label])

        fracs_deleted.append((num_total_pixels-torch.sum(mask).cpu()) / num_total_pixels)
    return torch.stack(outs), np.array(fracs_deleted)

def calc_auc(outs, fracs_deleted):
    if isinstance(outs, torch.Tensor):
        outs = outs.cpu().detach().numpy()
    if isinstance(fracs_deleted, torch.Tensor):
        fracs_deleted = fracs_deleted.cpu().detach().numpy()
    return np.trapz(outs, fracs_deleted)