'''This file implements the perturbation-based saliency method by Greydanus et al. 2017:
Greydanus, S., Koul, A., Dodge, J., & Fern, A. (2018, July). Visualizing and understanding atari agents. In International Conference on Machine Learning (pp. 1792-1801). PMLR.'''

import cv2
import numpy as np

from PIL import Image
from scipy.ndimage.filters import gaussian_filter
from scipy.stats import entropy
import math

def get_mask(center, size, r):
    y, x = np.ogrid[-center[0] : size[0] - center[0], -center[1] : size[1] - center[1]]
    keep = x * x + y * y <= 1
    mask = np.zeros(size)
    mask[keep] = 1
    mask = gaussian_filter(mask, sigma=r)
    return mask / mask.max()

def softmax(vec):
    e_x = np.exp(vec - np.max(vec))
    return e_x / e_x.sum()

def greydanus(value_func, obs):
    '''
    Vectorized-generate attention for RL agents.
    
    Input:
    value_func: The step() function of an RL agent. Returns a 2D array of 
                Q-values (#sample x #acrtions) given a series of observations.
    obs: Preprocessed obseravtions for the RL agent.

    Output:
    vf_att: Attention maps of the RL agent for each of the input observations.
    '''
    
    H, W, C = 84, 84, 4
    d = r = 3
    N = len(obs)

    original = value_func(obs)
    original = np.array([softmax(orig) for orig in original])

    vf_scores = np.zeros((N, H // d + 1, W // d + 1))
    gauss_noise = np.array([
        cv2.GaussianBlur(x, (r, r), cv2.BORDER_DEFAULT)
        for x in obs
    ])

    print("Computing Greydanus saliency maps")
    print(''.join(['_'] * (H // d)))
    for i in range(0, H, d):
        for j in range(0, W, d):
            mask = get_mask(center=[i, j], size=[H, W], r=r).reshape(1, H, W, 1)
            interpolated_obs = (1 - mask) * obs + mask * gauss_noise
            perturbed = value_func(interpolated_obs)
            perturbed = np.array([softmax(pert) for pert in perturbed])
            #vf_scores[:, i // d, j // d] = np.sum(abs(original - perturbed), axis=1)
            vf_scores[:, i // d, j // d] = np.sum((original - perturbed) ** 2 * 0.5, axis=1)
        print(f'.', end='', flush=True)
    print()

    vf_scores = np.array([
        np.array(Image.fromarray(x).resize((H, W), Image.BILINEAR))
        for x in vf_scores
    ]).astype(np.float32)
    
    nan_idx = vf_scores.max(axis=(1,2)) == 0
    vf_att = vf_scores / (vf_scores.max(axis=(1,2)))[:, np.newaxis, np.newaxis]
    vf_att[nan_idx, :,:] = 0
    return vf_att

