import os.path

import torch
import torch.nn.functional as F
from utils import preprocess
from functools import partial
import torch.nn as nn
import model_editing.helpers.context_helpers as ch

DEFAULT_DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# def hook_feature_with_name(name, module, input, output):
#     # hooked_features[name + '_pre'] = input[0]
#     hooked_features[name + '_post'] = output


# def _extract_layer_grads(self, module, in_grad, out_grad):
#     # function to collect the gradient outputs
#     # from each layer
#
#     if not module.bias is None:
#         self.feature_grads.append(out_grad[0])

class IntegratedGradientsWithRef(object):
    def __init__(self, model, k=10, exp_obj='logit', dataset_name='imagenet'):
        self.model = model
        self.model.eval()
        self.k = k
        self.exp_obj = exp_obj
        self.dataset_name = dataset_name

        self.hooked_features = {}
        self.hooked_gradients = {}

        # def get_all_modules(module, prefix=''):
        #     all_modules = {}
        #     for name, submodule in module.named_children():
        #         submodule_name = prefix + '.' + name if prefix else name
        #         if isinstance(submodule, nn.Sequential):
        #             all_modules.update(get_all_modules(submodule, submodule_name))
        #         else:
        #             # print(submodule_name)
        #             all_modules[submodule_name] = submodule
        #     return all_modules

        all_modules = ch.get_all_modules(model)

        self.act_context_list = []
        for name, submodule in all_modules.items():
            if 'pool' not in name and 'conv' not in name:
                submodule.register_forward_hook(partial(self._hook_feature_with_name, name))
                submodule.register_backward_hook(partial(self._hook_grad_with_name, name))

            ARCH = 'resnet18' # if dataset_name == 'cifar10' or dataset_name == 'imagenet' else 'resnetxx'
            # dataset_name = 'imagenet'
            # cache_f = f"./model_editing/cache/covariances/spurious_{dataset_name}_{ARCH}_{name}/ZM_k.pt"
            cache_f = f"./model_editing/cache/covariances/backdoor_{dataset_name}_{ARCH}_{name}/ZM_k.pt"

            if os.path.exists(cache_f):
                context = torch.load(cache_f).to('cuda')
                # self.act_context_dict[name] = context
                self.act_context_list.append(context)

    def _hook_feature_with_name(self, name, module, input, output):  ##### replace output with input[0]
        if name in self.hooked_features.keys():
            self.hooked_features[name][1] = input[0]  # output
        else:
            self.hooked_features[name] = [input[0], 0]

    def _hook_grad_with_name(self, name, module, in_grad, out_grad):
        # if type(out_grad) == tuple: ##### replace out_grad with in_grad
        if name in self.hooked_gradients.keys():
            self.hooked_gradients[name] += out_grad[0]/self.k
        else:
            self.hooked_gradients[name] = out_grad[0]/self.k

        if name == 'layer1.0':
            if name+'_in' in self.hooked_gradients.keys():
                self.hooked_gradients[name+'_in'] += in_grad[0] / self.k
            else:
                self.hooked_gradients[name+'_in'] = in_grad[0] / self.k

        # ---------------
        # if name+'_post' in self.hooked_gradients.keys():
        #     self.hooked_gradients[name + '_post'][1] = out_grad
        # else:
        #     self.hooked_gradients[name + '_post'] = [out_grad, 0]

    def _get_samples_input(self, input_tensor, reference_tensor):
        '''
        calculate interpolation points
        Args:
            input_tensor: Tensor of shape (batch, ...), where ... indicates
                          the input dimensions. 
            reference_tensor: A tensor of shape (batch, k, ...) where ... 
                indicates dimensions, and k represents the number of background 
                reference samples to draw per input in the batch.
        Returns: 
            samples_input: A tensor of shape (batch, k, ...) with the 
                interpolated points between input and ref.
        '''
        input_dims = list(input_tensor.size())[1:]
        num_input_dims = len(input_dims)

        batch_size = reference_tensor.size()[0]
        k_ = self.k

        # Grab a [batch_size, k]-sized interpolation sample
        if k_ == 1:
            t_tensor = torch.cat([torch.Tensor([1.0]) for _ in range(batch_size)]).to(DEFAULT_DEVICE)
        else:
            t_tensor = torch.cat([torch.linspace(0, 1, k_) for _ in range(batch_size)]).to(DEFAULT_DEVICE)

        shape = [batch_size, k_] + [1] * num_input_dims
        interp_coef = t_tensor.view(*shape)

        # Evaluate the end points
        end_point_ref = (1.0 - interp_coef) * reference_tensor
        input_expand_mult = input_tensor.unsqueeze(1)
        end_point_input = interp_coef * input_expand_mult

        # A fine Affine Combine
        samples_input = end_point_input + end_point_ref

        return samples_input

    def _get_samples_delta(self, input_tensor, reference_tensor):
        input_expand_mult = input_tensor.unsqueeze(1)
        sd = input_expand_mult - reference_tensor
        return sd

    def _get_grads(self, samples_input, sparse_labels=None):
        torch.set_grad_enabled(True)

        grad_tensor = torch.zeros(samples_input.shape).float().to(DEFAULT_DEVICE)

        beg_output, end_output = 0, 0
        for i in range(self.k):
            particular_slice = samples_input[:, i]
            particular_slice.requires_grad = True

            output = self.model(particular_slice)

            # ----------------------------
            if i == 0:
                beg_output = output
            if i == self.k-1:
                end_output = output

            if sparse_labels is None:
                sparse_labels = output.max(1, keepdim=False)[1]

            # batch_output = -1 * F.nll_loss(output, sparse_labels.flatten(), reduction='sum')
            batch_output = None
            if self.exp_obj == 'logit':
                batch_output = -1 * F.nll_loss(output, sparse_labels.flatten(), reduction='sum')
            elif self.exp_obj == 'prob':
                batch_output = -1 * F.nll_loss(F.log_softmax(output, dim=1), sparse_labels.flatten(), reduction='sum')
            elif self.exp_obj == 'contrast':
                b_num, c_num = output.shape[0], output.shape[1]
                mask = torch.ones(b_num, c_num, dtype=torch.bool)
                mask[torch.arange(b_num), sparse_labels] = False
                neg_cls_output = output[mask].reshape(b_num, c_num - 1)
                neg_weight = F.softmax(neg_cls_output, dim=1)
                weighted_neg_output = (neg_weight * neg_cls_output).sum(dim=1)
                pos_cls_output = output[torch.arange(b_num), sparse_labels]
                output = pos_cls_output - weighted_neg_output
                batch_output = output.sum()

            self.model.zero_grad()
            batch_output.backward()
            gradients = particular_slice.grad.clone()

            particular_slice.grad.zero_()

            grad_tensor[:, i, :] = gradients/self.k

            # rerecord the features and gradients in this step
        torch.set_grad_enabled(False)

        print('------- output --------')
        output_diff = (end_output - beg_output)[range(samples_input.size(0)), sparse_labels]
        print(torch.mean(output_diff))

        return grad_tensor

    def shap_values(self, input_tensor, ref_tensor, sparse_labels=None):
        """
        Calculate expected gradients approximation of Shapley values for the 
        sample ``input_tensor``.

        Args:
            model (torch.nn.Module): Pytorch neural network model for which the
                output should be explained.
            input_tensor (torch.Tensor): Pytorch tensor representing the input
                to be explained.
            sparse_labels (optional, default=None):
            inter (optional, default=None)
        """

        reference_tensor = ref_tensor.unsqueeze(1)
        reference_tensor = reference_tensor.repeat([1, self.k, 1, 1, 1])
        samples_input = self._get_samples_input(input_tensor, reference_tensor)
        grad_tensor = self._get_grads(samples_input, sparse_labels)

        ###
        samples_delta = self._get_samples_delta(input_tensor, reference_tensor)
        mult_grads = samples_delta * grad_tensor
        attribution = mult_grads.sum(1)
        print('#### attribution abs mean: {:.4f} ####'.format(float(torch.sum(attribution[0]))))
        ###

        ig_input_abs_sum = torch.sum(torch.abs(attribution), dim=tuple(range(1, attribution.dim())))
        ig_input_sum_abs = torch.abs(torch.sum(attribution, dim=tuple(range(1, attribution.dim()))))
        # ig_input_pos_sum = torch.sum(attribution[attribution > 0])
        # ig_input_neg_sum = torch.sum(attribution[attribution < 0])
        ###

        check_ind = 1
        print('#### attribution abs mean: {:.4f} ####'.format(float(torch.mean(ig_input_sum_abs))))
        # print('#### attribution pos. sum: {:.4f} ####'.format(float(torch.mean(ig_input_pos_sum))))
        # print('#### attribution neg. sum: {:.4f} ####'.format(float(torch.mean(ig_input_neg_sum))))

        # print(ig_input_sum_abs[check_ind])

        int_grad = {}
        key_idx = 0
        import model_editing.helpers.rewrite_helpers as rh

        ####### update hooked grad #########
        temp = ['layer1.0', 'layer1.1', 'layer2.0', 'layer2.1', 'layer3.0', 'layer3.1', 'layer4.0', 'layer4.1']
        temp_grad = self.hooked_gradients['layer1.0_in']
        temp_dict = {}
        for t_name in temp:
            temp_dict[t_name] = temp_grad
            temp_grad = self.hooked_gradients[t_name]
        # self.hooked_gradients = temp_dict
        #################################

        for key, feat in self.hooked_features.items():
            grad = temp_dict[key] # self.hooked_gradients[key]
            key_idx += 1
            context = self.act_context_list[key_idx]
            feat_d = feat[1] - feat[0]

            feat_key = feat[0]
            # query context_k via feat
            context_k = ch.get_context_key_with_act(feat_key, context)

            ig = grad * feat_d
            ig_abs_sum = torch.sum(torch.abs(ig))
            ig = ig / ig_abs_sum

            # prject ig into the context_k direction
            ################ here it is #####################################
            ig = rh.projected_conv(ig, context_k, unfold=False)

            int_grad[key] = ig

        int_grad_change = {}
        ig_input = ig_input_sum_abs / ig_input_abs_sum
        for key, int_grad in int_grad.items():
            # ig_dim = tuple(range(1, int_grad.dim()))
            # ig_sum_abs = torch.abs(torch.sum(int_grad, dim=ig_dim))
            int_grad_abs = torch.abs(int_grad)

            ig_abs_norm = int_grad_abs # / ig_sum_abs

            # ----- selecting from norm values -----
            # ig_abs_norm = torch.norm(ig_abs_norm, dim=1, p=2)
            # ig_output = torch.max(ig_abs_norm)

            # ----- estimating with norm -----
            ig_abs_norm = torch.norm(ig_abs_norm, p=2)
            # ig_abs_norm = torch.sum(ig_abs_norm)
            ig_output = ig_abs_norm

            ig_diff = ig_output # / ig_input

            # ig_input = ig_output
            ##############

            int_grad_change[key] = float(ig_diff)

        return int_grad_change

### correct calculation
# samples_delta = self._get_samples_inter_delta(input_tensor, ref_tensor)
# mult_grads = samples_delta * grad_tensor.sum(1)
# attribution = mult_grads
# print('------------ attribution right ---------------')
# print(torch.sum(attribution[0]))
