import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import os
import numpy as np
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from torch.optim.lr_scheduler import ReduceLROnPlateau
import PIL.Image as Image
import matplotlib.pyplot as plt
from tqdm import tqdm

import torchvision
import torch
import numpy as np
import matplotlib.pyplot as plt
# from captum.attr import IntegratedGradients

import random

from scipy.sparse import lil_matrix, csc_matrix
from scipy.sparse.linalg import spsolve
import math

from resnet import resnet18,resnet50,ResNet9


# device = 'cuda:5' if torch.cuda.is_available() else 'cpu'

random_list = [0,0.2,0.4,0.6,0.8,1.0]
# random_list = [1.0,0.8,0.6,0.4,0.2,0.0]

sparsity = [ 5, 10, 15, 20, 25, 30, 35, 40, 45, 50,
            55, 60, 65, 70, 75, 80, 85, 90, 95, 100]


neighbors_weights = [((1, 1), 1 / 12), ((0, 1), 1 / 6), ((-1, 1), 1 / 12), ((1, -1), 1 / 12), ((0, -1), 1 / 6),
                     ((-1, -1), 1 / 12), ((1, 0), 1 / 6), ((-1, 0), 1 / 6)]

device = 'cuda:2' if torch.cuda.is_available() else 'cpu'

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

test_dataset = torchvision.datasets.CIFAR100(root='../../data', train=False, download=True, transform=transform_test)
val_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=1, shuffle=False)
glob_interval = 32*32 / 100



def init_dl_program(
        device_name,
        seed=None,
        use_cudnn=True,
        deterministic=False,
        benchmark=True,
        use_tf32=False,
        max_threads=None
):
    import torch
    if max_threads is not None:
        torch.set_num_threads(max_threads)  # intraop
        if torch.get_num_interop_threads() != max_threads:
            torch.set_num_interop_threads(max_threads)  # interop
        try:
            import mkl
        except:
            pass
        else:
            mkl.set_num_threads(max_threads)

    if seed is not None:
        random.seed(seed)
        print("SEED ", seed)
        seed += 1
        np.random.seed(seed)
        print("SEED ", seed)
        seed += 1
        torch.manual_seed(seed)
        print("SEED ", seed)

    if isinstance(device_name, (str, int)):
        device_name = [device_name]

    devices = []
    for t in reversed(device_name):
        t_device = torch.device(t)
        devices.append(t_device)
        if t_device.type == 'cuda':
            assert torch.cuda.is_available()
            torch.cuda.set_device(t_device)
            if seed is not None:
                seed += 1
                torch.cuda.manual_seed(seed)
                print("SEED ", seed)

    devices.reverse()
    torch.backends.cudnn.enabled = use_cudnn
    torch.backends.cudnn.deterministic = deterministic
    torch.backends.cudnn.benchmark = benchmark

    if hasattr(torch.backends.cudnn, 'allow_tf32'):
        torch.backends.cudnn.allow_tf32 = use_tf32
        torch.backends.cuda.matmul.allow_tf32 = use_tf32

    return devices if len(devices) > 1 else devices[0]


