from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.autograd import Variable
from sklearn.metrics import roc_auc_score, classification_report


def evaluate_multiclass(labels, predicted_label, predicted_probability):
    labels_array = labels.detach().cpu().numpy()  # one hot
    prediction_array = predicted_label  # .detach().cpu().numpy()  # one hot
    if len(np.unique(np.argmax(labels_array, 1))) >= 2:
        labels_array = labels_array[:, np.unique(np.argmax(labels_array, 1))]
        prediction_array = prediction_array[:, np.unique(np.argmax(labels_array, 1))]
        predicted_probability = predicted_probability[:, np.unique(np.argmax(labels_array, 1))]
        predicted_probability = np.array(predicted_probability.detach().cpu())
        auc_list = roc_auc_score(labels_array, predicted_probability, average=None)
        # print('macro auc:', auc_list)
        auc = np.mean(auc_list)

        report = classification_report(labels_array, prediction_array, output_dict=True)
        recall = report['macro avg']['recall']
        precision = report['macro avg']['precision']
    else:
        auc = 0
        recall = 0
        precision = 0
        auc_list = []
    correct_label = np.equal(np.argmax(labels_array, 1), np.argmax(prediction_array, 1)).sum()

    return auc, recall, precision, correct_label, auc_list


# This file contains baseline LTD method classes
class Linear_net_rej(nn.Module):
    def __init__(self, input_dim, out_dim):
        super(Linear_net_rej, self).__init__()
        # an affine operation: y = Wx + b
        self.fc = nn.Linear(input_dim, out_dim + 1)
        self.fc_rej = nn.Linear(input_dim, 1)
        torch.nn.init.ones_(self.fc.weight)
        torch.nn.init.ones_(self.fc_rej.weight)
        self.softmax = nn.Softmax(-1)

    def forward(self, x):
        out = self.fc(x)
        rej = self.fc_rej(x)
        # out = torch.cat([out,rej],1)
        out = self.softmax(out)
        return out


class Linear_net(nn.Module):
    def __init__(self, input_dim, out_dim):
        super(Linear_net, self).__init__()
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(input_dim, out_dim)
        torch.nn.init.ones_(self.fc1.weight)
        self.softmax = nn.Softmax(-1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.softmax(x)
        return x


class Linear_net_sig(nn.Module):
    def __init__(self, input_dim, out_dim):
        super(Linear_net_sig, self).__init__()
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(input_dim, 1)
        torch.nn.init.ones_(self.fc1.weight)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.fc1(x)
        x = self.sigmoid(x)
        return x


# This file contains all outcome predictors we might want to try
def reject_CrossEntropyLoss(outputs, m, labels, m2, n_classes):
    # m: expert costs, labels: ground truth, n_classes: number of classes
    batch_size = outputs.size()[0]  # batch_size
    rc = [n_classes] * batch_size
    rc = torch.tensor(rc)
    outputs = -m * torch.log2(outputs[range(batch_size), rc]) - m2 * torch.log2(
        outputs[range(batch_size), labels])  # pick the values corresponding to the labels
    # outputs =  (1-m)* outputs[range(batch_size), rc] - m2*outputs[range(batch_size), labels]*(1-outputs[range(
    # batch_size), rc])   # pick the values corresponding to the labels
    return torch.sum(outputs) / batch_size


def train_classifier_multiclass(net, data_x, data_y, criterion=nn.CrossEntropyLoss(),
                                optimizer=None, lr=0.1, n_epochs=50):
    # defaults for multi-class classification, use the foll for binary
    # criterion = torch.nn.BCELoss()
    # print(1. / (torch.sum(data_y, dim=0) / data_y.shape[0]))
    criterion = nn.CrossEntropyLoss(weight=1. / (torch.sum(data_y, dim=0) / data_y.shape[0]))
    if optimizer is None:
        # optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0)
        optimizer = optim.Adam(net.parameters(), lr=lr)
    for epoch in range(n_epochs):  # loop over the dataset multiple times
        running_loss = 0.0
        # get the inputs; data is a list of [inputs, labels]
        inputs = data_x
        labels = torch.argmax(data_y, dim=1)
        order = np.array(range(len(data_x)))
        np.random.shuffle(order)
        # in-place changing of values
        inputs[np.array(range(len(data_x)))] = inputs[order]
        labels[np.array(range(len(data_x)))] = labels[order]

        # zero the parameter gradients
        optimizer.zero_grad()
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(inputs2)*100)

        # forward + backward + optimize
        outputs = net(inputs)
        # print(outputs.shape, labels.shape)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        # scheduler.step()
        running_loss += loss.item()
        # print("loss " + str(loss.item()))
        # pred_onehot = torch.zeros(data_y.shape)
        # _, pred_label = outputs.max(1)
        # pred_onehot.scatter_(1, pred_label.view(-1, 1), 1)

        # print('auc', auc)


