
import copy
import glob
import json
from keras.activations import sigmoid
import keras.backend as K
from keras.layers import Input
from keras.layers import Lambda
from keras.layers import Layer
from keras.layers import Flatten
from keras.models import Model
from keras.optimizers import Adam
from keras.optimizers import SGD
import matplotlib.pyplot as plt
import numpy as np
from numpy import inf
from pathlib import Path
from PIL import Image, ImageOps
import pickle
import sys
original_stdout = sys.stdout
import tensorflow as tf
from scipy.special import comb
from tensorflow import keras
from tqdm import tqdm

from Config import get_data_dir

### CONFIG ###
name = 'frisbee' #COCO class
n_concept = 20 #Number of concepts to learn
batch_size = 64 #Batch size to use for learning the concepts
epochs_init = 30 #Number of epochs to train the concept model for
n_sample = 100 #Number of samples to use for Kernel SHAP
epochs_shap = 6 #Number of epochs to train each Kern SHAP model for

n_tc = 5 #Number of most top concepts to show
fold = 'train' #Dataset split that we use to find images for concept visualization
n_row = 4 #Configure the image grid used to visualize a concept
n_col = 5
n_i = n_row * n_col

run_main = False # Flag that controls whether or not the core calculations are run

out_dir = './Outputs/concept-shap/{}'.format(name)
Path(out_dir).mkdir(parents = True, exist_ok = True)

out_dir_concepts = '{}/concepts/'.format(out_dir)
Path(out_dir_concepts).mkdir(parents = True, exist_ok = True)

with open('{}/maps.json'.format(get_data_dir()), 'r') as f:
    name2index, _ = json.load(f)
index = name2index[name]

### Concept SHAP code (modified from their awa_standalone code)
init = keras.initializers.RandomUniform(minval=-0.5, maxval=0.5, seed=None)

class Weight(Layer):
  """Simple Weight class."""

  def __init__(self, dim, **kwargs):
    self.dim = dim
    super(Weight, self).__init__(**kwargs)

  def build(self, input_shape):
    # creates a trainable weight variable for this layer.
    self.kernel = self.add_weight(
        name='proj', shape=self.dim, initializer=init, trainable=True)
    super(Weight, self).build(input_shape)

  def call(self, x):
    return self.kernel

  def compute_output_shape(self, input_shape):
    return self.dim
    