class NoisyLinearImputer():
    def __init__(self, noise=0.01, weighting=neighbors_weights):
        """
            Noisy linear imputation.
            noise: magnitude of noise to add (absolute, set to 0 for no noise)
            weighting: Weights of the neighboring pixels in the computation.
            List of tuples of (offset, weight)
        """
        self.noise = noise
        self.weighting = weighting

    @staticmethod
    def add_offset_to_indices(indices, offset, mask_shape):
        """ Add the corresponding offset to the indices. Return new indices plus a valid bit-vector. """
        cord1 = indices % mask_shape[1]
        cord0 = indices // mask_shape[1]
        cord0 += offset[0]
        cord1 += offset[1]
        # in the image
        valid = ((cord0 < 0) | (cord1 < 0) | (cord0 >= mask_shape[0]) | (cord1 >= mask_shape[1]))
        # is current id in the image,   indices
        return ~valid, indices + offset[0] * mask_shape[1] + offset[1]

    @staticmethod
    def setup_sparse_system(mask, img, neighbors_weights):
        """ Vectorized version to set up the equation system.
            mask: (H, W)-tensor of missing pixels.
            Image: (H, W, C)-tensor of all values.
            Return (N,N)-System matrix, (N,C)-Right hand side for each of the C channels.
        """
        maskflt = mask.flatten()
        imgflat = img.reshape((img.shape[0], -1))
        # print(imgflat.shape)
        indices = np.argwhere(maskflt == 0).flatten()  # Indices that are imputed in the flattened mask
        coords_to_vidx = np.zeros(len(maskflt), dtype=int)
        coords_to_vidx[indices] = np.arange(len(indices))  # lookup_indices =
        # print(coords_to_vidx[:10])
        # coords_to_vidx = {(idx[0].item(), idx[1].item()): i for i, idx in enumerate(indices)} # Coordinates to variable index
        numEquations = len(indices)
        A = lil_matrix((numEquations, numEquations))  # System matrix
        b = np.zeros((numEquations, img.shape[0]))
        sum_neighbors = np.ones(numEquations)  # Sum of weights assigned
        # print("My indices:", indices[:10])
        # print("Num indices: ", len(indices))
        for n in neighbors_weights:
            offset, weight = n[0], n[1]  # location, weight
            # print("Using: ", offset, weight)
            # Sum of the neighbors.
            # Take out outliers
            valid, new_coords = NoisyLinearImputer.add_offset_to_indices(indices, offset, mask.shape)

            valid_coords = new_coords[valid]  # all the masks
            valid_ids = np.argwhere(valid == 1).flatten()  # the location of valid
            # print(valid_ids[:10], valid_coords[:10])
            # print("Valid:", valid_ids.shape)

            # Add values to the right hand-side
            has_values_coords = valid_coords[maskflt[valid_coords] > 0.5]  # the mask > 0.5  not removed this time
            has_values_ids = valid_ids[maskflt[valid_coords] > 0.5]  # the id > 0.5  not removed this time
            # print(has_values_ids[:10], has_values_coords[:10])
            # print("Has Values:", has_values_coords.shape)
            b[has_values_ids, :] -= weight * imgflat[:, has_values_coords].T

            # Add weights to the system (left hand side)
            has_no_values = valid_coords[maskflt[valid_coords] < 0.5]  # Find coordinates in the system.
            variable_ids = coords_to_vidx[has_no_values]
            has_no_values_ids = valid_ids[maskflt[valid_coords] < 0.5]

            # print("Has No Values:", has_no_values.shape)
            A[has_no_values_ids, variable_ids] = weight

            # Reduce weight for invalid
            # print(np.argwhere(valid==0).flatten()[:10])
            sum_neighbors[np.argwhere(valid == 0).flatten()] = sum_neighbors[np.argwhere(valid == 0).flatten()] - weight

        A[np.arange(numEquations), np.arange(numEquations)] = -sum_neighbors
        return A, b

    def __call__(self, img: torch.Tensor, mask: torch.Tensor):
        """ Our linear inputation scheme. """
        """
        This is the function to do the linear infilling 
        img: original image (C,H,W)-tensor;
        mask: mask; (H,W)-tensor

        """
        imgflt = img.reshape(img.shape[0], -1)
        maskflt = mask.reshape(-1)
        indices_linear = np.argwhere(maskflt == 0).flatten()  # Indices that need to be imputed.
        # Set up sparse equation system, solve system.
        A, b = NoisyLinearImputer.setup_sparse_system(mask.numpy(), img.numpy(), neighbors_weights)
        res = torch.tensor(spsolve(csc_matrix(A), b), dtype=torch.float)

        # Fill the values with the solution of the system.
        img_infill = imgflt.clone()
        img_infill[:, indices_linear] = res.t() + self.noise * torch.randn_like(res.t())

        return img_infill.reshape_as(img)

    def batched_call(self, img: torch.Tensor, mask: torch.Tensor):
        """ Pseudo implementation of batched interface. """
        res_list = []
        in_device = img.device
        for i in range(len(img)):
            res_list.append(self.__call__(img[i].cpu(), mask[i].cpu()))
        return torch.stack(res_list).to(in_device)

