import torch
import argparse
from tqdm import tqdm

import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from PIL import Image
import pathlib

import torchvision
import torch.nn as nn

IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm',
                    'tif', 'tiff', 'webp'}


def validate(model, loader, args):
    n_samples = len(loader.dataset)
    entropy_cum_sum = 0
    forgotten_prob_cum_sum = 0
    accuracy_cum_sum = 0
    model.eval()
    correct_forget = 0
    total_forget = 0
    correct_remain = 0
    total_remain = 0
    for data in tqdm(iter(loader)):
        img, label = data
        logits = model(img.to(device))
        
        pred = torch.argmax(logits, dim=-1)
        '''
        res = pred == label.to(device)
        
        for i in range(len(label)):
            if label[i] == args.label_of_forgotten_class:
                correct_forget += res[i]
                total_forget += 1
            else:
                correct_remain += res[i]
                total_remain += 1
        '''
        all_accuracy = (pred == label.to(device)).sum()#args.label_of_forgotten_class).sum()
        accuracy_cum_sum += all_accuracy / n_samples
        
        probs = torch.nn.functional.softmax(logits, dim=-1)
        log_probs = torch.log(probs)
        entropy = -torch.multiply(probs, log_probs).sum(1)
        avg_entropy = torch.sum(entropy) / n_samples
        entropy_cum_sum += avg_entropy.item()
        #forgotten_prob_cum_sum += (probs[:, args.label_of_forgotten_class] / n_samples).sum().item()
    
    print(f"Average entropy: {entropy_cum_sum}")
    #print(f"Average prob of forgotten class: {forgotten_prob_cum_sum}")
    print(f"Average accuracy of forgotten class: {accuracy_cum_sum}")


class ImagePathDataset(torch.utils.data.Dataset):
    def __init__(self, img_folder, transforms=None, n=None, forget_label=0, mode='forget'):
        self.transforms = transforms
        self.forget_label = forget_label
        self.mode = mode
        path = pathlib.Path(img_folder)
        dirs = [f for f in path.glob("*")]
        #print(dirs)
        if mode == 'forget':
            self.files = sorted([(file, paths.name) for paths in dirs for ext in IMAGE_EXTENSIONS
                       for file in paths.glob('*.{}'.format(ext)) if paths.name == str(forget_label)])
        else:
            self.files = sorted([(file, paths.name) for paths in dirs for ext in IMAGE_EXTENSIONS
                       for file in paths.glob('*.{}'.format(ext)) if paths.name != str(forget_label)])
        #print(self.files)
        assert n is None or n <= len(self.files)
        self.n = len(self.files) if n is None else n
        
    def __len__(self):
        return self.n

    def __getitem__(self, i):
        path, label = self.files[i]
        img = Image.open(path).convert('RGB')
        if self.transforms is not None:
            img = self.transforms(img)
        return img, int(label)


def GetImageFolderLoader(path, data_type, img_size, batch_size, forget_label, mode):
    
    norm_mean = [0.5, 0.5, 0.5]
    norm_std = [0.5, 0.5, 0.5]
    
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(norm_mean, norm_std),
    ])
    
    dataset = ImagePathDataset(
        path,
        transforms=transform,
        forget_label = forget_label,
        mode = mode
    )
    
    loader = DataLoader(
        dataset,
        batch_size=batch_size
    )
    
    return loader


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # data related settings
    parser.add_argument("--sample_path", type=str, 
                        help="Path to folder containing samples")
    parser.add_argument('--dataset', type=str, choices=['cifar10', 'stl10', 'imagenette'],
                        help='name of the dataset, either cifar10 or stl10')
    parser.add_argument("--label_of_forgotten_class", type=int, default=0, 
                        help="Class label of forgotten class (for calculating average prob)")
    parser.add_argument('-b', '--batch-size', type=int, default=16,
                        help='test batch size for inference')
    args = parser.parse_args()
    
    model = torchvision.models.resnet34(pretrained=False)

    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 10)
    model.load_state_dict(torch.load(f"{args.dataset}_resnet34.pth", map_location='cpu'))
    device = torch.device("cuda:5") #torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    forget_loader = GetImageFolderLoader(args.sample_path, args.dataset, 224, args.batch_size, args.label_of_forgotten_class, mode='forget')
    remain_loader = GetImageFolderLoader(args.sample_path, args.dataset, 224, args.batch_size, args.label_of_forgotten_class, mode='remain')
    validate(model, forget_loader, args)
    validate(model, remain_loader, args)

