import numpy as np
import torch
from pathlib import Path
import argparse
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
from torchvision.utils import save_image
from captum.attr import *
from captum.attr import visualization as viz
import matplotlib.pyplot as plt

from guided_diffusion import explainees

PATH_CKPT = 'data/explainees/celeba-hq/ckpt.tar'
PATH_OUT = 'data/datasets/gt_keep_masks'

class ImageFromPathDataset(Dataset):

    def __init__(self, path_imgs_dir):
        super().__init__()
        self.paths = np.array(list(path_imgs_dir.iterdir()))
        self.length = self.paths.shape[0]
    
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        path_img = str(self.paths[idx])
        img = read_image(path_img) / 255
        return img, path_img

def get_dataloader(args):
    dataset = ImageFromPathDataset(args.path_imgs_dir)
    return DataLoader(dataset, batch_size = args.batch_size)

def main(args):
    # set device
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

    # get model and wrap it with explainer
    model = explainees.DenseNet(PATH_CKPT).to(device)
    explainer = IntegratedGradients(model)
    if args.use_noise_tunnel:
        explainer = NoiseTunnel(explainer)

    # get loader and label
    dataloader = get_dataloader(args)
    label_idx = args.label_idx
    label_name = model.id_to_cls[label_idx]
    print(f'label: {label_name}')

    # make output dirs
    path_output = Path(PATH_OUT) / args.output_dir_name
    path_output.mkdir(parents = True, exist_ok = True)

    for batch in dataloader:
        # load images and set requires grad
        batch_imgs, batch_paths = batch
        batch_imgs = batch_imgs.to(device)
        batch_imgs.requires_grad_()

        # get predictions
        with torch.no_grad():
            batch_outputs = model(batch_imgs)
            batch_probs = torch.sigmoid(batch_outputs)[:, label_idx]
            print(f'predicted prob for {label_name}: {batch_probs.item()}')

        # our images are in [0, 1] range so 
        # for NoiseTunnel we modify sdtevds
        if args.use_noise_tunnel:
            batch_attrs = explainer.attribute(
                batch_imgs, 
                target = label_idx, 
                stdevs = args.stdev,
                nt_samples = 50,
                nt_samples_batch_size = args.batch_size,
                nt_type = args.nt_type)
        else:
            batch_attrs = explainer.attribute(batch_imgs, target = label_idx)

        # save attribution maps
        for attr, img, path in zip(batch_attrs, batch_imgs, batch_paths):
            name = path.split('/')[-1]
            attr = attr.permute(1, 2, 0).numpy(force = True)
            attr = viz._normalize_attr(attr, sign = 'absolute_value', reduction_axis = 2)
            plt.imsave(str(path_output / name), 1 - attr, cmap = 'gray')

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--path-imgs-dir', type = Path, help = 'Path to directory with images')
    parser.add_argument('--output-dir-name', type = str, help = 'Name of subdirectory where soft masks will be saved')
    parser.add_argument('--label-idx', type = int, help = 'Index of the label from multilabel classification')
    parser.add_argument('--batch-size', type = int, help = 'Batch size')
    parser.add_argument('--use-noise-tunnel', action = 'store_true', help = 'Whether to use NoiseTunnel or not')
    parser.add_argument('--nt-type', choices = ['smoothgrad', 'smoothgrad_sq', 'vargrad'], help = 'NoiseTunnel type')
    parser.add_argument('--stdev', type = float, default = 0.05, help = 'Stdev for NoiseTunnel')
    args = parser.parse_args()
    main(args)