noise_imputer = NoisyLinearImputer()

def fill_image_imputer(ori_img, mask, run_num=1):
    img = noise_imputer(ori_img, 1 - mask)
    return img


def road_MoRF(net, inputs, targets, scores=None):
    acc_lists = []
    output = net(inputs.to(device))
    output = output.argmax().item()
    if math.fabs(output - targets) < 0.5:
        acc_lists.append(1)
    else:
        acc_lists.append(0)

    interval = glob_interval  # scores.shape[0] // 50

    score_vector = (1 - scores).reshape(-1)
    score_indexes = np.argsort(score_vector)  # min is the first

    for i in sparsity:
        masks = np.ones_like(score_vector)
        masks[score_indexes[:min(int(interval * (i + 1)), scores.shape[0])]] = 0.0
        masks = torch.from_numpy(masks).view(1, 1, inputs.shape[-2], inputs.shape[-1])

        fill_img_ = fill_image_imputer(inputs.clone()[0], 1-masks[0, 0], run_num=1)
        shapes = fill_img_.shape
        output = net(fill_img_.float().to(device).view(1, shapes[0], shapes[1], shapes[2]))
        output = output.argmax().item()
        if math.fabs(output - targets) < 0.5:
            acc_lists.append(1)
        else:
            acc_lists.append(0)
    return acc_lists


def road_LeRF(net, inputs, targets, scores=None):
    acc_lists = []
    output = net(inputs.to(device))
    output = output.argmax().item()
    if math.fabs(output - targets) < 0.5:
        acc_lists.append(1)
    else:
        acc_lists.append(0)

    interval = glob_interval
    score_vector = scores.reshape(-1)
    score_indexes = np.argsort(score_vector)
    for i in sparsity:
        masks = np.ones_like(score_vector)
        masks[score_indexes[:min(int(interval * (i + 1)), scores.shape[0])]] = 0.0
        masks = torch.from_numpy(masks).view(1, 1, inputs.shape[-2], inputs.shape[-1])

        fill_img_ = fill_image_imputer(inputs.clone()[0], 1-masks[0, 0], run_num=1)
        shapes = fill_img_.shape
        output = net(fill_img_.float().to(device).view(1, shapes[0], shapes[1], shapes[2]))
        output = output.argmax().item()  # int(output.item()>0)
        if math.fabs(output - targets) < 0.5:
            acc_lists.append(1)
        else:
            acc_lists.append(0)
    return acc_lists


def random_MoRF(net, inputs, targets, scores=None):
    acc_lists = []
    output = net(inputs.to(device))
    output = output.argmax().item()  # int(output.item()>threshold)
    if math.fabs(output - targets) < 0.5:
        acc_lists.append(1)
    else:
        acc_lists.append(0)

    interval = glob_interval  # scores.shape[0] // 50

    # score_tensor = torch.from_numpy(1-scores)

    score_vector = (1 - scores).reshape(-1)
    score_indexes = np.argsort(score_vector)  # min is the first

    for i in sparsity:
        masks = np.ones_like(score_vector)
        masks[score_indexes[:min(int(interval * (i + 1)), scores.shape[0])]] = 0.0
        masks = torch.from_numpy(masks).view(1, 1, inputs.shape[-2], inputs.shape[-1])

        output = net((inputs * masks).float().to(device))
        output = output.argmax().item()
        if math.fabs(output - targets) < 0.5:
            acc_lists.append(1)
        else:
            acc_lists.append(0)
    return acc_lists


def random_LeRF(net, inputs, targets, scores=None):
    acc_lists = []
    output = net(inputs.to(device))
    output = output.argmax().item()  # int(output.item()>threshold)
    if math.fabs(output - targets) < 0.5:
        acc_lists.append(1)
    else:
        acc_lists.append(0)

    interval = glob_interval
    score_vector = scores.reshape(-1)
    score_indexes = np.argsort(score_vector)
    # for i in range(0,50):
    for i in sparsity:

        masks = np.ones_like(score_vector)
        masks[score_indexes[:min(int(interval * (i + 1)), scores.shape[0])]] = 0.0
        masks = torch.from_numpy(masks).view(1, 1, inputs.shape[-2], inputs.shape[-1])

        output = net((inputs * masks).float().to(device))
        output = output.argmax().item()
        if math.fabs(output - targets) < 0.5:
            acc_lists.append(1)
        else:
            acc_lists.append(0)
    return acc_lists


