import os
import timeit
import pickle
import argparse
from datetime import datetime
from tqdm.auto import tqdm

import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import SGD, Optimizer
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10, FashionMNIST
from torchvision.transforms import Compose, RandomCrop, RandomHorizontalFlip, ToTensor

import numpy as np; np.set_printoptions(suppress=True)
import matplotlib.pyplot as plt

np.random.seed(0)
torch.manual_seed(0)
torch.set_printoptions(sci_mode=False)

#lime image specific utilities
from lime import lime_image
from skimage.segmentation import mark_boundaries
from skimage.color import gray2rgb, rgb2gray, label2rgb
from lime.wrappers.scikit_image import SegmentationAlgorithm
from numpy import linalg as la

from vgg_model import VGG_CIFAR, VGG_FMNIST
from functools import partial

from ExpCertifyBB import Ecertify


def tensor_imshow(image_tensor):
    plt.imshow(image_tensor.permute(1, 2, 0).numpy())
    plt.show()


def tensor_to_image(image_tensor):
    return image_tensor.permute(1, 2, 0).numpy()


def image_to_tensor(image):
    return torch.from_numpy(image).permute(2, 0, 1)


def _predict_proba(image_batch, model):
    model.eval()
    with torch.torch.no_grad():
        logits = model(image_batch)
        probs = F.softmax(logits, dim=1)
        return probs.detach()


def _single_predict_proba(image_tensor, model):
    return _predict_proba(image_tensor.unsqueeze(dim=0), model)[0]



