import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, TensorDataset, DataLoader
from simshap.utils import ExactShapleySampler, DatasetRepeat
from tqdm.auto import tqdm
def evaluate_explainer(explainer, x, num_players,
                       inference=False):
    '''
    Helper function for evaluating the explainer model and performing necessary
    normalization and reshaping operations.

    Args:
      explainer: explainer model.
      x: input.
      grand: grand coalition value.
      null: null coalition value.
      num_players: number of players.
      inference: whether this is inference time (or training).
    '''
    # Evaluate explainer.
    pred = explainer(x)

    # Reshape SHAP values.
    if len(pred.shape) == 4:
        # Image.
        image_shape = pred.shape
        pred = pred.reshape(len(x), -1, num_players)
        # pred = pred.permute(0, 2, 1)
    else:
        # Tabular.
        image_shape = None
        # pred = pred.reshape(len(x), num_players, -1)
        pass

    # Reshape for inference.
    if inference:
        if image_shape is not None:
            # pred = pred.permute(0, 2, 1)
            pred = pred.reshape(image_shape)

        return pred

    return pred


def calculate_grand_coalition(dataset, imputer, batch_size, device,
                              num_workers, num_players=None):
    '''
    Calculate the value of grand coalition for each input.

    Args:
      dataset: dataset object.
      imputer: imputer model.
      batch_size: minibatch size.
      num_players: number of players.
      link: link function.
      device: torch device.
      num_workers: number of worker threads.
    '''
    if num_players is None:
        num_players = imputer.num_players
        ones = torch.ones(batch_size, num_players, dtype=torch.float32,
                        device=device)
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
                            pin_memory=True, num_workers=num_workers)
        with torch.no_grad():
            grand = []
            for (x,) in loader:
                grand.append(imputer(x.to(device), ones[:len(x)]))

            # Concatenate and return.
            grand = torch.cat(grand)
            if len(grand.shape) == 1:
                grand = grand.unsqueeze(1)

        return grand
    else:
        ones = torch.ones(batch_size, num_players, dtype=torch.float32,
                        device=device)
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
                            pin_memory=True, num_workers=num_workers)
        with torch.no_grad():
            grand = []
            for (x,) in loader:
                grand.append(imputer(x.to(device)))

            # Concatenate and return.
            grand = torch.cat(grand)
            if len(grand.shape) == 1:
                grand = grand.unsqueeze(1)

        return grand