def topic_model(predict,
                input_shape,
                n_concept,
                threshold=0.0,
                mode = 'find_concepts'):
    
    if mode not in ['find_concepts', 'eval_completeness']:
        print('WARNING:  bad "mode" parameter')
        return -1
    
    # Calculate the concept scores from the features
    f_input = Input(input_shape[1:], name='f_input')
    f_input_n = Lambda(lambda x: K.l2_normalize(x, axis = 3), name = 'f_input_n')(f_input)
    
    topic_vector = Weight((input_shape[-1], n_concept))(f_input)
    topic_vector_n = Lambda(lambda x: K.l2_normalize(x, axis = 0), name = 'topic_vector_n')(topic_vector)
    
    topic_prob = Lambda(lambda x: K.dot(x[0], x[1]), name = 'topic_prob')([f_input, topic_vector_n])
    topic_prob_n = Lambda(lambda x: K.dot(x[0], x[1]), name = 'topic_prob_n')([f_input_n, topic_vector_n])
    
    topic_prob_mask = Lambda(lambda x: K.cast(K.greater(x, threshold),'float32'), name = 'topic_prob_mask')(topic_prob_n)
    topic_prob_am = Lambda(lambda x: x[0] * x[1], name = 'topic_prob_am')([topic_prob, topic_prob_mask])
    topic_prob_sum = Lambda(lambda x: K.sum(x, axis = 3, keepdims = True) + 1e-3, name = 'topic_prob_sum')(topic_prob_am)
    topic_prob_nn = Lambda(lambda x: x[0] / x[1], name = 'topic_prob_nn')([topic_prob_am, topic_prob_sum]) #Note:  this is l1 normalization
      
    # Use a 2 layer MLP to reconstruct the features from the concept scores
    rec_vector_1 = Weight((n_concept, 500))(f_input)
    rec_vector_2 = Weight((500, input_shape[-1]))(f_input)

    rec_layer_1 = Lambda(lambda x: K.relu(K.dot(x[0], x[1])), name = 'rec_layer_1')([topic_prob_nn, rec_vector_1])
    rec_layer_2 = Lambda(lambda x: K.dot(x[0], x[1]), name = 'rec_layer_2')([rec_layer_1, rec_vector_2])
    
    # Get the prediction from the reconstructed features
    # WARNING 1:  have to transpose the output, this may be different for convolution and linear features
    pred = predict(tf.transpose(rec_layer_2, perm = (0, 3, 1, 2)))
    
    topic_model_pr = Model(inputs = f_input, outputs = [pred, topic_vector_n, topic_prob_n])
    # Because we are trying to identify concepts that are sufficient for the model's predictions, we don't want to re-train the prediction head of the model
    topic_model_pr.layers[-1].trainable = False
    if mode == 'eval_completeness':
        # For evaluation, the concepts are fixed
        topic_model_pr.layers[1].trainable = False
    
    # Setup the loss
    if mode == 'find_concepts':
        def loss(y_true, y_pred):
            pred = y_pred[0]
            topic_vector_n = y_pred[1]
            topic_prob_n = y_pred[2]
            acc_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(y_true, pred))
            neighbor_loss = -10.0 * tf.reduce_mean(tf.nn.top_k(K.transpose(K.reshape(topic_prob_n, (-1, n_concept))), k = 2, sorted = True).values)
            similarity_loss = 10.0 * tf.reduce_mean(K.dot(K.transpose(topic_vector_n), topic_vector_n) - np.eye(n_concept))
            return acc_loss + neighbor_loss + similarity_loss
    elif mode == 'eval_completeness':
        def loss(y_true, y_pred):
            pred = y_pred[0]
            acc_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(y_true, pred))
            return acc_loss
        
    return topic_model_pr, loss

def init_optimization(model, loss):
    optimizer = Adam(lr = 0.001)
    model.compile(loss = loss, optimizer = optimizer)
    return optimizer

def get_acc(data, model):
    
    dataset = tf.data.Dataset.from_tensor_slices((data['val_features'], data['val_preds']))
    dataset = dataset.batch(batch_size)

    correct = 0
    total = 0
    for v in dataset:
        y_pred = tm(v[0])[0]
        correct += np.sum((y_pred.numpy() > 0) == v[1].numpy())
        total += len(y_pred.numpy()) 
    
    return correct / total

def sample_binary(n_concept, n_sample, p = 0.2):
    binary_matrix = np.zeros((n_sample, n_concept))
    for i in range(n_sample):
        binary_matrix[i, :] = np.random.choice(2, n_concept, p = [1 - p, p])
    return binary_matrix

def kernel(n, k, p = 0.2):
    return (n - 1) * 1.0 / ((n - k) * k * comb(n, k)) / (np.power(p, k) * np.power(1 - p, n - k))

def kernel_shap(n_concept, binary_input, scores, p = 0.2):
    x = np.array(binary_input)
    k = np.array([kernel(n_concept, i, p = p) for i in np.sum(binary_input, axis = 1)])
    k[k == inf] = 10000
    y = np.array(scores)
    xkx = np.matmul(np.matmul(x.transpose(), np.diag(k)), x)
    xky = np.matmul(np.matmul(x.transpose(), np.diag(k)), y)
    expl = np.matmul(np.linalg.pinv(xkx), xky)
    return expl

### Image normalization and processing stuff ###
mean = np.array([0.485, 0.456, 0.406])
mean = mean.reshape((1, 1, 3))

std = np.array([0.229, 0.224, 0.225])
std = std.reshape((1, 1, 3))

def image_grid(imgs, rows, cols):
    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

