import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import transforms
from torchsummary import summary

import shap
import lime

np.set_printoptions(threshold=sys.maxsize)
torch.set_printoptions(profile="full")
torch.manual_seed(0)
import random
random.seed(0)
np.random.seed(0)

xai = 'shap'  # 'lime'
arch = 'ff'  # 'conv'

n_epochs = 100
batch_size_train = 32
batch_size_test = 1010 + 974
learning_rate = 0.05
momentum = 0.5

random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

TRAIN_DATA_PATH = "./data/train/"
TEST_DATA_PATH1 = "./data/test/"
TEST_DATA_PATH2 = "./data/test_exp/"
VAL_DATA_PATH = "./data/val/"
TRANSFORM_IMG = transforms.Compose([
    transforms.Grayscale(),
    transforms.ToTensor()]) #,
    #transforms.Normalize(mean=[0.5],
    #                     std=[0.5] )
    #])

if xai == 'lime':
    class ImageExplanation_2(object):
        def __init__(self, image, segments):
            """Init function.
            Args:
                image: 3d numpy array
                segments: 2d numpy array, with the output from skimage.segmentation
            """
            self.image = image
            self.segments = segments
            self.intercept = {}
            self.local_exp = {}
            self.local_pred = {}
            self.score = {}

        def get_image_and_mask(self, label, positive_only=True, negative_only=False, hide_rest=False,
                               num_features=5, min_weight=0.):
            """Init function.
            Args:
                label: label to explain
                positive_only: if True, only take superpixels that positively contribute to
                    the prediction of the label.
                negative_only: if True, only take superpixels that negatively contribute to
                    the prediction of the label. If false, and so is positive_only, then both
                    negativey and positively contributions will be taken.
                    Both can't be True at the same time
                hide_rest: if True, make the non-explanation part of the return
                    image gray
                num_features: number of superpixels to include in explanation
                min_weight: minimum weight of the superpixels to include in explanation
            Returns:
                (image, mask), where image is a 3d numpy array and mask is a 2d
                numpy array that can be used with
                skimage.segmentation.mark_boundaries
            """
            if label not in self.local_exp:
                raise KeyError('Label not in explanation')
            if positive_only & negative_only:
                raise ValueError("Positive_only and negative_only cannot be true at the same time.")
            segments = self.segments
            image = self.image
            exp = self.local_exp[label]
            mask = np.zeros(segments.shape, np.float32)
            if hide_rest:
                temp = np.zeros(self.image.shape)
            else:
                temp = self.image.copy()
            if positive_only:
                fs = [x for x in exp
                      if x[1] > 0 and x[1] > min_weight][:num_features]
            if negative_only:
                fs = [x for x in exp
                      if x[1] < 0 and abs(x[1]) > min_weight][:num_features]
            if positive_only or negative_only:
                for f, w in fs:
                    temp[segments == f] = image[segments == f].copy()
                    mask[segments == f] = w
                return temp, mask
            else:
                for f, w in exp[:num_features]:
                    if np.abs(w) < min_weight:
                        continue
                    c = 0 if w < 0 else 1
                    mask[segments == f] = -1 if w < 0 else 1
                    temp[segments == f] = image[segments == f].copy()
                    temp[segments == f, c] = np.max(image)
                return temp, mask

train_data = torchvision.datasets.ImageFolder(root=TRAIN_DATA_PATH, transform=TRANSFORM_IMG)
train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size_train, shuffle=True,  num_workers=1)
test_data = torchvision.datasets.ImageFolder(root=TEST_DATA_PATH1, transform=TRANSFORM_IMG)
test_data_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size_test, shuffle=True, num_workers=1) 
val_data = torchvision.datasets.ImageFolder(root=VAL_DATA_PATH, transform=TRANSFORM_IMG)
val_data_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size_test, shuffle=False, num_workers=1) 
print(train_data.class_to_idx)

if arch == 'ff':
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.fc1 = nn.Linear(28*28, 8)
            self.fc2 = nn.Linear(8, 2)

        def forward(self, x):
            x = x.reshape(-1, 28*28)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            return F.softmax(x, dim=1)
elif arch == 'conv':
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
            self.conv2 = nn.Conv2d(3, 5, kernel_size=3)
            self.conv2_drop = nn.Dropout2d()
            self.fc1 = nn.Linear(5*5*5, 8)
            self.fc2_1 = nn.Linear(8, 2)

        def forward(self, x):
            x = F.relu(F.max_pool2d(self.conv1(x), 2))
            x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
            x = x.view(-1, 5*5*5)
            x = self.fc1(x)
            x = F.relu(self.fc2_1(x))
            return F.softmax(x, dim=1)

network = Net()
print(summary(network, (1,28,28)))
optimizer = optim.SGD(network.parameters(), lr=learning_rate,
                      momentum=momentum)

from pytorchtools import EarlyStopping
def train():
    train_losses = []
    valid_losses = []
    avg_train_losses = []
    avg_valid_losses = []
    early_stopping = EarlyStopping(patience=2, verbose=True)

    for epoch in range(1, n_epochs + 1):
        network.train()
        for batch_idx, (data, target) in enumerate(train_data_loader):
            optimizer.zero_grad()
            output = network(data)
            loss = F.nll_loss(output.log(), target)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())

        network.eval() # prep model for evaluation
        for data, target in val_data_loader:
            output = network(data)
            loss = F.nll_loss(output.log(), target)
            valid_losses.append(loss.item())

        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)

        epoch_len = len(str(n_epochs))
        
        print_msg = (f'[{epoch:>{epoch_len}}/{n_epochs:>{epoch_len}}] ' +
                     f'train_loss: {train_loss:.5f} ' +
                     f'valid_loss: {valid_loss:.5f}')
        
        print(print_msg)
        train_losses = []
        valid_losses = []
        early_stopping(valid_loss, network)
        if early_stopping.early_stop:
            print("Early stopping")
            break
    network.load_state_dict(torch.load('checkpoint.pt'))

    return network

