import os
import shutil

import torch
import numpy as np

from torchvision import datasets, transforms
from torchvision.transforms.functional import InterpolationMode

def getTestData(dir=None,resize_size=256, crop_size=224, interpolation=InterpolationMode.BILINEAR):

    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    if dir is None:
        dir = os.getenv('IMAGENET_PATH')
    valdir = os.path.join(dir, "val")
    test_transforms = transforms.Compose([
            transforms.Resize(resize_size,interpolation=interpolation),
            transforms.CenterCrop(crop_size),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
    print(transforms)
    testset = datasets.ImageFolder(
        valdir,
        test_transforms
        )

    return testset

def ECE(y_true, y_pred, num_bins=11):
    if isinstance(y_pred, np.ndarray):
        y_pred = torch.tensor(y_pred)
        y_true = torch.tensor(y_true)
    print(type(y_pred))
    y_pred = torch.nn.functional.softmax(torch.tensor(y_pred), dim=1)
    #print(y_pred.shape)
    y_p = np.squeeze(y_pred.cpu().numpy())
    y_t = np.squeeze(y_true.cpu().numpy())
    pred_y = np.argmax(y_p, axis=-1)
    #print(pred_y)
    correct = (pred_y == y_t).astype(np.float32)
    prob_y = np.max(y_p, axis=-1)
    #print(prob_y.shape)
    #print(y_t.shape)
    bins = np.linspace(start=0, stop=1.0, num=num_bins)
    binned = np.digitize(prob_y, bins=bins, right=True)

    errors = np.zeros(num_bins)
    confs = np.zeros(num_bins)
    counts = np.zeros(num_bins)
    corrects = np.zeros(num_bins)
    accs = np.zeros(num_bins)
    o = 0
    for b in range(num_bins):
        mask = binned == b
        #if np.any(mask):
        count = np.sum(mask)
        counts[b] = count
        corrects[b] = np.sum(correct[mask])
        if count > 0:
            accs[b] = corrects[b] / counts[b]
            confs[b] = np.mean(prob_y[mask])
            errors[b] = np.abs( accs[b] - np.mean(prob_y[mask]) ) * counts[b]
    results = {'ECE':np.sum(errors)/y_pred.shape[0],
               'accuracy':np.sum(correct)/y_pred.shape[0],
               'confidence_bin':confs,
               'accuracy_bin':accs,
               'counts_bin':np.array(counts),
               'confidence_raw':prob_y,
               'prediction_raw':pred_y,
               'labels_raw':np.array(y_true.detach().cpu())
               }
    for key in results.keys():
        if isinstance(results[key], np.ndarray):
            results[key] = results[key].tolist()
    return results