### Run Concept SHAP ###
if run_main:

    # Load the full model
    model = keras.models.load_model('./Outputs/initial-tune/trial0/tf')
    model.compile()

    # Restrict it to the class of interest
    preds = model.layers[-1].output
    preds_index = preds[:, index] #tf.reshape(preds[:, index], (-1, 1))
    model = Model(model.input, preds_index)

    with open('{}/initial-model.txt'.format(out_dir), 'w') as f:
        sys.stdout = f
        model.summary()
        sys.stdout = original_stdout 

    # Divide the model into a "feature extractor" and a "prediction head"
    slice_index = -31 # -9, -24, -31

    features = model.layers[slice_index].output
    preds = model.layers[-1].output

    feature_model = Model(model.input, features, name = 'feature_model')
    with open('{}/feature-extractor.txt'.format(out_dir), 'w') as f:
        sys.stdout = f
        feature_model.summary()
        sys.stdout = original_stdout 

    pred_layers = [Input(features.shape[1:])]
    c = 0
    for i in range(slice_index + 1, 0):
        if i == -24:
            pred_layers.append(model.layers[i]([pred_layers[c], pred_layers[c - 5]]))
        elif i == -17:
            pred_layers.append(model.layers[i](pred_layers[c - 5]))
        elif i == -16:
            pred_layers.append(model.layers[i]([pred_layers[c], pred_layers[c - 1]]))
        elif i == -9:
            pred_layers.append(model.layers[i]([pred_layers[c], pred_layers[c - 5]]))
        else:
            pred_layers.append(model.layers[i](pred_layers[c]))
        c += 1

    predict_model = Model(pred_layers[0], pred_layers[-1], name = 'prediction_model')

    with open('{}/prediction-head.txt'.format(out_dir), 'w') as f:
        sys.stdout = f
        predict_model.summary()
        sys.stdout = original_stdout 

    # Sanity check this process
    corr = []
    imgs = []
    for i, v in enumerate(glob.glob('./ExternalData/frisbee+people+outside/*')):
        img = Image.open(v).convert('RGB')

        img_tf = tf.image.resize_with_pad(img, 224, 224)
        img_tf /= 255

        if i < n_i:
            imgs.append(Image.fromarray(np.array(255 * img_tf.numpy(), dtype = np.uint8)))
        x = (img_tf - mean) / std
        x = tf.transpose(x, perm = [2, 0, 1])
        x = tf.expand_dims(x, axis = 0)

        pred_orig = model(x).numpy()
        pred_split = predict_model(feature_model(x)).numpy()

        # Check that the composed feature and prediction models match the original model
        if not np.all(pred_orig == pred_split):
            print('Mismatched predictions: ', v)   

        # Check if the model detected the object
        corr.append(pred_orig[0] > 0)

    plt.rcParams["figure.figsize"] = (16, 20)
    grid = image_grid(imgs, n_row, n_col)
    plt.title('Average accuracy: {}'.format(np.mean(corr)))
    plt.imshow(grid)
    plt.savefig('{}/sanity-check.png'.format(out_dir))
    plt.close()

    # Check the feature extractor's perceptual field
    x = tf.ones((1, 3, 224, 224))

    y = feature_model(x)
    dim = y.shape[2]

    fields = []
    imgs = []
    for i in range(dim):
        tmp = []
        for j in range(dim):

            with tf.GradientTape() as tape:
                tape.watch(x)
                y = feature_model(x)
                v = tf.square(y)
                v = tf.reduce_mean(v, axis = (0, 1))
                v = v[i, j]
            gradients = tape.gradient(v, x)

            vis = 255 * (gradients.numpy() != 0)
            vis = np.squeeze(vis)
            vis = np.transpose(vis, axes = (1, 2, 0))
            tmp.append(vis)
            vis = np.array(vis, dtype = np.uint8)
            imgs.append(Image.fromarray(vis))
        fields.append(tmp)
        
    with open('{}/fields.pkl'.format(out_dir), 'wb') as f:
        pickle.dump(fields, f)

    plt.rcParams["figure.figsize"] = (20, 20)
    grid = image_grid(imgs, dim, dim)
    plt.imshow(grid)
    plt.savefig('{}/receptive-field.png'.format(out_dir))
    plt.close()

    # Get the model predictions and features for the dataset
    print('Computing the "data" object')
    data = {}
    for split in ['train', 'val']:

        with open('{}/{}/images.json'.format(get_data_dir(), split), 'r') as f:
            images = json.load(f)

        labels = []
        features = []
        preds = []
        for i, v in enumerate(tqdm(list(images))):
            file = images[v]['file']
            label = images[v]['label']

            x = np.array(Image.open(file).convert('RGB'))
            x = x / 255
            x = (x - mean) / std
            x = np.transpose(x, axes = [2, 0, 1])
            x = np.expand_dims(x, axis = 0)

            feature = feature_model(x)

            logit = predict_model(feature)

            labels.append(label)
            features.append(np.squeeze(np.array(feature)))
            preds.append(np.squeeze(1.0 * (np.array(logit) > 0.0)))

        data['{}_labels'.format(split)] = np.array(labels)
        # WARNING 1:  have to transpose the output, this may be different for convolutional and linear features
        data['{}_features'.format(split)] = np.transpose(np.array(features), axes = [0, 2, 3, 1])
        data['{}_preds'.format(split)] = np.array(preds, dtype = np.float32) 

    with open('{}/data.pkl'.format(out_dir), 'wb') as f:
        pickle.dump(data, f, protocol = 4)
    
    # Learn the important concepts
    print('Learning the concepts')
    tm, loss = topic_model(predict_model, data['train_features'].shape, n_concept)
    with open('{}/topic-model.txt'.format(out_dir), 'w') as f:
        sys.stdout = f
        tm.summary()
        sys.stdout = original_stdout 
    optimizer = init_optimization(tm, loss)
    print('Starting accuracy: ', get_acc(data, tm))

    dataset = tf.data.Dataset.from_tensor_slices((data['train_features'], data['train_preds']))
    dataset = dataset.repeat(epochs_init).batch(batch_size)

    for v in dataset:
        with tf.GradientTape() as tape:
            y_pred = tm(v[0])
            l = loss(v[1], y_pred)
        gradients = tape.gradient(l, tm.trainable_weights)
        optimizer.apply_gradients(zip(gradients, tm.trainable_weights))

    print('Ending accuracy: ', get_acc(data, tm))

    # Remove redundant concepts
    topic_vec = tm.layers[1].get_weights()[0]
    w1 = tm.layers[-6].get_weights()[0]
    w2 = tm.layers[-4].get_weights()[0]

    topic_vec_n = topic_vec / np.linalg.norm(topic_vec, axis = 0, keepdims = True)

    aa = np.matmul(topic_vec_n.T, topic_vec_n)- np.eye(n_concept)

    remove_list = set()
    for i in range(n_concept):
        for j in range(i + 1, n_concept):
            if aa[i,j] > 0.8:
                remove_list.add(j)

    remove = np.array(list(remove_list))
    keep = np.array(list(set(range(n_concept)) - remove_list))
    n_concept_alive = len(keep)

    print('Concepts Removed: ', remove)
    print('Concepts Kept: ', keep)

    with open('{}/concepts.pkl'.format(out_dir), 'wb') as f:
        pickle.dump({'keep': keep, 'remove': remove, 'topic_vec_n': topic_vec_n}, f)
    
    
    # Setup the model for Kernel SHAP
    tm_shap, loss = topic_model(predict_model, data['train_features'].shape, n_concept, mode = 'eval_completeness')

    # Generate samples to use for Kernel SHAP
    binary_input = sample_binary(n_concept_alive, n_sample)
    binary_input[-1, :] = 1

    # Process each of those samples
    print('Computing Kernel SHAP')
    scores = []
    for i in tqdm(range(n_sample)):

        # Remove the chosen concepts 
        topic_vec_tmp = copy.copy(topic_vec)
        if len(remove) != 0:
            topic_vec_tmp[:, remove] = 0.0
        topic_vec_tmp[:, keep[binary_input[i, :] == 0]] = 0.0

        # Reset the SHAP model weights
        tm_shap.layers[1].set_weights([topic_vec_tmp])
        tm_shap.layers[-6].set_weights([w1])
        tm_shap.layers[-4].set_weights([w2])

        # Reset the optimization process
        optimizer = init_optimization(tm_shap, loss)

        # Train the SHAP model
        dataset = tf.data.Dataset.from_tensor_slices((data['train_features'], data['train_preds']))
        dataset = dataset.repeat(epochs_shap).batch(batch_size)

        for v in dataset:
            with tf.GradientTape() as tape:
                y_pred = tm_shap(v[0])
                l = loss(v[1], y_pred)
            gradients = tape.gradient(l, tm_shap.trainable_weights)
            optimizer.apply_gradients(zip(gradients, tm_shap.trainable_weights))

        # Get the score for this point
        scores.append(get_acc(data, tm_shap))

    # Compute Kernel SHAP
    shap = kernel_shap(n_concept_alive, binary_input, scores)
    with open('{}/shap.pkl'.format(out_dir), 'wb') as f:
        pickle.dump(shap, f)
    plt.hist(shap)
    plt.savefig('{}/shap.png'.format(out_dir))
    plt.close()

