import os
import timeit
import pickle
import argparse
from datetime import datetime
from tqdm.auto import tqdm

import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import SGD, Optimizer
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10, FashionMNIST
from torchvision.transforms import Compose, RandomCrop, RandomHorizontalFlip, ToTensor

import numpy as np; np.set_printoptions(suppress=True)
import pandas as pd
import matplotlib.pyplot as plt

np.random.seed(0)
torch.manual_seed(0)
torch.set_printoptions(sci_mode=False)

#lime image specific utilities
from lime import lime_image
from skimage.segmentation import mark_boundaries
from skimage.color import gray2rgb, rgb2gray, label2rgb
from lime.wrappers.scikit_image import SegmentationAlgorithm
from numpy import linalg as la

from vgg_model import VGG_CIFAR, VGG_FMNIST
from functools import partial

import shap

from ExpCertifyBB import Ecertify


def tensor_imshow(image_tensor):
    plt.imshow(image_tensor.permute(1, 2, 0).numpy())
    plt.show()


def tensor_to_image(image_tensor):
    return image_tensor.permute(1, 2, 0).numpy()


def image_to_tensor(image):
    return torch.from_numpy(image).permute(2, 0, 1)


def _predict_proba(image_batch, model):
    model.eval()
    with torch.torch.no_grad():
        logits = model(image_batch)
        probs = F.softmax(logits, dim=1)
        return probs.detach()


def _single_predict_proba(image_tensor, model):
    return _predict_proba(image_tensor.unsqueeze(dim=0), model)[0]