def ours_random_MoRF1(net, inputs, targets,
                     robust_ratio=0.1, num_samples=50, max_ratio=0.1, use_max=False,
                     scores=None):
    acc_lists = []
    output = net(inputs.to(device))
    output = output.argmax().item()  # int(output.item()>threshold)
    if math.fabs(output - targets) < 0.5:
        acc_lists.append(1)
    else:
        acc_lists.append(0)

    def samples(input, masks, pexils_ids,robust_ratio):
        union_masks = torch.rand([num_samples, masks.shape[-3], masks.shape[-2], masks.shape[-1]],device=device)
        union_masks = torch.where(union_masks > robust_ratio, torch.ones_like(union_masks),
                                  torch.zeros_like(union_masks))
        new_masks = masks + (1 - masks) * union_masks.to(device)
                # visualization_save(input,new_masks[i],"./images_for_debug/MORF_%d_%d.jpg"%(pexils_ids,i))

        new_inputs = torch.tile(input, [num_samples, 1, 1, 1]).to(device)
        new_inputs = new_inputs * new_masks

        new_outputs = net(new_inputs.float())
        new_outputs = new_outputs.argmax(
            dim=-1).cpu()  # torch.where(new_outputs>threshold,torch.ones_like(new_outputs),torch.zeros_like(new_outputs))

        results = new_outputs.float() - targets
        results = torch.where(torch.abs(results) < 0.5, torch.ones_like(results), torch.zeros_like(results))

        return results.mean()

    # score_tensor = torch.from_numpy(1-scores)

    score_vector = (1 - scores).reshape(-1)
    score_indexes = np.argsort(score_vector)  # min is the first

    all_pexils = score_vector.shape[0]
    interval = glob_interval  # scores.shape[0] // 50
    for i in sparsity:
        masks = np.ones_like(score_vector)
        expl_size = min(int(interval * (i + 1)), scores.shape[0])
        masks[score_indexes[:expl_size]] = 0.0
        masks = torch.from_numpy(masks).view(1, 1, inputs.shape[-2], inputs.shape[-1]).to(device)

        # visualization_save(inputs,masks[0],
        # save_name="./images_for_debug/LERF_%d.jpg"%(i+1))
        ratio = robust_ratio
        if use_max:
            if expl_size*robust_ratio> max_ratio * all_pexils:
                ratio = max_ratio * all_pexils/ expl_size

        output = samples(inputs, masks, i,ratio)
        acc_lists.append(output.item())

    return acc_lists