if __name__ == '__main__':
    # Parse command line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default="cifar10", help='dataset name')
    parser.add_argument('--seed', type=int, default=1234, help='random seed')
    parser.add_argument('--imgno', type=int, default=3, help='image index to certify')
    args = parser.parse_args()

    dataset = args.dataset
    seed = args.seed
    imgno = args.imgno

    print(dataset)
    run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_dir = f'./results/{dataset}_{run_id}'
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    print(results_dir)


    transforms = {
        "train": Compose([
            # RandomCrop(32, padding=4),  # introduces some black borders
            # RandomHorizontalFlip(),
            ToTensor(),
        ]),
        "test": ToTensor(),
    }

    dataset = {}
    for split in ["train", "test"]:
        dataset[split] = CIFAR10(
            # root="data/cifar10",
            # root="/dccstor/irtltabfm01/data/cifar10",
            root="./data/cifar10",
            train=(split == "train"),
            download=False,
            transform=transforms[split],
        )

    MODEL_PATH = './doing-basic-things/saved_models_pytorch/cifar10/vgg_cifar10_trained.pth'
    model = VGG_CIFAR()
    model.load_state_dict(torch.load(MODEL_PATH, map_location='cpu'))

    print(f"model is in gpu? {next(model.parameters()).is_cuda}")
    # imgno = 11
    # imgnotst = 222
    nlabels = 10
    datatype = np.float32
    nsamples = 1000

    predict_proba = partial(_predict_proba, model=model)
    single_predict_proba = partial(_single_predict_proba, model=model)

    # here image is a numpy batch array
    def batch_predict(images):
        # print(images.shape)  # (10, 32, 32, 3)
        batch = torch.stack(tuple(torch.from_numpy(i) for i in images), dim=0)
        batch = batch.permute(0, 3, 1, 2)
        # print(f"batch shape={batch.shape}")
        return predict_proba(batch).numpy()


    segmenter = SegmentationAlgorithm('quickshift', kernel_size=1, max_dist=200, ratio=0.2)

    inputimg = dataset["test"][imgno][0]  # tensor of shape (3, 32, 32)
    inputimglabel = dataset["test"][imgno][1]
    # tensor_imshow(inputimg)
    # img = gray2rgb(inputimg[0]).astype(datatype)
    img = inputimg.permute(1, 2, 0).numpy()  # img must be (32, 32, 3) numpy array, otherwise segmenter won't work
    nrows = img.shape[0]
    print(f"certifying lime for idx={imgno} with ground truth label={inputimglabel} ({dataset['train'].classes[inputimglabel]})")

    explainer = lime_image.LimeImageExplainer(verbose=False)
    explanation = explainer.explain_instance(img,
                                             batch_predict,  # classification function
                                             top_labels=nlabels,
                                             hide_color=0,
                                             num_samples=nsamples,
                                             segmentation_fn=segmenter)

    num_segments = len(np.unique(explanation.segments))
    toplabelinp = explanation.top_labels[0]

    # Multiply img by coefficients to get explanation prediction probability
    # Compare with pred prob outputted by the model for img

    coeff = np.zeros((nlabels, nrows, nrows))
    # masktr = np.zeros((nlabels,nrows,nrows))
    intercepts = np.zeros(nlabels)
    # print(explanation.local_exp)
    seg_size = np.zeros(len(np.unique(explanation.segments)))

    for l in range(nlabels):
        # coeff = np.zeros((nlabels,nrows,nrows))
        intercepts[l] = explanation.intercept[l]
        seg_size = np.zeros(len(np.unique(explanation.segments)))
        for i in range(nrows):
            for j in range(nrows):
                for k in range(len(explanation.local_exp[l])):
                    if explanation.local_exp[l][k][0] == explanation.segments[i][j]:
                        coeff[l][i][j] = explanation.local_exp[l][k][1]
                        seg_size[explanation.local_exp[l][k][0]] += 1
                        break

        for i in range(nrows):
            for j in range(nrows):
                coeff[l][i][j] = coeff[l][i][j] / seg_size[explanation.segments[i][j]]

    pred_exp = np.zeros((1, nlabels))
    masktr = np.zeros((nlabels, nrows, nrows))

    # Predict using mask of test input with coeffs obtained from the input you created explanations for
    # thresh = 0.001
    # for l in range(nlabels):
    #     pred_exp[0][l] = np.dot(np.reshape(coeff[l], nrows ** 2), np.reshape(abs(masktst), nrows ** 2)) + intercepts[l]
    #     masktr[l] = np.multiply(coeff[l], abs(masktst))
    #
    #     for i in range(nrows):
    #         for j in range(nrows):
    #             if masktr[l][i][j] > thresh:
    #                 masktr[l][i][j] = 1
    #             else:
    #                 masktr[l][i][j] = 0

    # x is a single image
    def _bb(x, label_x0):
        # x is flattened array, reshape it to (32, 32, 3) -- numpy array
        x = x.reshape(-1, 32, 32, 3).astype(datatype)
        probas = batch_predict(x)
        return probas[0][label_x0]


    def _e(x, label_x0):
        # x is a flattened array
        pred_exp = np.zeros((1, nlabels))
        masktr = np.zeros((nlabels, nrows, nrows))
        explanationtst = explainer.explain_instance(x.reshape(32, 32, 3).astype(datatype),
                                                    batch_predict,  # classification function
                                                    top_labels=nlabels,
                                                    hide_color=0,
                                                    num_samples=50,
                                                    # num_samples=nsamples // 10,
                                                    segmentation_fn=segmenter)

        temptst, masktst = explanationtst.get_image_and_mask(label_x0, positive_only=False, hide_rest=False)

        # Predict using mask of test input with coeffs obtained from the input you created explanations for
        thresh = 0.001
        for l in range(nlabels):
            pred_exp[0][l] = np.dot(np.reshape(coeff[l], nrows ** 2), np.reshape(abs(masktst), nrows ** 2)) + \
                             intercepts[l]
            masktr[l] = np.multiply(coeff[l], abs(masktst))

            for i in range(nrows):
                for j in range(nrows):
                    if masktr[l][i][j] > thresh:
                        masktr[l][i][j] = 1
                    else:
                        masktr[l][i][j] = 0
        # print(pred_exp)
        return pred_exp[0][label_x0]


    def f(x):  # Fidelity function
        # fidelity = 1-abs(bb(x) - e(x))/max(abs(bb(x)), abs(e(x))) #Normalized MAE
        bbx = bb(x)
        ex = e(x)
        # print(bbx, ex)
        fidelity = 1 - np.abs(bbx - ex)  # 1 - MAE
        return fidelity


    datatype = img.dtype

    bb = partial(_bb, label_x0=inputimglabel)
    e = partial(_e, label_x0=inputimglabel)

    result = {}
    # QUERY_BUDGET = 10 ** (np.arange(4) + 1)
    QUERY_BUDGET = [10, 100, 1000, 10000]
    STRATEGIES = [1, 2, 3, 4]
    CHOICE = "min"
    NUMRUNS = 100

    # - certification code starts here
    certresult = {}

    for s in STRATEGIES:
        print(f"strategy: {s}")
        NUMRUNS = 2 if s == 4 else 10
        certresult[s] = {}
        for Q in QUERY_BUDGET:

            if s==4:
                NUMRUNS=2
            else:
                if Q==10000:
                    NUMRUNS=5
                else:
                    NUMRUNS=10

            if s==4 and Q<100:
                continue

            print(f"\tcertifying with s={s}, Q={Q} with NUMRUNS={NUMRUNS}")

            certresult[s][Q] = {"w": [], "time-per-run": None, "num-runs": NUMRUNS}

            x = img.flatten()
            d = len(x)

            theta = 0.75  # fidelity threshold
            Z = 10  # number of hypercubes to certify
            eps = 0.01 / d  # min gap between lb and ub
            sigma = 0.001
            certicubeperrun = np.zeros(NUMRUNS)

            t_0 = timeit.default_timer()
            for irun in range(NUMRUNS):
                print(f"irun: {irun}")
                ub = 1  # initial hypercube half-width
                lb = 0  # since x is the center of the hypercube
                Currbst = 0  # current certified hypercube half width
                # print(f"fidelity at x0: f(x)={f(x):.4f}")
                try:
                    Certicube = Ecertify(x, theta, Z, Q, lb, ub, sigma, s, f, choice=CHOICE, eps_mul=0.01)
                    print(Certicube)
                except Exception as e:
                    print(e)
                    Certicube = -1
                certicubeperrun[irun] = Certicube
            t_1 = timeit.default_timer()
            time_per_run = round((t_1 - t_0) / NUMRUNS, 3)
            print(str(np.mean(certicubeperrun)) + ' +- ' + str(np.std(certicubeperrun) / np.sqrt(NUMRUNS)))
            print(f"found: {np.mean(certicubeperrun):.4f} +- {np.std(certicubeperrun) / np.sqrt(NUMRUNS):.6f}, 1/d: {1 / d:.4f}")
            print(f"\nTime per run: {time_per_run} s")

            certresult[s][Q]["w"] = certicubeperrun
            certresult[s][Q]["time-per-run"] = time_per_run
            certresult[s][Q]["num-runs"] = NUMRUNS
            certresult[s][Q]["w-mean"] = np.mean(certicubeperrun)
            certresult[s][Q]["w-error"] = np.std(certicubeperrun) / np.sqrt(NUMRUNS)
            certresult[s][Q]["choice"] = CHOICE

    result[imgno] = certresult

    fname = f"result_{imgno}_cifar10"
    res_object_name = os.path.join(results_dir, fname)
    with open(res_object_name, 'wb') as output:
        pickle.dump(result, output, pickle.HIGHEST_PROTOCOL)



















