import numpy as np
import torch
from sklearn import metrics


def normalize_masked_data(data, mask, att_min, att_max):
    # we don't want to divide by zero
    att_max[att_max == 0.] = 1.

    if (att_max != 0.).all():
        data_norm = (data - att_min) / att_max
    else:
        raise Exception("Zero!")

    if torch.isnan(data_norm).any():
        raise Exception("nans!")

    # set masked out elements back to zero
    data_norm[mask == 0] = 0

    return data_norm, att_min, att_max


def obtain_impute_result(model, dataloader, device, args):
    mse, num = 0, 0
    model.eval()
    with torch.no_grad():
        for ind, data in enumerate(dataloader):
            data = data.to(device)
            channels = data.size(-1) // 2
            observed_data = data[:, :, :channels]
            observed_mask = data[:, :, channels:-1]
            observed_tp = data[:, :, -1]
            if args.sample_tp and args.sample_tp < 1:
                subsampled_data, subsampled_tp, subsampled_mask = subsample_timepoints(
                    observed_data.clone(), observed_tp.clone(), observed_mask.clone(), args.sample_tp)
            else:
                subsampled_data, subsampled_tp, subsampled_mask = \
                    observed_data, observed_tp, observed_mask

            _, recon, _, _ = model(torch.cat([subsampled_data, subsampled_mask], dim=-1), subsampled_tp, observed_tp)

            observed_mask[subsampled_mask == 1.] = 0
            mse += mse_with_mask(observed_data, recon, observed_mask)
            num += data.size(0)
    model.train()
    return mse / num


def obtain_classifier_result(model, classifier, dataloader, device, criterions=['auc', 'auprc']):
    pred, true, loss = [], [], 0
    criterion = torch.nn.CrossEntropyLoss()
    model.eval(), classifier.eval()
    with torch.no_grad():
        for ind, (data, label) in enumerate(dataloader):
            data, label = data.to(device), label.to(device)
            reprs, _, _, _ = model(data[..., :-1], data[..., -1])

            if hasattr(classifier, 'gru_rnn'):
                out = classifier(reprs)
            elif len(reprs.size()) == 3:
                out = classifier(reprs[:, 0])
            else:
                out = classifier(reprs)
            pred.append(out.cpu().numpy()), true.append(label.cpu().numpy())
            loss += criterion(out, label)
        pred = np.concatenate(pred, axis=0)
        true = np.concatenate(true)
    # end with
    model.train(), classifier.train()
    results, average = {}, 'macro' if pred.shape[-1] > 2 else 'binary'
    for crit in criterions:
        if crit == 'acc':
            ret = np.mean(pred.argmax(1) == true)
        elif crit == 'auc':
            ret = metrics.roc_auc_score(true, pred[:, 1])
        elif crit == 'auprc':
            ret = metrics.average_precision_score(true, pred[:, 1])
        elif crit == 'precision':
            ret = metrics.precision_score(true, pred.argmax(1), average=average, zero_division=1.)
        elif crit == 'recall':
            ret = metrics.recall_score(true, pred.argmax(1), average=average, zero_division=1.)
        elif crit == 'f1':
            ret = metrics.f1_score(true, pred.argmax(1), average=average, zero_division=1.)
        else:
            raise NotImplementedError

        results[crit] = ret
    return results, loss.item()


def mse_with_mask(x, x_hat, mask):
    return torch.pow((x - x_hat) * mask, 2).sum() / (mask.sum())


def subsample_timepoints(data, time_steps, mask, percentage_tp_to_sample=None):
    # Subsample percentage of points from each time series
    for i in range(data.size(0)):
        # take mask for current training sample and sum over all features --
        # figure out which time points don't have any measurements at all in this batch
        current_mask = mask[i].sum(-1).cpu()
        non_missing_tp = np.where(current_mask > 0)[0]
        n_tp_current = len(non_missing_tp)
        n_to_sample = int(n_tp_current * percentage_tp_to_sample)
        subsampled_idx = sorted(np.random.choice(
            non_missing_tp, n_to_sample, replace=False))
        tp_to_set_to_zero = np.setdiff1d(non_missing_tp, subsampled_idx)

        data[i, tp_to_set_to_zero] = 0.
        if mask is not None:
            mask[i, tp_to_set_to_zero] = 0.

    return data, time_steps, mask