def ours_random_LeRF1(net, inputs, targets,
                     robust_ratio=0.1, num_samples=50, max_ratio=0.1, use_max=False,
                     scores=None):
    acc_lists = []
    output = net(inputs.to(device))
    output = output.argmax().item()  # int(output.item()>threshold)
    if math.fabs(output - targets) < 0.5:
        acc_lists.append(1)
    else:
        acc_lists.append(0)

    def samples(input, masks, pexils_ids,robust_ratio):
        union_masks = torch.rand([num_samples, masks.shape[-3], masks.shape[-2], masks.shape[-1]],device=device)
        union_masks = torch.where(union_masks > robust_ratio, torch.ones_like(union_masks),
                                  torch.zeros_like(union_masks))

        new_masks = masks + (1 - masks) * union_masks.to(device)

            # visualization_save(input,new_masks[i],"./images_for_debug/LERF_%d_%d.jpg"%(pexils_ids,i))
        # random_n = np.random.randn(num_samples,inputs.shape[1])
        # random_arr = np.random.random([num_samples, inputs.shape[1]])
        # random_arr = np.where(random_arr < robust_ratio, np.zeros_like(random_arr), np.ones_like(random_arr))
        # random_arr = masks + random_arr * (1 - masks)  #

        new_inputs = torch.tile(input, [num_samples, 1, 1, 1]).to(device)
        new_inputs = new_inputs * new_masks

        new_outputs = net(new_inputs.float())
        new_outputs = new_outputs.argmax(
            dim=-1).cpu()  # torch.where(new_outputs>threshold,torch.ones_like(new_outputs),torch.zeros_like(new_outputs))
        results = new_outputs.float() - targets
        results = torch.where(torch.abs(results) < 0.5, torch.ones_like(results), torch.zeros_like(results))

        return results.mean()

    score_vector = (scores).reshape(-1)
    score_indexes = np.argsort(score_vector)  # min is the first

    interval = glob_interval  # scores.shape[0] // 50
    all_pexils = score_vector.shape[0]
    for i in sparsity:
        masks = np.ones_like(score_vector)
        expl_size = min(int(interval * (i + 1)), scores.shape[0])
        masks[score_indexes[:expl_size]] = 0.0
        masks = torch.from_numpy(masks).view(1, 1, inputs.shape[-2], inputs.shape[-1]).to(device)

        # visualization_save(inputs,masks[0],
        # save_name="./images_for_debug/LERF_%d.jpg"%(i+1))
        ratio = robust_ratio
        if use_max:
            if expl_size*robust_ratio> max_ratio * all_pexils:
                ratio = max_ratio * all_pexils/ expl_size

        output = samples(inputs, masks, i,ratio)
        acc_lists.append(output.item())

    return acc_lists


global_dir = '../../data/cifar100/IG/'
# global_dir1 = '../../data/cifar100/0923/'
explanation_name_list= [global_dir+'resnet-9-explanations_random0.000000_%d.npy',
                        global_dir+'resnet-9-explanations_random0.200000_%d.npy',
                        global_dir + 'resnet-9-explanations_random0.400000_%d.npy',
                        global_dir + 'resnet-9-explanations_random0.600000_%d.npy',
                        global_dir + 'resnet-9-explanations_random0.800000_%d.npy',
                        global_dir + 'resnet-9-explanations_random1.000000_%d.npy']

finetune_explanation_name_list= [global_dir+'resnet-9-finetune-explanations_random0.000000_%d.npy',
                        global_dir+'resnet-9-finetune-explanations_random0.200000_%d.npy',
                        global_dir + 'resnet-9-finetune-explanations_random0.400000_%d.npy',
                        global_dir + 'resnet-9-finetune-explanations_random0.600000_%d.npy',
                        global_dir + 'resnet-9-finetune-explanations_random0.800000_%d.npy',
                        global_dir + 'resnet-9-finetune-explanations_random1.000000_%d.npy']

def evaluate_ori_ori(seed):
    model = ResNet9(3, 100)
    state_dict = torch.load(global_dir + 'resnet-9.pth',map_location='cpu')
    model.load_state_dict(state_dict['net'])
    model.to(device)
    model.eval()
    net = model

    for explantion_name in explanation_name_list:

        path = explantion_name%seed
        if not os.path.exists(path):
            continue

        save_path = path.replace('.npy','_ori_model_ori.npy')
        if os.path.exists(save_path):
            continue

        explaination_list = np.load(path)
        acc_lists_LeRF = []
        acc_lists_MoRF = []
        count = 0
        for batch_idx, (inputs, targets) in tqdm(enumerate(val_loader)):
            scores = explaination_list[count].copy()
            scores = scores.reshape(-1)

            true_or_false = random_LeRF(net, inputs.clone(), targets.clone(), scores=scores)
            acc_lists_LeRF.append(true_or_false)
            true_or_false = random_MoRF(net, inputs.clone(), targets.clone(), scores=scores)
            acc_lists_MoRF.append(true_or_false)
            count += 1

        acc_lists_LeRF = np.array(acc_lists_LeRF)
        acc_lists_MoRF = np.array(acc_lists_MoRF)
        acc_LeRF = acc_lists_LeRF.mean(axis=0)
        acc_MoRF = acc_lists_MoRF.mean(axis=0)
        print(acc_LeRF)
        print(acc_MoRF)
        np.save( save_path,[acc_LeRF, acc_MoRF,acc_lists_LeRF,acc_lists_MoRF])