model = train()
model.eval()    
# with torch.no_grad():
correct = 0
total = 0
pred_y = None
for images, labels in test_data_loader:
    test_output = model(images)
    pred_y = torch.max(test_output, 1)[1].data.squeeze()
    accuracy = (pred_y == labels).sum().item() / float(labels.size(0))
    print('Test Accuracy of the feedforward model: %.4f' % accuracy)


train_data = torchvision.datasets.ImageFolder(root=TRAIN_DATA_PATH, transform=TRANSFORM_IMG)
train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=len(train_data), shuffle=True,  num_workers=1)
test_data = torchvision.datasets.ImageFolder(root=TEST_DATA_PATH2, transform=TRANSFORM_IMG)
test_data_loader = torch.utils.data.DataLoader(test_data, batch_size=len(test_data), shuffle=False, num_workers=1)
exp_data = torchvision.datasets.ImageFolder(root="./data/gt/", transform=TRANSFORM_IMG)
exp_data_loader = torch.utils.data.DataLoader(exp_data, batch_size=len(exp_data), shuffle=False, num_workers=1)
train_all = next(iter(train_data_loader))
test_all = next(iter(test_data_loader))
exp_all = next(iter(exp_data_loader))
train_images_all, _ = train_all
test_images_all, label_test_images_all = test_all
exp_images_all, label_exp_images_all = exp_all

test_output = model(test_images_all)
pred_y = torch.max(test_output, 1)[1].data.squeeze()
# print(pred_y)

if xai == 'shap':
    print('computing shap values')
    e = shap.DeepExplainer(model, train_images_all)
    shap_values = e.shap_values(test_images_all)
    shap_numpy = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values]
    test_numpy = np.swapaxes(np.swapaxes(test_images_all.numpy(), 1, -1), 1, 2)
    exp_numpy = np.swapaxes(np.swapaxes(exp_images_all.numpy(), 1, -1), 1, 2)
    sp = []
    for idx, ii in enumerate(label_test_images_all):
        sp.append(shap_numpy[ii][idx])
    sp = np.array(sp)
    pl = shap.image_plot(sp[:4], test_numpy[:4])
elif xai == 'lime':
    from skimage.color import gray2rgb, rgb2gray
    def pd(images):
        batch = torch.stack(tuple(torch.from_numpy(np.expand_dims(rgb2gray(i), axis=0)) for i in images), dim=0)
        return model(batch).detach().numpy()
    test_numpy = np.swapaxes(np.swapaxes(test_images_all.numpy(), 1, -1), 1, 2)
    test_numpy = test_numpy[:, :, :, 0]
    exp_numpy = np.swapaxes(np.swapaxes(exp_images_all.numpy(), 1, -1), 1, 2)
    exp_numpy = exp_numpy[:, :, :, 0]

    print('computing lime values')
    from lime import lime_image
    explainer = lime_image.LimeImageExplainer(feature_selection='none')
    lime_values = []
    for indx, ii in enumerate(test_images_all):
        explanation = explainer.explain_instance(np.array(ii[0]), pd, labels=(0,1), top_labels=None, hide_color=0, num_samples=1000, segmentation_fn=lime_image.SegmentationAlgorithm('slic', n_segments=20))
        explanation_2 = ImageExplanation_2(explanation.image, explanation.segments)
        explanation_2.intercept = explanation.intercept
        explanation_2.local_exp = explanation.local_exp
        explanation_2.local_pred = explanation.local_pred
        explanation_2.score = explanation.score
        temp, mask = explanation_2.get_image_and_mask(label_test_images_all[indx].item(), positive_only=True, num_features=20, hide_rest=False)
        lime_values.append(mask)
    lime_values = np.array(lime_values)
    sp = lime_values

def closest_node_which(node, nodes):
    nodes = np.asarray(nodes)
    dist_2 = np.sum((nodes - node)**2, axis=1)
    return np.argmin(dist_2)


def closest_node(node, nodes):
    nodes = np.asarray(nodes)
    dist_2 = np.sum((nodes - node)**2, axis=1)
    return np.sqrt(np.min(dist_2)) #/ ((100**2 + 100**2)))

metric = 0
metric_r = 0
counts = 0
for i in range(len(test_data)):
    counts += 1
    p = exp_numpy[i]
    if xai == 'shap':
        p = p[:,:,0]

    indices_p = np.argwhere(p > 0)

    q = sp[i]
    if xai == 'shap':
        q = q[:, :, 0]
    q1 = q.flatten()
    q1 = [jj for jj in q1 if jj > 0]
    # thres = np.percentile(q1, 75)
    thres = 0
    indices_q = np.argwhere(q>thres)

    #TWP
    metric_one = 0
    for point in indices_q:
        metric_one += q[point[0]][point[1]] * np.exp(-closest_node(np.array(point), indices_p))
    metric_one /= sum(iii for ii in q for iii in ii if iii > thres)
    metric += metric_one

    #TWR
    metric_one_r = 0
    deno = 0
    for point in indices_p:
        whi_min = closest_node_which(np.array(point), indices_q)
        metric_one_r += q[indices_q[whi_min][0]][indices_q[whi_min][1]] * np.exp(-closest_node(np.array(point), indices_q))
        deno += q[indices_q[whi_min][0]][indices_q[whi_min][1]]
    metric_one_r /= deno
    metric_r += metric_one_r
metric = metric / counts
print('TWP: %.6f' % metric)
metric_r = metric_r / counts #len(test_data)
print('TWR: %.6f' % metric_r)
