"""Utility functions and parameters common to all parts of the package
"""
import pathlib
import json
import os
import torch
import torch.nn as nn
import numpy as np
from shapreg.removal import MarginalExtension

import models.nn_model
import shap_dataset.shap_dataset


####### Utility functions ###########
def get_task_settings():
    this_directory = pathlib.Path(__file__).parent.resolve()
    with open(f"{this_directory}/shap_dataset/raw_datasets/task_settings.json") as f:
        dataset_settings = json.load(f)
    return dataset_settings


def get_dataset(task_name: str, with_splits=False):
    if task_name == "entacmaea":
        dataset = shap_dataset.shap_dataset.EntacmaeaDataset()

    elif task_name.startswith("harvard"):
        no_features = int(task_name[7:])
        dataset = shap_dataset.shap_dataset.HarvardCleanEnergyDataset(no_features=no_features)

    elif task_name.startswith("avGFP"):
        dataset = shap_dataset.shap_dataset.avGFPDataset()

    elif task_name == "sgemm":
        dataset = shap_dataset.shap_dataset.SGEMMDataset()

    elif task_name == "gb1":
        dataset = shap_dataset.shap_dataset.GB1Dataset()

    else:
        raise ValueError(f"Unknown dataset: \"{task_name}\"")

    if with_splits:
        nn_dataset = models.nn_model.NNDataset(dataset)
        train, test = nn_dataset.get_splits()
        x_train, x_test, y_train, y_test = dataset.x[list(train.indices)], dataset.x[list(test.indices)], \
            dataset.y[train.indices], dataset.y[test.indices]
        return x_train, x_test, y_train, y_test
    else:
        return dataset


def make_dir(dir):
    if not os.path.exists(dir):
        os.makedirs(dir)

# FastSHAP - ICLR: https://github.com/iclr1814/fastshap/blob/master/fastshap_torch/utils.py
class MarginalImputerTorch:
    '''
    Evaluate a model while replacing features with samples from the marginal
    distribution.

    Args:
      model:
      background:
      groups:
    '''

    def __init__(self, model, background, groups=None, link=None, device='cpu'):
        # Store model.
        self.model = model
        self.device = device

        # Store background samples.
        if isinstance(background, np.ndarray):
            background = torch.tensor(background, dtype=torch.float32,
                                      device=self.device)
        elif isinstance(background, torch.Tensor):
            background = background.to(device=self.device)
        else:
            raise ValueError('background must be np.ndarray or torch.Tensor')
        self.background = background
        self.background_repeat = background
        self.n_background = len(background)

        # Set up link.
        if link is None:
            self.link = nn.Identity()
        elif isinstance(link, nn.Module):
            self.link = link
        else:
            raise ValueError('unsupported link function: {}'.format(link))

        # Store feature groups.
        num_features = background.shape[1]
        if groups is None:
            self.num_players = num_features
            self.groups_matrix = None
        else:
            # Verify groups.
            inds_list = []
            for group in groups:
                inds_list += list(group)
            assert np.all(np.sort(inds_list) == np.arange(num_features))

            # Map groups to features.
            self.num_players = len(groups)
            self.groups_matrix = torch.zeros(
                len(groups), num_features, dtype=torch.float32, device=self.device)
            for i, group in enumerate(groups):
                self.groups_matrix[i, group] = 1

    def __call__(self, x, S):
        '''
        Evaluate model with marginal imputation.

        Args:
          x:
          S:
        '''
        # Prepare x and S.
        if isinstance(x, np.ndarray):
            numpy_conversion = True
            x = torch.tensor(x, dtype=torch.float32, device=self.device)
            S = torch.tensor(S, dtype=torch.float32, device=self.device)
        else:
            numpy_conversion = False

        
        if self.groups_matrix is not None:
            S = torch.mm(S, self.groups_matrix)

        # Set up background repeat.
        if len(self.background_repeat) != len(x) * self.n_background:
            self.background_repeat = self.background.repeat(len(x), 1)

        # Prepare x and S.
        x_tiled = x.unsqueeze(1).repeat(1, self.n_background, 1).reshape(
            len(x) * self.n_background, -1)
        S_tiled = S.unsqueeze(1).repeat(1, self.n_background, 1).reshape(
            len(x) * self.n_background, -1)

        # Replace features.
        x_tiled = S_tiled * x_tiled + (1 - S_tiled) * self.background_repeat

        # Make predictions.
        pred = self.link(self.model(x_tiled))
        pred = pred.reshape(len(x), self.n_background, *pred.shape[1:])
        pred = torch.mean(pred, dim=1)
        if numpy_conversion:
            pred = pred.cpu().data.numpy()
        return pred