def evaluation_ori_rfid(seed, robust_ratio_MoRF=0.5, robust_ratio_LeRF=0.5):
    model = ResNet9(3, 100)
    state_dict = torch.load(global_dir + 'resnet-9.pth', map_location='cpu')
    model.load_state_dict(state_dict['net'])
    model.to(device)
    model.eval()
    net = model

    for explantion_name in explanation_name_list:

        path = explantion_name % seed
        if not os.path.exists(path):
            continue

        save_path = path.replace('.npy', '_ori_model_rfid.npy')
        if os.path.exists(save_path):
            continue

        explaination_list = np.load(path)


        # point evaluation
        acc_lists_LeRF = []
        acc_lists_MoRF = []
        count = 0
        for batch_idx, (inputs, targets) in tqdm(enumerate(val_loader)):
            scores = explaination_list[count].copy()
            scores = scores.reshape(-1)


            true_or_false = ours_random_LeRF1(net, inputs.clone(), targets.clone(),
                                             robust_ratio=robust_ratio_LeRF,use_max=False,
                                             scores=scores)
            acc_lists_LeRF.append(true_or_false)
            true_or_false = ours_random_MoRF1(net, inputs.clone(), targets.clone(),
                                             robust_ratio=robust_ratio_MoRF,use_max=False,
                                             scores=scores)
            acc_lists_MoRF.append(true_or_false)
            count += 1

        acc_lists_LeRF = np.array(acc_lists_LeRF)
        acc_lists_MoRF = np.array(acc_lists_MoRF)
        acc_LeRF = acc_lists_LeRF.mean(axis=0)
        acc_MoRF = acc_lists_MoRF.mean(axis=0)
        print(acc_LeRF)
        print(acc_MoRF)
        np.save(
            save_path,[acc_LeRF,acc_MoRF,acc_lists_LeRF,acc_lists_MoRF])



def evaluation_finetune_ffid(seed, robust_ratio_MoRF=0.5, robust_ratio_LeRF=0.5,threshold=0.1):
    model = ResNet9(3, 100)
    state_dict = torch.load(global_dir + 'resnet-9-finetune.pth', map_location='cpu')
    model.load_state_dict(state_dict['net'])
    # model = resnet18_init
    model.to(device)
    model.eval()
    net = model

    for explantion_name in finetune_explanation_name_list:
        path = explantion_name % seed
        if not os.path.exists(path):
            continue

        save_path = path.replace('.npy', '_finetune_model_ffid.npy')
        if os.path.exists(save_path):
            continue

        explaination_list = np.load(path)

        # point evaluation
        acc_lists_LeRF = []
        acc_lists_MoRF = []
        count = 0
        for batch_idx, (inputs, targets) in tqdm(enumerate(val_loader)):
            scores = explaination_list[count].copy()
            scores = scores.reshape(-1)


            true_or_false = ours_random_LeRF1(net, inputs.clone(), targets.clone(),
                                             robust_ratio=robust_ratio_LeRF,use_max=True,
                                             scores=scores)
            acc_lists_LeRF.append(true_or_false)
            true_or_false = ours_random_MoRF1(net, inputs.clone(), targets.clone(),
                                             robust_ratio=robust_ratio_MoRF,use_max=True,
                                             scores=scores)
            acc_lists_MoRF.append(true_or_false)
            count += 1

        acc_lists_LeRF = np.array(acc_lists_LeRF)
        acc_lists_MoRF = np.array(acc_lists_MoRF)
        acc_LeRF = acc_lists_LeRF.mean(axis=0)
        acc_MoRF = acc_lists_MoRF.mean(axis=0)
        print(acc_LeRF)
        print(acc_MoRF)
        np.save(
            save_path,
            [acc_LeRF, acc_MoRF, acc_lists_LeRF, acc_lists_MoRF])



if __name__ == "__main__":
    for i in range(5):
        # i = 5-i
        init_dl_program(device, seed=i)

        evaluate_ori_ori(i)
        evaluation_ori_rfid(i)
        evaluation_finetune_ffid(i)
        