### Visualize the results ###
with open('{}/{}/images.json'.format(get_data_dir(), fold), 'r') as f:
    images = json.load(f)
files = ['{}/{}/images/{}.jpg'.format(get_data_dir(), fold, v) for v in list(images)]

with open('{}/fields.pkl'.format(out_dir), 'rb') as f:
    fields = pickle.load(f)

with open('{}/data.pkl'.format(out_dir), 'rb') as f:
    data = pickle.load(f)
features = data['{}_features'.format(fold)]
labels = data['{}_labels'.format(fold)]

with open('{}/concepts.pkl'.format(out_dir), 'rb') as f:
    tmp = pickle.load(f)
topic_vec_n = tmp['topic_vec_n']
keep = tmp['keep']

with open('{}/shap.pkl'.format(out_dir), 'rb') as f:
    shap = pickle.load(f)

concepts = topic_vec_n[:, keep]
top_concept_indices = (-1.0 * shap).argsort()
activations = np.matmul(features, concepts)

# Visualize each of the most important concepts
plt.rcParams["figure.figsize"] = (16, 20)
for i in range(n_tc):    
    # Get how strongly a concept is expressed in each image
    score = activations[:, :, :, top_concept_indices[i]]
    score = np.max(score, axis = (1, 2))
    
    # Find the most expressive images
    top_image_indices = score.argsort()
    
    # Filter for those that are examples of this class
    found = 0
    imgs = []
    for j in top_image_indices:
        if labels[j, index] == 1:
            # Find what part of the image produces this activation
            tmp = activations[j, :, :, top_concept_indices[i]]
            top_i, top_j = np.unravel_index(np.argmax(tmp), tmp.shape)
            
            # Use that as a mask for the image
            mask = fields[top_i][top_j] / 255  
            img = np.array(Image.open(files[j]).convert('RGB'))
            masked_img = mask * img + 255 * (1 - mask) * np.ones(img.shape)
            masked_img = np.array(masked_img, dtype = np.uint8)
            masked_img = Image.fromarray(masked_img)
            masked_img = ImageOps.expand(masked_img, border = 5, fill = (0, 0, 0))
            
            imgs.append(masked_img)
            found += 1
            if found == n_i:
                break

    # Display those images
    grid = image_grid(imgs, n_row, n_col)
    plt.imshow(grid)
    plt.title('Concept: {}'.format(keep[top_concept_indices[i]]))
    plt.savefig('{}/{}.png'.format(out_dir_concepts, i))
    plt.close()
       