if __name__ == '__main__':
    # Parse command line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default="cifar10", help='dataset name')
    parser.add_argument('--seed', type=int, default=1234, help='random seed')
    parser.add_argument('--imgno', type=int, default=3, help='image index to certify')
    args = parser.parse_args()

    dataset = args.dataset
    seed = args.seed
    imgno = args.imgno

    print(dataset)
    run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_dir = f'./results/{dataset}_{run_id}'
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    print(results_dir)


    transforms = {
        "train": Compose([
            # RandomCrop(32, padding=4),  # introduces some black borders
            # RandomHorizontalFlip(),
            ToTensor(),
        ]),
        "test": ToTensor(),
    }

    dataset = {}
    for split in ["train", "test"]:
        dataset[split] = CIFAR10(
            # root="data/cifar10",
            # root="/dccstor/irtltabfm01/data/cifar10",
            root="./data/cifar10",
            train=(split == "train"),
            download=False,
            transform=transforms[split],
        )

    MODEL_PATH = './doing-basic-things/saved_models_pytorch/cifar10/vgg_cifar10_trained.pth'
    model = VGG_CIFAR()
    model.load_state_dict(torch.load(MODEL_PATH, map_location='cpu'))

    print(f"model is in gpu? {next(model.parameters()).is_cuda}")
    # imgno = 11
    # imgnotst = 222
    nlabels = 10
    datatype = np.float32
    nsamples = 1000

    predict_proba = partial(_predict_proba, model=model)
    single_predict_proba = partial(_single_predict_proba, model=model)

    # here image is a numpy batch array
    def batch_predict(images):
        # print(images.shape)  # (10, 32, 32, 3)
        images = images.reshape(-1, 32, 32, 3)
        # print(images.shape)
        batch = torch.stack(tuple(torch.from_numpy(i.astype("float32")) for i in images), dim=0)
        # print(batch.shape)
        batch = batch.permute(0, 3, 1, 2)
        # print(f"batch shape={batch.shape}")
        return predict_proba(batch).numpy()


    segmenter = SegmentationAlgorithm('quickshift', kernel_size=1, max_dist=200, ratio=0.2)

    inputimg = dataset["test"][imgno][0]  # tensor of shape (3, 32, 32)
    inputimglabel = dataset["test"][imgno][1]
    # tensor_imshow(inputimg)
    # img = gray2rgb(inputimg[0]).astype(datatype)
    img = inputimg.permute(1, 2, 0).numpy()  # img must be (32, 32, 3) numpy array, otherwise segmenter won't work
    nrows = img.shape[0]
    print(f"certifying shap for idx={imgno} with ground truth label={inputimglabel} ({dataset['train'].classes[inputimglabel]})")

    feature_names = [f"x_{i}" for i in range(3072)]
    EX = pd.Series(dataset["train"].data.reshape(50000, -1).mean(0), index=feature_names)
    npEX = np.array(EX)
    StdX = np.array(dataset["train"].data.reshape(50000, -1).std(0))

    def get_SHAP_classifier(label_x0, phi, phi0, x0, EX):
        import sklearn
        coef = np.divide(phi[label_x0], (x0 - EX), where=(x0 - EX) != 0)
        g = sklearn.linear_model.Ridge(alpha=1.0, fit_intercept=True)  # , normalize=False)
        g.coef_ = coef
        g.intercept_ = phi0[label_x0]
        return g


    x0 = img.flatten()
    SHAPEXPL = shap.KernelExplainer(batch_predict, EX, nsamples=500)
    shap_phi = SHAPEXPL.shap_values(x0, l1_reg="num_features(1000)")
    shap_phi0 = SHAPEXPL.expected_value
    func = get_SHAP_classifier(inputimglabel, shap_phi, shap_phi0, x0, EX)


    def _e(x, expl_func):
        """
                    x: single 1d numpy array of shape (d, )
            expl_func: a callable/sklearn model with predict method
        """
        x = [x-EX]
        return expl_func.predict(x)[0]


    # x is a single image
    def _bb(x, label_x0):
        # x is flattened array, reshape it to (32, 32, 3) -- numpy array
        x = x.reshape(-1, 32, 32, 3).astype(datatype)
        probas = batch_predict(x)
        return probas[0][label_x0]


    bb = partial(_bb, label_x0=inputimglabel)
    e = partial(_e, expl_func=func)


    def f(x):  # Fidelity function
        # fidelity = 1-abs(bb(x) - e(x))/max(abs(bb(x)), abs(e(x))) #Normalized MAE
        bbx = bb(x)
        ex = e(x)
        # print(bbx, ex)
        fidelity = 1 - np.abs(bbx - ex)  # 1 - MAE
        return fidelity


    datatype = img.dtype

    result = {}
    # QUERY_BUDGET = 10 ** (np.arange(4) + 1)
    QUERY_BUDGET = [10, 100, 1000, 10000]
    STRATEGIES = [1, 2, 3, 4]
    CHOICE = "min"
    NUMRUNS = 100

    # - certification code starts here
    certresult = {}

    for s in STRATEGIES:
        print(f"strategy: {s}")
        NUMRUNS = 2 if s == 4 else 10
        certresult[s] = {}
        for Q in QUERY_BUDGET:

            if s==4 and Q<100:
                continue

            print(f"\tcertifying with s={s}, Q={Q} with NUMRUNS={NUMRUNS}")

            certresult[s][Q] = {"w": [], "time-per-run": None, "num-runs": NUMRUNS}

            x = img.flatten()
            d = len(x)

            theta = 0.75  # fidelity threshold
            Z = 10  # number of hypercubes to certify
            eps = 0.01 / d  # min gap between lb and ub
            sigma = 0.001
            certicubeperrun = np.zeros(NUMRUNS)

            t_0 = timeit.default_timer()
            for irun in range(NUMRUNS):
                print(f"irun: {irun}")
                ub = 1  # initial hypercube half-width
                lb = 0  # since x is the center of the hypercube
                Currbst = 0  # current certified hypercube half width
                # print(f"fidelity at x0: f(x)={f(x):.4f}")
                try:
                    Certicube = Ecertify(x, theta, Z, Q, lb, ub, sigma, s, f, choice=CHOICE, eps_mul=0.01)
                except Exception as e:
                    print(e)
                    Certicube = -1
                certicubeperrun[irun] = Certicube
            t_1 = timeit.default_timer()
            time_per_run = round((t_1 - t_0) / NUMRUNS, 3)
            print(str(np.mean(certicubeperrun)) + ' +- ' + str(np.std(certicubeperrun) / np.sqrt(NUMRUNS)))
            print(f"found: {np.mean(certicubeperrun):.4f} +- {np.std(certicubeperrun) / np.sqrt(NUMRUNS):.6f}, 1/d: {1 / d:.4f}")
            print(f"\nTime per run: {time_per_run} s")

            certresult[s][Q]["w"] = certicubeperrun
            certresult[s][Q]["time-per-run"] = time_per_run
            certresult[s][Q]["num-runs"] = NUMRUNS
            certresult[s][Q]["w-mean"] = np.mean(certicubeperrun)
            certresult[s][Q]["w-error"] = np.std(certicubeperrun) / np.sqrt(NUMRUNS)
            certresult[s][Q]["choice"] = CHOICE

    result[imgno] = certresult

    fname = f"result_{imgno}_cifar10_shap"
    res_object_name = os.path.join(results_dir, fname)
    with open(res_object_name, 'wb') as output:
        pickle.dump(result, output, pickle.HIGHEST_PROTOCOL)