def generate_validation_data(val_set, imputer, validation_samples, sampler,
                             batch_size, device, num_workers, num_players=None):
    '''
    Generate coalition values for validation dataset.

    Args:
      val_set: validation dataset object.
      imputer: imputer model.
      validation_samples: number of samples per validation example.
      sampler: Shapley sampler.
      batch_size: minibatch size.
      link: link function.
      device: torch device.
      num_workers: number of worker threads.
    '''
    # Generate coalitions.
    if num_players == None: # surrogate
        val_S = sampler.sample(
            validation_samples * len(val_set), paired_sampling=True).reshape(
            len(val_set), validation_samples, imputer.num_players)

        # Get values.
        val_values = []
        for i in range(validation_samples):
            # Set up data loader.
            dset = DatasetRepeat([val_set, TensorDataset(val_S[:, i])])
            loader = DataLoader(dset, batch_size=batch_size, shuffle=False,
                                pin_memory=False, num_workers=num_workers)
            values = []

            for x, S in loader:
                values.append(imputer(x.to(device), S.to(device)).cpu().data)

            val_values.append(torch.cat(values))

        val_values = torch.stack(val_values, dim=1)
        return val_S, val_values
    else: # single model
        val_S = sampler.sample(
            validation_samples * len(val_set), paired_sampling=True).reshape(
            len(val_set), validation_samples, num_players)

        # Get values.
        val_values = []
        for i in range(validation_samples):
            # Set up data loader.
            dset = DatasetRepeat([val_set, TensorDataset(val_S[:, i])])
            loader = DataLoader(dset, batch_size=batch_size, shuffle=False,
                                pin_memory=False, num_workers=num_workers)
            values = []

            for x, S in loader:
                if x.ndim == 4: # Image
                    # S.shape:[BS*num_samples, num_players]
                    S = S.view(S.size(0), 1, np.sqrt(num_players).astype(int), np.sqrt(num_players).astype(int))
                    S = nn.Upsample(scale_factor=x.size(-1) // S.size(-1), mode='nearest')(S)
                    values.append(imputer((x.to(device)*S.to(device)).to(device)).cpu().data)
                else: # Tabular
                    values.append(imputer((x.to(device)*S.to(device)).to(device)).cpu().data)

            val_values.append(torch.cat(values))

        val_values = torch.stack(val_values, dim=1)
        return val_S, val_values

def compute_shapley_sampling(values, S, num_players, grand, null, adj_ratio=1):
    num_samples = S.size(1)
    len_S = torch.sum(S, dim=-1, keepdim=True) # BS x NS x 1
    weights = S * (num_players - len_S) - (1 - S) * len_S # [BS, NS, inD]
    weights_mat = weights * adj_ratio / num_samples
    shapleys_most = torch.einsum("ijm,ijk->ikm", weights_mat, values) # BS x outD x inD
    shapleys_rest = grand - null # [BS, outD]
    sampled_GT = shapleys_most + shapleys_rest.unsqueeze(-1) / num_players # BS x outD x inD
    return sampled_GT

class SimSHAPSampling:
    def __init__(self,
                 explainer,
                 imputer, 
                 num_players=None,
                 device=None):
        # Set up explainer, imputer and link function.
        self.explainer = explainer
        self.imputer = imputer
        if num_players is None:

            self.num_players = imputer.num_players
        else:
            self.num_players = num_players
        self.sampler = ExactShapleySampler(self.num_players, device=device)
        self.weight_sum = self.sampler.get_weight_sum()
        self.device = device

    def train(self,
              train_data,
              val_data,
              batch_size,
              num_samples,
              max_epochs,
              lr=2e-4,
              min_lr=1e-5,
              lr_factor=0.5,
              paired_sampling=True,
              validation_samples=None,
              lookback=5,
              training_seed=None,
              validation_seed=None,
              num_workers=0,
              bar=False,
              verbose=False, 
              accum_iter=1):
        '''
        Train explainer model.

        Args:
          train_data: training data with inputs only (np.ndarray, torch.Tensor,
            torch.utils.data.Dataset).
          val_data: validation data with inputs only (np.ndarray, torch.Tensor,
            torch.utils.data.Dataset).
          batch_size: minibatch size.
          num_samples: number of training samples.
          max_epochs: max number of training epochs.
          lr: initial learning rate.
          min_lr: minimum learning rate.
          lr_factor: learning rate decrease factor.
          paired_sampling: whether to use paired sampling.
          validation_samples: number of samples per validation example.
          lookback: lookback window for early stopping.
          training_seed: random seed for training.
          validation_seed: random seed for generating validation data.
          num_workers: number of worker threads in data loader.
          bar: whether to show progress bar.
          verbose: verbosity.
        '''
        # Set up explainer model.
        explainer = self.explainer
        num_players = self.num_players
        imputer = self.imputer
        device = self.device
        explainer.train()

        # Set up train dataset.
        if isinstance(train_data, np.ndarray):
            x_train = torch.tensor(train_data, dtype=torch.float32)
            train_set = TensorDataset(x_train)
        elif isinstance(train_data, torch.Tensor):
            train_set = TensorDataset(train_data)
        elif isinstance(train_data, Dataset):
            train_set = train_data
        else:
            raise ValueError('train_data must be np.ndarray, torch.Tensor or '
                             'Dataset')
        # Set up validation dataset.
        if isinstance(val_data, np.ndarray):
            x_val = torch.tensor(val_data, dtype=torch.float32)
            val_set = TensorDataset(x_val)
        elif isinstance(val_data, torch.Tensor):
            val_set = TensorDataset(val_data)
        elif isinstance(val_data, Dataset):
            val_set = val_data
        else:
            raise ValueError('train_data must be np.ndarray, torch.Tensor or '
                             'Dataset')


        # Grand coalition value.
        grand_train = calculate_grand_coalition(
            train_set, imputer, batch_size * num_samples, device,
            num_workers).cpu()
        grand_val = calculate_grand_coalition(
            val_set, imputer, batch_size * num_samples,device,
            num_workers).cpu()
        # Null coalition.
        with torch.no_grad():
            zeros = torch.zeros(1, num_players, dtype=torch.float32,
                                device=device)
            null = imputer(train_set[0][0].unsqueeze(0).to(device), zeros)
            if len(null.shape) == 1:
                null = null.reshape(1, 1)
        self.null = null

        # Set up train loader.
        train_set = DatasetRepeat([train_set, TensorDataset(grand_train)])
        train_loader = DataLoader(
            train_set, batch_size=batch_size, shuffle=False, pin_memory=True,
            drop_last=True, num_workers=num_workers)


        val_S, val_values = generate_validation_data(
            val_set, imputer, validation_samples, self.sampler,
            batch_size * num_samples, device, num_workers)

        # Set up val loader.
        val_set = DatasetRepeat(
            [val_set, TensorDataset(grand_val, val_S, val_values)])
        val_loader = DataLoader(val_set, batch_size=batch_size * num_samples,
                                pin_memory=False, num_workers=num_workers)
        # Setup for training.
        loss_fn = nn.MSELoss()
        optimizer = optim.AdamW(explainer.parameters(), lr=lr)
        # scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        #     optimizer, factor=lr_factor, patience=lookback // 2, min_lr=min_lr,
        #     verbose=verbose)
        # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, max_epochs / 2, eta_min=min_lr)
        self.loss_list = []
        pbar = tqdm(range(max_epochs))
        for epoch in pbar:
            # Batch iterable.
            if bar:
                batch_iter = tqdm(train_loader, desc='Training epoch')
            else:
                batch_iter = train_loader
            loss_lst = []
            for x, grand in batch_iter:
                # Sample S.
                S = self.sampler.sample(batch_size * num_samples, paired_sampling=paired_sampling) # every batch different masks
                
                # Move to device.
                x = x.to(device)
                S = S.to(device) # [BS*num_samples, inD]
                grand = grand.to(device)

                # Evaluate value function.
                x_tiled = x.unsqueeze(1).repeat(
                    1, num_samples, *[1 for _ in range(len(x.shape) - 1)]
                    ).reshape(batch_size * num_samples, *x.shape[1:])
                with torch.no_grad():
                    values = imputer(x_tiled, S) # [BS*num_samples, outD]

                # Evaluate explainer.
                pred = evaluate_explainer(
                    explainer, x, num_players)
                # pred = explainer(x)

                # Calculate loss.
                S = S.reshape(batch_size, num_samples, num_players) # [BS * NS * inD]
                values = values.view(batch_size, num_samples, -1) # [BS, NS, outD]
                sampled_GT = compute_shapley_sampling(values, S, num_players, grand, null, adj_ratio=self.weight_sum)
                loss = loss_fn(pred, sampled_GT)
                loss_lst.append(loss)
                loss = loss * num_players
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

            # validate
            explainer.eval()
            val_loss_lst = []
            for x, grand, S, values in val_loader:                
                # Move to device.
                x = x.to(device)
                S = S.to(device) # [BS*validation_samples, inD]

                grand = grand.to(device)
                values = values.to(device) # [BS*validation_samples, outD]

                # Evaluate explainer.
                pred = evaluate_explainer(
                    explainer, x, num_players)
                
                # Calculate loss.
                sampled_GT = compute_shapley_sampling(values, S, num_players, grand, null, adj_ratio=self.weight_sum)
                val_loss = loss_fn(pred, sampled_GT).item()
                val_loss *= num_players
                val_loss_lst.append(val_loss)
            explainer.train()
            if verbose:
                pbar.set_description('Loss/Train: {}, Loss/Val: {}'.format(torch.mean(torch.stack(loss_lst)), torch.mean(torch.tensor(val_loss_lst))))
        explainer.eval()

    def shap_values(self, x, original=False):
        '''
        Generate SHAP values.

        Args:
          x: input examples.
        '''
        if original == True:
            if isinstance(x, np.ndarray):
                x = torch.tensor(x, dtype=torch.float32)
            elif isinstance(x, torch.Tensor):
                pass
            else:
                raise ValueError('data must be np.ndarray or torch.Tensor')

            # Ensure null coalition is calculated.
            device = next(self.explainer.parameters()).device
            with torch.no_grad():
                # Evaluate explainer.
                x = x.to(device)
                pred = evaluate_explainer(
                    self.explainer, x,
                    self.num_players, inference=True)

            return pred.cpu().data.numpy()
        else: # using surrogate model
            # Data conversion.
            if isinstance(x, np.ndarray):
                x = torch.tensor(x, dtype=torch.float32)
            elif isinstance(x, torch.Tensor):
                pass
            else:
                raise ValueError('data must be np.ndarray or torch.Tensor')

            # Ensure null coalition is calculated.
            device = next(self.explainer.parameters()).device
            with torch.no_grad():
                zeros = torch.zeros(1, self.num_players, dtype=torch.float32,
                                    device=device)
                null = self.imputer(x[:1].to(device), zeros)
            if len(null.shape) == 1:
                null = null.reshape(1, 1)
            self.null = null
            with torch.no_grad():
                # Evaluate explainer.
                x = x.to(device)
                pred = evaluate_explainer(
                    self.explainer, x,
                    self.imputer.num_players, inference=True)

            return pred.cpu().data.numpy()