def train_classifier_binary(net, data_x, data_y):
    BCE = torch.nn.BCELoss()
    optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0)
    for epoch in range(100):  # loop over the dataset multiple times

        running_loss = 0.0
        # get the inputs; data is a list of [inputs, labels]
        inputs = data_x
        labels = data_y
        order = np.array(range(len(data_x)))
        np.random.shuffle(order)
        # in-place changing of values
        inputs[np.array(range(len(data_x)))] = inputs[order]
        labels[np.array(range(len(data_x)))] = labels[order]

        # zero the parameter gradients
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(inputs)*100)

        # forward + backward + optimize
        outputs = net(inputs)
        # loss = -labels*torch.log2(outputs) - (1-labels)*torch.log2(1-outputs) #BCE(outputs, labels)
        # loss = torch.sum(loss)/ len(inputs)
        loss = BCE(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # scheduler.step()
        running_loss += loss.item()
        # print("loss " + str(loss.item()))


def test_classifier_multiclass(net, data_x, data_y, enc_y):
    correct = 0
    total = 0
    with torch.no_grad():
        inputs = data_x
        labels = torch.argmax(data_y, dim=1)
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print('Accuracy of the network on test: %d %%' % (
            100 * correct / total))
    pred_onehot = np.zeros(data_y.shape)
    pred_onehot[:, predicted] = 1.
    # print(pred_onehot)
    auc, recall, precision, correct_label, auc_list = evaluate_multiclass(labels=data_y, predicted_label=pred_onehot,
                                                                          predicted_probability=torch.nn.Softmax(-1)(
                                                                              outputs))
    print('auc:', auc, 'auc_class:', auc_list)


def test_classifier_binary(net, data_x, data_y):
    correct = 0
    total = 0
    with torch.no_grad():
        inputs = data_x
        labels = data_y
        outputs = net(inputs)
        predicted = torch.round(outputs.data)
        total = labels.size(0)
        for i in range(total):
            correct += predicted[i].item() == labels[i].item()
        # correct = (predicted == labels).sum()
    print('Accuracy of the network on the 10000 test images: %d %%' % (
            100 * correct / total))


def train_classifier_rej(net, net_exp, data_x, data_y, alpha, lr=0.01, n_epochs=100):
    optimizer = optim.SGD(net.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(data_x) * 50)

    for epoch in range(n_epochs):  # loop over the dataset multiple times
        running_loss = 0.0
        inputs = data_x
        labels = torch.argmax(data_y, dim=1)
        order = np.array(range(len(data_x)))
        np.random.shuffle(order)
        # in-place changing of values
        inputs[np.array(range(len(data_x)))] = inputs[order]
        labels[np.array(range(len(data_x)))] = labels[order]
        x_batches = torch.split(inputs, 64)
        y_batches = torch.split(labels, 64)
        for inputs, labels in zip(x_batches, y_batches):
            # get the inputs; data is a list of [inputs, labels]

            # order = np.array(range(len(data_x)))
            # np.random.shuffle(order)
            # in-place changing of values
            # inputs[np.array(range(len(data_x)))] = inputs[order]
            # labels[np.array(range(len(data_x)))] = labels[order]
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            m = net_exp(inputs)
            _, predicted = torch.max(m.data, 1)
            m = (predicted == labels) * 1
            m2 = [0] * len(inputs)
            for j in range(0, len(inputs)):
                if m[j]:
                    m2[j] = alpha
                else:
                    m2[j] = 1
            m = m.clone().detach()
            m2 = torch.tensor(m2)
            outputs = net(inputs)
            loss = reject_CrossEntropyLoss(outputs, m, labels, m2, 2)
            # loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()
            running_loss += loss.item()
            # print("loss " + str(loss.item()))

    # print('Finished Training')


def train_classifier_rej_act(net, net_exp, data_x, data_y, alpha, env, data_x_unencoded=None, lr=0.01, n_epochs=100,
                             func=True, data_t=None):
    optimizer = optim.SGD(net.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(data_x) * 50)

    for epoch in range(n_epochs):  # loop over the dataset multiple times
        running_loss = 0.0
        inputs = data_x
        times = data_t
        labels = torch.argmax(data_y, dim=1)
        order = np.array(range(len(data_x)))
        np.random.shuffle(order)
        # in-place changing of values
        inputs[np.array(range(len(data_x)))] = inputs[order]
        labels[np.array(range(len(data_x)))] = labels[order]
        x_batches = torch.split(inputs, 64)
        y_batches = torch.split(labels, 64)
        if data_t is not None:
            times[np.array(range(len(data_x)))] = times[order]
            times = torch.Tensor(times)
            t_batches = torch.split(times, 64)
        else:
            t_batches = x_batches
        if data_x_unencoded is not None:
            data_x_unencoded[np.array(range(len(data_x)))] = data_x_unencoded[order]
            data_x_unenc_batches = torch.split(data_x_unencoded, 64)
        else:
            data_x_unenc_batches = x_batches
        for inputs, labels, t_batch, inputs_unenc in zip(x_batches, y_batches, t_batches, data_x_unenc_batches):
            # get the inputs; data is a list of [inputs, labels]

            # order = np.array(range(len(data_x)))
            # np.random.shuffle(order)
            # in-place changing of values
            # inputs[np.array(range(len(data_x)))] = inputs[order]
            # labels[np.array(range(len(data_x)))] = labels[order]
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            # print(inputs)
            if not func:
                if env == "discrete_toy" or env == "diabetes":
                    _, states = torch.max(inputs, 1)
                    m = np.array([net_exp[int(t)][int(i)] for i, t in zip(states, t_batch)])
                else:
                    # raise ValueError("Mozannar ltd: not implemented for this setting")
                    _, states = torch.max(inputs, 1)
                    m = np.array([net_exp[int(i)] for i in states])
                # m = np.array([net_exp[int(i)] for i in inputs[:, 0]])
            else:
                if data_t is not None:
                    if env == "sepsis_diabetes" or env == "hiv":
                        # print('sepsis')
                        m = np.array([net_exp(i, t) for i, t in
                                      zip(inputs.detach().numpy(), t_batch[:, 0].detach().numpy())])
                        # m = F.one_hot(torch.Tensor(m), num_classes=data_y.shape[1])
                    elif env == "randomwalk":
                        m = np.array([net_exp(i) for i, t in zip(inputs, t_batch[:, 0])])
                else:
                    m = np.array([net_exp(i) for i in inputs[:, 0]])
            if m.ndim == 1:
                m = torch.Tensor(np.vstack((1 - m, m))).T
            elif m.shape[1] == 1:
                m = torch.Tensor(np.vstack((1 - m[:, 0], m[:, 0]))).T
            else:
                m = torch.Tensor(m)
            # print(m.shape)
            _, predicted = torch.max(m.data, 1)
            m = (predicted == labels) * 1
            m2 = [0] * len(inputs)
            for j in range(0, len(inputs)):
                if m[j]:
                    m2[j] = alpha
                else:
                    m2[j] = 1
            m = m.clone().detach()
            m2 = torch.tensor(m2)
            # print(inputs.shape)
            outputs = net(inputs)
            # print("network labels", outputs)
            loss = reject_CrossEntropyLoss(outputs, m, labels, m2, 2)
            # loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()
            running_loss += loss.item()
            # print("loss " + str(loss.item()))

    # print('Finished Training')


def test_classifier_rej_act(net, net_exp, data_x, data_y, env, data_x_unencoded=None, data_t=None, func=False):
    correct = 0
    correct_sys = 0
    exp = 0
    exp_total = 0
    total = 0
    real_total = 0
    alone_correct = 0
    score = np.zeros(data_x.shape[0])
    with torch.no_grad():
        inputs = data_x
        labels = torch.argmax(data_y, dim=1)
        if not func:
            # print(inputs.shape)
            if env == "discrete_toy" or env == "diabetes":
                _, states = torch.max(inputs, 1)
                m = np.array([net_exp[t][i] for i, t in zip(states, data_t[:, 0])])
            else:
                _, states = torch.max(inputs, 1)
                m = np.array([net_exp[i] for i in states])
        else:
            if data_t is not None:
                # sepsis
                if env == "sepsis_diabetes" or env == "hiv":
                    m = np.array([net_exp(i, t) for i, t in zip(data_x.detach().numpy(), data_t[:, 0])])
                    # m = F.one_hot(m, num_classes=data_y.shape[1])
                elif env == "randomwalk":
                    m = np.array([net_exp(i) for i, t in zip(inputs, data_t[:, 0])])
            else:
                # randomwalk
                m = np.array([net_exp(i) for i in inputs[:, 0]])
        if m.ndim == 1:
            m = torch.Tensor(np.vstack((1 - m, m))).T
        elif m.shape[1] == 1:
            m = torch.Tensor(np.vstack((1 - m[:, 0], m[:, 0]))).T
        else:
            m = torch.Tensor(m)

        _, predicted_exp = torch.max(m.data, 1)
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        for i in range(len(inputs)):
            r = (predicted[i] == 2).item()
            if r:
                # print(labels[i],predicted_exp[i])
                exp += (predicted_exp[i] == labels[i]).item()
                correct_sys += (predicted_exp[i] == labels[i]).item()
                score[i] = (predicted_exp[i] == labels[i]).item()
                exp_total += 1
            else:
                correct += (predicted[i] == labels[i]).item()
                correct_sys += (predicted[i] == labels[i]).item()
                score[i] = (predicted[i] == labels[i]).item()
                total += 1
        real_total += labels.size(0)
    cov = str(total) + str(" out of") + str(real_total)
    to_print = {"coverage": cov, "system accuracy": 100 * correct_sys / real_total,
                "expert accuracy": 100 * exp / (exp_total + 0.0002),
                "classifier accuracy": 100 * correct / (total + 0.0001),
                "alone classifier": 100 * alone_correct / real_total}
    print(to_print)
    return [100 * total / real_total, 100 * correct_sys / real_total, 100 * exp / (exp_total + 0.0002),
            100 * correct / (total + 0.0001)], score, outputs


def test_classifier_rej(net, net_exp, data_x, data_y):
    correct = 0
    correct_sys = 0
    exp = 0
    exp_total = 0
    total = 0
    real_total = 0
    alone_correct = 0
    score = np.zeros(data_x.shape[0])
    with torch.no_grad():
        inputs = data_x
        labels = torch.argmax(data_y, dim=1)
        m = net_exp(inputs)
        _, predicted_exp = torch.max(m.data, 1)
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        for i in range(len(inputs)):
            r = (predicted[i] == 2).item()
            if r:
                exp += (predicted_exp[i] == labels[i]).item()
                correct_sys += (predicted_exp[i] == labels[i]).item()
                score[i] = (predicted_exp[i] == labels[i]).item()
                exp_total += 1
            else:
                correct += (predicted[i] == labels[i]).item()
                correct_sys += (predicted[i] == labels[i]).item()
                score[i] = (predicted[i] == labels[i]).item()
                total += 1
        real_total += labels.size(0)
    cov = str(total) + str(" out of") + str(real_total)
    to_print = {"coverage": cov, "system accuracy": 100 * correct_sys / real_total,
                "expert accuracy": 100 * exp / (exp_total + 0.0002),
                "classifier accuracy": 100 * correct / (total + 0.0001),
                "alone classifier": 100 * alone_correct / real_total}
    print(to_print)
    return [100 * total / real_total, 100 * correct_sys / real_total, 100 * exp / (exp_total + 0.0002),
            100 * correct / (total + 0.0001)], score


def sample_gumbel(shape, eps=1e-20):
    U = torch.rand(shape)
    # U = U.to(device)
    return -Variable(torch.log(-torch.log(U + eps) + eps))


def gumbel_softmax_sample(logits, temperature):
    y = logits + sample_gumbel(logits.size())
    return F.softmax(y / temperature, dim=-1)


def gumbel_softmax(logits, temperature, hard=False):
    """
    input: [*, n_class]
    return: [*, n_class] an one-hot vector
    """
    y = gumbel_softmax_sample(logits, temperature)
    return y


def gumbel_binary_sample(logits, t=0.5, eps=1e-20):
    """ Draw a sample from the Gumbel-Softmax distribution"""
    gumbel_noise_on = sample_gumbel(logits.size())
    gumbel_noise_off = sample_gumbel(logits.size())
    concrete_on = (torch.log2(logits + eps) + gumbel_noise_on) / t
    concrete_off = (torch.log2(1 - logits + eps) + gumbel_noise_off) / t
    concrete_softmax = torch.div(torch.exp(concrete_on), torch.exp(concrete_on) + torch.exp(concrete_off))
    return concrete_softmax


class Linear_net_madras_class(nn.Module):
    def __init__(self, input_dim, out_dim=1):
        super(Linear_net_madras_class, self).__init__()
        self.fc1 = nn.Linear(input_dim, out_dim)
        self.sigmoid = nn.Softmax(-1)

    def forward(self, x):
        out = self.fc1(x)
        out = self.sigmoid(out)
        return out


class Linear_net_madras_rej(nn.Module):
    def __init__(self, input_dim, class_dim=2):
        super(Linear_net_madras_rej, self).__init__()
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(input_dim + class_dim, 2)
        self.sigmoid = nn.Softmax(-1)

    def forward(self, x, y_hat):
        rej_input = torch.cat((x, y_hat), 1)
        rej = self.fc1(rej_input)
        rej = self.sigmoid(rej)
        return rej


def madras_loss_original(outputs, rej, labels, expert, eps=10e-12, defer_cost=0.01):
    # m: expert costs, labels: ground truth, n_classes: number of classes
    batch_size = outputs.size()[0]
    net_loss = -torch.log2(outputs[range(batch_size), labels] + eps)
    exp_loss = -torch.log2(expert[range(batch_size), labels] + eps) + defer_cost
    # exp_loss = #BCE(expert,labels)
    gumbel_rej = gumbel_binary_sample(rej)
    system_loss = (rej[range(batch_size), 0]) * net_loss + rej[range(batch_size), 1] * exp_loss
    return torch.sum(system_loss) / batch_size


def train_classifier_madras_original(net_class, net_rej, net_exp, data_x, data_y, lr=0.01, n_epochs=30):
    optimizer_class = optim.SGD(net_class.parameters(), lr=lr)
    optimizer_rej = optim.SGD(net_rej.parameters(), lr=lr)
    for epoch in range(n_epochs):  # loop over the dataset multiple times
        running_loss = 0.0
        inputs = data_x
        labels = torch.argmax(data_y, 1)
        order = np.array(range(len(data_x)))
        np.random.shuffle(order)
        # in-place changing of values
        inputs[np.array(range(len(data_x)))] = inputs[order]
        labels[np.array(range(len(data_x)))] = labels[order]
        x_batches = torch.split(inputs, 64)
        y_batches = torch.split(labels, 64)
        for inputs, labels in zip(x_batches, y_batches):
            optimizer_class.zero_grad()
            optimizer_rej.zero_grad()

            # forward + backward + optimize
            expert_prediction = net_exp(inputs)
            outputs = net_class(inputs)
            outputs_no_grad = outputs.detach()
            batch_size = outputs.size()[0]

            rej = net_rej(inputs, outputs_no_grad)

            loss_rej = madras_loss_original(outputs_no_grad, rej, labels, expert_prediction)
            loss_rej.backward()

            loss_class = -torch.log2(outputs[range(batch_size), labels] + 10e-12)
            loss_class = torch.sum(loss_class) / batch_size
            loss_class.backward()

            optimizer_class.step()
            optimizer_rej.step()

            running_loss += loss_rej.item()
            # print("loss " + str(loss_rej.item()))


def train_classifier_madras_original_act(net_class, net_rej, net_exp, data_x, data_y, env, data_x_unencoded=None,
                                         data_t=None, lr=0.01,
                                         n_epochs=30,
                                         func=False, train_net=False, defer_cost=0.01):
    # print('here', train_net)
    if train_net:
        optimizer_class = optim.SGD(net_class.parameters(), lr=lr)
    optimizer_rej = optim.SGD(net_rej.parameters(), lr=lr)
    for epoch in range(n_epochs):  # loop over the dataset multiple times
        running_loss = 0.0
        inputs = data_x
        labels = torch.argmax(data_y, 1)
        order = np.array(range(len(data_x)))
        np.random.shuffle(order)
        # in-place changing of values
        inputs[np.array(range(len(data_x)))] = inputs[order]
        labels[np.array(range(len(data_x)))] = labels[order]

        x_batches = torch.split(inputs, 64)
        y_batches = torch.split(labels, 64)
        # print('xbatch', x_batches)

        if data_t is not None:
            data_t[np.array(range(len(data_x)))] = data_t[order]
            data_t = torch.Tensor(data_t)
            t_batches = torch.split(data_t, 64)
        else:
            t_batches = x_batches

        if data_x_unencoded is not None:
            data_x_unencoded[np.array(range(len(data_x)))] = data_x_unencoded[order]
            data_x_unenc_batches = torch.split(data_x_unencoded, 64)
        else:
            data_x_unenc_batches = x_batches

        for inputs, labels, times, inputs_unenc in zip(x_batches, y_batches, t_batches, data_x_unenc_batches):
            if train_net:
                optimizer_class.zero_grad()
            optimizer_rej.zero_grad()

            # forward + backward + optimize
            # expert_prediction = net_exp(inputs)
            if not func:
                if env == "discrete_toy" or env == "diabetes":
                    _, states = torch.max(inputs, 1)
                    expert_prediction = np.array([net_exp[int(t)][int(i)] for i, t in zip(states, times)])
                else:
                    _, states = torch.max(inputs, 1)
                    expert_prediction = np.array([net_exp[int(i)] for i in states])
            else:
                if data_t is not None:
                    if env == "sepsis_diabetes" or env == "hiv":
                        # print('sepsis')
                        expert_prediction = np.array([net_exp(i, t) for i, t in
                                                      zip(inputs.detach().numpy(), times[:, 0].detach().numpy())])
                        # expert_prediction = F.one_hot(expert_prediction, num_classes=data_y.shape[1])
                    elif env == "randomwalk":
                        expert_prediction = np.array([net_exp(i) for i in inputs_unenc.detach().numpy()])
                    else:
                        expert_prediction = np.array([net_exp(i, t) for i, t in zip(inputs_unenc.detach().numpy(),
                                                                                    times[:, 0].detach().numpy())])
                        # expert_prediction = F.one_hot(torch.Tensor(expert_prediction), num_classes=data_y.shape[1])
                else:
                    expert_prediction = np.array([net_exp(i) for i in inputs[:, 0]])
                    # print('here', expert_prediction)
            if expert_prediction.ndim == 1:
                expert_prediction = torch.Tensor(np.vstack((1 - expert_prediction, expert_prediction))).T
            elif expert_prediction.shape[1] == 1:
                expert_prediction = torch.Tensor(np.vstack((1 - expert_prediction[:, 0], expert_prediction[:, 0]))).T
            else:
                expert_prediction = torch.Tensor(expert_prediction)

            if train_net:
                outputs = net_class(inputs)
            else:
                if not func:
                    if env == "discrete_toy" or env == "diabetes":
                        _, states = torch.max(inputs, 1)
                        # print('here')
                        outputs = np.array([net_class[int(t)][int(i)] for i, t in zip(states, times[:, 0])])
                    else:
                        _, states = torch.max(inputs, 1)
                        # print('here')
                        outputs = np.array([net_class[int(i)] for i in states])
                else:
                    if data_t is not None:
                        if env == "sepsis_diabetes" or env == "hiv":
                            # print('sepsis')
                            outputs = np.array([net_class(i, t) for i, t in
                                                zip(inputs.detach().numpy(),
                                                    times[:, 0].detach().numpy())])
                            # outputs = F.one_hot(outputs, num_classes=data_y.shape[1])
                        elif env == "randomwalk":
                            outputs = np.array([net_class(i) for i in
                                                inputs.detach().numpy()])
                        else:
                            outputs = np.array([net_class(i, t) for i, t in zip(inputs_unenc.detach().numpy()[:, 0],
                                                                                times[:, 0].detach().numpy())])
                            # outputs = F.one_hot(outputs, num_classes=data_y.shape[1])
                    else:
                        outputs = np.array([net_class(i) for i in inputs_unenc.detach().numpy()[:, 0]])
                        # print('here2', outputs, inputs)
                # print(outputs.shape)
                if outputs.ndim == 1:
                    outputs = torch.Tensor(np.vstack((1 - outputs, outputs))).T
                elif outputs.shape[1] == 1:
                    outputs = torch.Tensor(
                        np.vstack((1 - outputs[:, 0], outputs[:, 0]))).T
                else:
                    outputs = torch.Tensor(outputs)

            outputs_no_grad = outputs.detach()
            batch_size = outputs.size()[0]
            # print(outputs,expert_prediction,labels)

            rej = net_rej(inputs, outputs_no_grad)

            # print(expert_prediction.shape)
            loss_rej = madras_loss_original(outputs_no_grad, rej, labels, expert_prediction, defer_cost)
            loss_rej.backward()

            if train_net:
                loss_class = -torch.log2(outputs[range(batch_size), labels] + 10e-15)
                loss_class = torch.sum(loss_class) / batch_size
                loss_class.backward()

            if train_net:
                optimizer_class.step()
            optimizer_rej.step()

            running_loss += loss_rej.item()
            # print('trainnet', train_net)
            # print("loss " + str(loss_rej.item()))


def test_classifier_madras_original(net_class, net_rej, net_exp, data_x, data_y):
    correct = 0
    correct_sys = 0
    exp = 0
    exp_total = 0
    total = 0
    real_total = 0
    alone_correct = 0
    score = np.zeros(data_x.shape[0])
    with torch.no_grad():
        inputs = data_x
        labels = torch.argmax(data_y, dim=1)
        m = net_exp(inputs)
        _, predicted_exp = torch.max(m.data, 1)
        outputs = net_class(inputs)
        _, predicted = torch.max(outputs.data, 1)
        rej = net_rej(inputs, outputs)
        for i in range(len(inputs)):
            r = (rej[i][1].item() >= 0.5)
            if r:
                exp += (predicted_exp[i] == labels[i]).item()
                correct_sys += (predicted_exp[i] == labels[i]).item()
                score[i] = (predicted_exp[i] == labels[i]).item()
                exp_total += 1
            else:
                correct += (predicted[i] == labels[i]).item()
                correct_sys += (predicted[i] == labels[i]).item()
                score[i] = (predicted_exp[i] == labels[i]).item()
                total += 1
            alone_correct += (predicted[i] == labels[i]).item()
        real_total += labels.size(0)
    cov = str(total) + str(" out of") + str(real_total)
    to_print = {"coverage": cov, "system accuracy": 100 * correct_sys / real_total,
                "expert accuracy": 100 * exp / (exp_total + 0.0002),
                "classifier accuracy": 100 * correct / (total + 0.0001),
                "alone classifier": 100 * alone_correct / real_total}
    print(to_print)
    return [100 * total / real_total, 100 * correct_sys / real_total, 100 * exp / (exp_total + 0.0002),
            100 * correct / (total + 0.0001)], score


def test_classifier_madras_original_act(net_class, net_rej, net_exp, data_x, data_y, env, data_x_unencoded=None,
                                        data_t=None, func=False, train_net=False, defer_cost=0.01):
    correct = 0
    correct_sys = 0
    exp = 0
    exp_total = 0
    total = 0
    real_total = 0
    alone_correct = 0
    score = np.zeros(data_x.shape[0])
    with torch.no_grad():
        inputs = data_x
        times = data_t
        labels = torch.argmax(data_y, dim=1)
        # m = net_exp(inputs)
        if not func:
            if env == "discrete_toy" or env == "diabetes":
                _, states = torch.max(inputs, 1)
                m = np.array([net_exp[t][i] for i, t in zip(states, times[:, 0])])
            else:
                _, states = torch.max(inputs, 1)
                m = np.array([net_exp[i] for i in times[:, 0]])
        else:
            if data_t is not None:
                # sepsis
                if env == "sepsis_diabetes" or env == "hiv":
                    m = np.array([net_exp(i, t) for i, t in zip(data_x.detach().numpy(), data_t[:, 0])])
                    # m = F.one_hot(m, num_classes=data_y.shape[1])
                elif env == "randomwalk":
                    m = np.array([net_exp(i) for i, t in zip(data_x.detach().numpy(), data_t[:, 0])])
                else:
                    m = np.array([net_exp(i, t) for i, t in zip(data_x_unencoded.detach().numpy(), data_t[:, 0])])
                    # m = F.one_hot(m, num_classes=data_y.shape[1])
            else:
                m = np.array([net_exp(i) for i in data_x_unencoded.detach().numpy()])
        if m.ndim == 1:
            m = torch.Tensor(np.vstack((m, 1 - m))).T
        elif m.shape[1] == 1:
            m = torch.Tensor(np.vstack((m[:, 0], 1 - m[:, 0]))).T
        else:
            m = torch.Tensor(m)

        _, predicted_exp = torch.max(m.data, 1)
        if train_net:
            outputs = net_class(inputs)
        else:
            if not func:
                if env == "discrete_toy" or env == "diabetes":
                    _, states = torch.max(inputs, 1)
                    outputs = np.array([net_class[int(t)][int(i)] for i, t in zip(states, times[:, 0])])
                else:
                    raise ValueError("Mozannar ltd: not implemented for this setting")
            else:
                if data_t is not None:
                    # sepsis
                    if env == "sepsis_diabetes" or env == "hiv":
                        outputs = np.array(
                            [net_class(i, t) for i, t in zip(data_x.detach().numpy(), data_t[:, 0])])
                    elif env=="randomwalk":
                        outputs = np.array([net_class(i) for i in inputs])
                    else:
                        outputs = np.array([net_class(i, t) for i, t in zip(inputs, data_t[:, 0])])
                else:
                    outputs = np.array([net_class(i) for i in inputs[:, 0]])
            if outputs.ndim == 1:
                outputs = torch.Tensor(np.vstack((outputs, 1 - outputs))).T
            elif outputs.shape[1] == 1:
                outputs = torch.Tensor(np.vstack((outputs[:, 0], 1 - outputs[:, 0]))).T
            else:
                outputs = torch.Tensor(outputs)

        # print(inputs.shape, outputs.shape)
        _, predicted = torch.max(outputs.data, 1)
        # print(predicted_exp,predicted)
        rej = net_rej(inputs, outputs)
        for i in range(len(inputs)):
            r = (rej[i][1].item() >= 0.5)
            if r:
                exp += (predicted_exp[i] == labels[i]).item()
                correct_sys += (predicted_exp[i] == labels[i]).item()
                score[i] = (predicted_exp[i] == labels[i]).item()
                exp_total += 1
            else:
                correct += (predicted[i] == labels[i]).item()
                correct_sys += (predicted[i] == labels[i]).item()
                score[i] = (predicted_exp[i] == labels[i]).item()
                total += 1
            alone_correct += (predicted[i] == labels[i]).item()
        real_total += labels.size(0)
    cov = str(total) + str(" out of") + str(real_total)
    to_print = {"coverage": cov, "system accuracy": 100 * correct_sys / real_total,
                "expert accuracy": 100 * exp / (exp_total + 0.0002),
                "classifier accuracy": 100 * correct / (total + 0.0001),
                "alone classifier": 100 * alone_correct / real_total}
    print(to_print)
    return [100 * total / real_total, 100 * correct_sys / real_total, 100 * exp / (exp_total + 0.0002),
            100 * correct / (total + 0.0001)], score, outputs
