from __future__ import print_function
import torch
import os, sys
import pickle as pkl
from baselines.baseline_utils import Linear_net_rej, Linear_net, train_classifier_rej, test_classifier_rej, \
    train_classifier_multiclass, train_classifier_rej_act, test_classifier_rej_act, \
    train_classifier_madras_original, test_classifier_madras_original, test_classifier_multiclass, \
    Linear_net_madras_class, Linear_net_madras_rej, \
    train_classifier_madras_original_act, test_classifier_madras_original_act
from sklearn.preprocessing import OneHotEncoder, StandardScaler
import numpy as np

class mozannar_ltd:
    # change this depending on what experiments we want to run
    def __init__(self, input_dim, output_dim, alpha=0.5):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model = Linear_net_rej(input_dim=input_dim, out_dim=output_dim)
        self.expert_model = Linear_net(input_dim=input_dim, out_dim=output_dim)
        self.alpha = alpha
        self.x = None
        self.y = None
        self.x_test = None
        self.y_test = None
        self.data_y = None
        self.data_x = None
        self.data_y_test = None
        self.data_x_test = None
        self.enc_x = None
        self.enc_y = None
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.scaler = None

    def encode_data(self, data_x, data_y, data_x_test, data_y_test, one_hot=True):
        # data_x is numpy array here and we prepare data for training in this function
        # data_y is batch x 1 array to be encoded - note that we assume here rewards are discrete +1,0,-1 and then treat
        # this as a multi-class problem. Hence we will transform y as well.
        if one_hot:
            self.x = data_x
            self.y = data_y
            self.x_test = data_x_test
            self.y_test = data_y_test
            x_enc = np.zeros((self.x.shape[0], self.input_dim))
            for i in range(x_enc.shape[0]):
                x_enc[i, int(self.x[i])] = 1
            self.data_x = torch.Tensor(x_enc)
            self.enc_y = OneHotEncoder(sparse=False)
            self.data_y = torch.Tensor(self.enc_y.fit_transform(data_y))

            x_enc = np.zeros((self.x_test.shape[0], self.input_dim))
            for i in range(x_enc.shape[0]):
                x_enc[i, int(self.x_test[i])] = 1
            self.data_x_test = torch.Tensor(x_enc)
            self.data_y_test = torch.Tensor(self.enc_y.transform(data_y_test))
        else:
            self.x = data_x
            self.y = data_y
            self.scaler = StandardScaler()
            self.x_test = data_x_test
            self.y_test = data_y_test
            self.data_x = torch.Tensor(self.scaler.fit_transform(data_x))
            self.enc_y = OneHotEncoder(sparse=False)
            self.data_y = torch.Tensor(self.enc_y.fit_transform(data_y))
            # print(self.data_y.shape,self.enc_y.categories_)
            self.data_x_test = torch.Tensor(self.scaler.transform(data_x_test))
            self.data_y_test = torch.Tensor(self.enc_y.transform(data_y_test))

    def train_expert(self, lr=0.1, n_epochs=100):
        train_classifier_multiclass(net=self.expert_model, data_x=self.data_x, data_y=self.data_y, lr=lr,
                                    n_epochs=n_epochs)
        test_classifier_multiclass(net=self.expert_model, data_x=self.data_x_test, data_y=self.data_y_test,
                                   enc_y=self.enc_y)

    def train_model(self, lr=0.01, n_epochs=100):
        train_classifier_rej(net=self.model, net_exp=self.expert_model, data_x=self.data_x, data_y=self.data_y,
                             alpha=self.alpha, lr=lr, n_epochs=n_epochs)
        test_classifier_rej(net=self.model, net_exp=self.expert_model, data_x=self.data_x_test, data_y=self.data_y_test)

    def eval(self, one_hot=True):
        if one_hot:
            batch_result, score_vector = test_classifier_rej(net=self.model, net_exp=self.expert_model,
                                                             data_x=self.data_x,
                                                             data_y=self.data_y)
            # get a value estimate from score vector
            V = np.zeros(self.input_dim)
            value_vec = np.multiply(self.y[:, 0], score_vector)  # element-wise mult.
            # print(value_vec)
            # average over states:
            for s in range(self.input_dim):
                idx = np.where(self.x == s)[0]
                if len(idx) > 0:
                    V[s] = np.mean(value_vec[idx])
            return batch_result, np.mean(V)
        else:
            batch_result, score_vector = test_classifier_rej(net=self.model, net_exp=self.expert_model,
                                                             data_x=self.data_x,
                                                             data_y=self.data_y)
            # get a value estimate from score vector
            value_vec = np.multiply(self.y[:, 0], score_vector)  # element-wise mult.
            return batch_result, np.mean(value_vec)


class mozannar_action_ltd:
    # change this depending on what experiments we want to run
    def __init__(self, input_dim, output_dim, clinician_policy, env, pre_encoder=None, alpha=0.5, nst=False):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model = Linear_net_rej(input_dim=input_dim, out_dim=output_dim)
        self.expert_model = clinician_policy
        self.alpha = alpha
        self.x = None
        self.y = None
        self.x_test = None
        self.y_test = None
        self.data_y = None
        self.data_x = None
        self.data_y_test = None
        self.data_x_test = None
        self.data_x_unencoded = None
        self.data_x_unencoded_test = None
        self.enc_x = None
        self.enc_y = None
        self.data_t = None
        self.data_t_test = None
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.scaler = None
        self.nst = nst
        self.env = env
        self.pre_encoder = pre_encoder

    def pre_decoder(self, x):
        if self.pre_encoder is not None:
            n_samples, n_features = x.shape
            recovered_X = np.array([self.pre_encoder.active_features_[col] for col
                                    in x.sorted_indices().indices]).reshape(n_samples, n_features) - \
                          self.pre_encoder.feature_indices_[:-1]
            return recovered_X
        else:
            return x

    def encode_data(self, data_x, data_y, data_x_test, data_y_test, data_t, data_t_test, data_x_unencoded=None,
                    data_x_unencoded_test=None, one_hot=True):
        # data_x is numpy array here and we prepare data for training in this function
        # data_y is batch x 1 array to be encoded - note that we assume here rewards are discrete +1,0,-1 and then treat
        # this as a multi-class problem. Hence we will transform y as well.
        # these two are mainly used for evaluation in eval function.
        self.data_t = data_t
        self.data_t_test = data_t_test
        if one_hot:
            self.x = data_x
            self.y = data_y
            self.x_test = data_x_test
            self.y_test = data_y_test
            x_enc = np.zeros((self.x.shape[0], self.input_dim))
            for i in range(x_enc.shape[0]):
                x_enc[i, int(self.x[i])] = 1
            self.data_x = torch.Tensor(x_enc)
            self.enc_y = OneHotEncoder(sparse=False)
            self.data_y = torch.Tensor(self.enc_y.fit_transform(data_y))

            x_enc = np.zeros((self.x_test.shape[0], self.input_dim))
            for i in range(x_enc.shape[0]):
                x_enc[i, int(self.x_test[i])] = 1
            self.data_x_test = torch.Tensor(x_enc)
            self.data_y_test = torch.Tensor(self.enc_y.transform(data_y_test))
        else:
            self.x = data_x
            self.y = data_y
            self.scaler = StandardScaler()
            self.x_test = data_x_test
            self.y_test = data_y_test
            self.data_x = torch.Tensor(data_x)
            self.enc_y = OneHotEncoder(sparse=False, handle_unknown='ignore')
            self.data_y = torch.Tensor(self.enc_y.fit_transform(data_y))
            # print(self.data_y.shape,self.enc_y.categories_)
            self.data_x_test = torch.Tensor(data_x_test)
            self.data_y_test = torch.Tensor(data_y_test)
            self.data_x_unencoded = data_x_unencoded
            self.data_x_unencoded_test = data_x_unencoded_test

    def train_model(self, lr=0.01, n_epochs=100, func=False):
        if not self.nst:
            train_classifier_rej_act(net=self.model, net_exp=self.expert_model, data_x=self.data_x, data_y=self.data_y,
                                     alpha=self.alpha, lr=lr, n_epochs=n_epochs, func=func, data_t=None, env=self.env,
                                     data_x_unencoded=None)
            test_classifier_rej_act(net=self.model, net_exp=self.expert_model, data_x=self.data_x_test,
                                    data_y=self.data_y_test, func=func, env=self.env, data_t=None,
                                    data_x_unencoded=None)
        else:
            train_classifier_rej_act(net=self.model, net_exp=self.expert_model, data_x=self.data_x, data_y=self.data_y,
                                     alpha=self.alpha, lr=lr, n_epochs=n_epochs, func=func, data_t=self.data_t,
                                     env=self.env,
                                     data_x_unencoded=None)
            test_classifier_rej_act(net=self.model, net_exp=self.expert_model, data_x=self.data_x_test,
                                    data_y=self.data_y_test, func=func, data_t=self.data_t_test, env=self.env,
                                    data_x_unencoded=None)

    def eval(self, true_tx, reward_mat, one_hot=True, func=False, n_trials=5, env="discrete_toy", defer_cost=0.1,
             init_prior=None, results_path=None, baseline='mozannar_ltd', seed=0, alpha_mozannar=1.0):
        """
        :param seed:
        :param alpha_mozannar:
        :param baseline:
        :param results_path:
        :param defer_cost: defer cost
        :param n_trials: number of trajectories
        :param env: environment for which to evaluate
        :param true_tx: dict indexed by t with action x state x state matrices
        :param reward_mat: action x state
        :param one_hot: Is the data one-hot encoded? (False for continous data)
        :param func: Is the policy a function or array (function for continuous data)
        :return: performance summary and average value
        """
        if not func:
            # get a value estimate from score vector
            T = int(len(true_tx.keys()))

            result_dict = {t: [] for t in range(T)}
            cost_dict = {t: [] for t in range(T)}
            state_dict = {t: [] for t in range(T)}
            action_dict = {t: [] for t in range(T)}
            defer_dict = {t: [] for t in range(T)}
            value_dict = {t: [] for t in range(T)}
            defer_freq_dict = {t: [] for t in range(T)}
            defer_times_dict = {jj: [] for jj in range(n_trials)}
            min_t_defer = []

            for j in range(n_trials):
                defer_cost_total = 0
                min_t_found = False
                cumulative_reward = 0

                for tn, tt in enumerate(range(T)):
                    if tn == 0:
                        xx = np.zeros(self.data_x.shape[1]).reshape(1, -1)
                        xx[0, 0] = 1
                        _, act = torch.max(self.model(torch.tensor(xx).reshape(1, -1).float()), 1)
                        act = int(act.detach().numpy()[0])
                        _, states = torch.max(torch.tensor(xx).float(), 1)
                        states = states.detach().numpy()
                        states = int(states[0])
                        if act == 2:  # defer
                            min_t_defer.append(tt)
                            min_t_found = True
                            defer_times_dict[j].append(tt)
                            if not func:
                                if env == "discrete_toy" or env == "diabetes":
                                    action = np.array([self.expert_model[tt][states]])
                                else:
                                    action = np.array([self.expert_model[states]])
                            else:
                                if self.data_t is not None:
                                    if env == "sepsis_diabetes":
                                        action = np.array([self.model(ii, t) for ii, t in
                                                           zip(xx.reshape(1, -1),
                                                               tt.reshape(1, -1))])
                                    else:
                                        action = np.array([self.model(ii, t) for ii, t in zip(xx, np.array([tt]))])
                                else:
                                    action = np.array([self.model(ii) for ii in xx.reshape(1, -1)[0]])
                            if action.ndim == 1:
                                action = torch.Tensor(np.vstack((1 - action, action))).T
                            elif action.shape[1] == 1:
                                action = torch.Tensor(
                                    np.vstack((1 - action[:, 0], action[:, 0]))).T
                            else:
                                action = torch.Tensor(action)
                            _, action = torch.max(action.data, 1)
                            action = action.detach().numpy()[0]
                        else:
                            action = act

                        if self.env == "diabetes":
                            reward = reward_mat[tt][states, action]
                            if act == 2:
                                defer_cost_total += defer_cost
                                cumulative_reward += (reward - defer_cost)
                            else:
                                cumulative_reward += reward
                        else:
                            reward = reward_mat[states, action]
                            if act == 2:
                                defer_cost_total += defer_cost
                                cumulative_reward += (reward - defer_cost)
                            else:
                                cumulative_reward += reward

                        state_dict[tt].append(states)
                        action_dict[tt].append(action)
                        result_dict[tt].append(reward)
                        defer_dict[tt].append(int(act == 2))
                        value_dict[tt].append(cumulative_reward)
                        defer_freq_dict[tt].append(int(act == 2))
                        cost_dict[tt].append(defer_cost * int(act == 2))

                        p_transition = true_tx[int(tt)][
                            action, states]
                        s_next = np.random.choice(a=range(len(p_transition)), p=p_transition, size=1)[0]
                        s_next_vec = np.zeros((1, self.input_dim))
                        s_next_vec[0, s_next] = 1.
                    else:
                        _, act = torch.max(self.model(torch.tensor(s_next_vec).reshape(1, -1).float()), 1)
                        act = int(act.detach().numpy()[0])
                        if act == 2:  # defer
                            if ~min_t_found:
                                min_t_found = True
                                min_t_defer.append(tt)
                            defer_times_dict[j].append(tt)
                            if not func:
                                if self.env == "discrete_toy" or self.env == "diabetes":
                                    action = np.array(
                                        [self.expert_model[tt][int(s_next)]])
                                else:
                                    action = np.array(
                                        [self.expert_model[int(s_next)]])
                            else:
                                if self.data_t is not None:
                                    if self.env == "sepsis_diabetes":
                                        # print('sepsis')
                                        action = np.array([self.model(s_next, tt)])
                                    else:
                                        action = np.array([self.model(s_next, tt)])
                                else:
                                    action = np.array([self.model(s_next)])
                            if action.ndim == 1:
                                action = torch.Tensor(np.vstack((1 - action, action))).T
                            elif action.shape[1] == 1:
                                action = torch.Tensor(
                                    np.vstack((1 - action[:, 0], action[:, 0]))).T
                            else:
                                action = torch.Tensor(action)
                            _, action = torch.max(action.data, 1)
                            action = action.detach().numpy()[0]
                        else:
                            action = act

                        if self.env == "diabetes":
                            reward = reward_mat[tt][s_next, action]
                            if act == 2:
                                defer_cost_total += defer_cost
                                cumulative_reward += (reward - defer_cost)
                            else:
                                cumulative_reward += reward
                        else:
                            reward = reward_mat[s_next, action]
                            if act == 2:
                                defer_cost_total += defer_cost
                                cumulative_reward += (reward - defer_cost)
                            else:
                                cumulative_reward += reward

                        state_dict[tt].append(s_next)
                        action_dict[tt].append(action)
                        result_dict[tt].append(reward)
                        defer_dict[tt].append(int(act == 2))
                        value_dict[tt].append(cumulative_reward)
                        defer_freq_dict[tt].append(int(act == 2))
                        cost_dict[tt].append(defer_cost * int(act == 2))

                        p_transition = true_tx[int(tt)][
                            action, s_next]

                        s_next = np.random.choice(a=range(len(p_transition)), p=p_transition, size=1)[0]
                        s_next_vec = np.zeros((1, self.input_dim))
                        s_next_vec[0, s_next] = 1.

        else:  # for continuous data true dynamics and rewards are functions. currently this will only work for
            T = int(max(self.data_t))

            result_dict = {t: [] for t in range(T)}
            cost_dict = {t: [] for t in range(T)}
            state_dict = {t: [] for t in range(T)}
            action_dict = {t: [] for t in range(T)}
            defer_dict = {t: [] for t in range(T)}
            value_dict = {t: [] for t in range(T)}
            defer_freq_dict = {t: [] for t in range(T)}
            defer_times_dict = {jj: [] for jj in range(n_trials)}
            min_t_defer = []

            for j in range(n_trials):
                defer_cost_total = 0
                min_t_found = False
                cumulative_reward = 0

                for tn, tt in enumerate(range(int(np.max(self.data_t[:, 0])))):
                    if tt == 0:
                        i = np.random.choice(self.data_x.shape[0])
                        _, act = torch.max(self.model(self.data_x[i, :].reshape(1, -1)), 1)
                        if act == 2:
                            min_t_defer.append(tt)
                            min_t_found = True
                            defer_times_dict[j].append(tt)
                            if self.env == "randomwalk":
                                action = np.random.binomial(n=1, p=self.expert_model(self.data_x[i, :]), size=1)[0]
                            elif self.env == "hiv":
                                pp = self.expert_model(self.data_x[i, :], t=tt)
                                action = np.random.choice(a=range(len(pp)), p=pp, size=1)[0]
                        else:
                            action = act

                        if self.env == "randomwalk":
                            reward = reward_mat(self.data_x[i, 0])
                        elif self.env == "hiv":
                            reward = reward_mat(self.data_x[i, 0], action)

                        if act == 2:
                            defer_cost_total += defer_cost
                            cumulative_reward += (reward - defer_cost)
                        else:
                            cumulative_reward += reward

                        state_dict[tt].append(self.data_x[i, 0])
                        action_dict[tt].append(action)
                        result_dict[tt].append(reward)
                        defer_dict[tt].append(int(act == 2))
                        value_dict[tt].append(cumulative_reward)
                        defer_freq_dict[tt].append(int(act == 2))
                        cost_dict[tt].append(defer_cost * int(act == 2))

                        if self.env == "randomwalk":
                            s_next, _, _ = true_tx(self.data_x[i, :], action, tt=tt)
                        elif self.env == "hiv":
                            s_next = true_tx(self.data_x[i, :], action, time=tt)
                            # s_next = np.random.choice(a=range(len(pp)), p=pp, size=1)[0]
                    else:
                        _, act = torch.max(self.model(torch.Tensor(s_next).reshape(1, -1)), 1)
                        if act == 2:
                            if ~min_t_found:
                                min_t_defer.append(tt)
                                min_t_found = True
                            defer_times_dict[j].append(tt)
                            if self.env == "randomwalk":
                                action = np.random.binomial(n=1, p=self.expert_model(s_next), size=1)[0]
                            elif self.env == "hiv":
                                pp = self.expert_model(s_next, t=tt)
                                action = np.random.choice(a=range(len(pp)), p=pp, size=1)[0]
                        else:
                            action = act

                        if self.env == "randomwalk":
                            reward = reward_mat(s_next)
                        elif self.env == "hiv":
                            reward = reward_mat(s_next, action)

                        if act == 2:
                            defer_cost_total += defer_cost
                            cumulative_reward += (reward - defer_cost)
                        else:
                            cumulative_reward += reward

                        state_dict[tt].append(s_next)
                        action_dict[tt].append(action)
                        result_dict[tt].append(reward)
                        defer_dict[tt].append(int(act == 2))
                        value_dict[tt].append(cumulative_reward)
                        defer_freq_dict[tt].append(int(act == 2))
                        cost_dict[tt].append(defer_cost * int(act == 2))

                        if self.env == "randomwalk":
                            s_next, _, _ = true_tx(s_next, action, tt=tt)
                        elif self.env == "hiv":
                            s_next = true_tx(s_next, action, time=tt)
                            # s_next = np.random.choice(a=range(len(pp)), p=pp, size=1)[0]

        value_array = np.array(value_dict[T - 1])
        defer_cost_array = np.asarray(list(defer_dict.values())).sum(0)
        defer_freq_mat = np.array(list(defer_freq_dict.values()), dtype=float)
        cost_mat = np.array(list(cost_dict.values()), dtype=float)  # costmat is T x N
        mean_defer_time = np.nanmean(np.array([np.nanmean(defer_times_dict[jj]) for jj in range(n_trials)]))

        print(baseline, " min_t:", np.median(min_t_defer), np.mean(min_t_defer), 'value:',
              (value_array + defer_cost_array).mean(),
              'defer frequency:', np.sum(defer_freq_mat) / np.prod(np.shape(defer_freq_mat)), 'mean defer time:',
              mean_defer_time)

        with open(os.path.join(results_path, 'models',
                               '%s_policy_%d_%s_cost_%f_alpha_%f_trajectories.pkl' % (
                                       baseline, seed, baseline, defer_cost, alpha_mozannar)),
                  'wb') as f:
            pkl.dump({'reward': result_dict, 'states': state_dict, 'actions': action_dict, 'cost': cost_dict,
                      'value': value_dict, 'defer_freq': defer_freq_dict, 'defer_time': defer_times_dict}, f)

        return value_array, cost_mat.sum(0), defer_freq_mat.sum(0)


class madras_ltd:
    # change this depending on what experiments we want to run
    def __init__(self, input_dim, output_dim):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model = Linear_net_madras_class(input_dim=input_dim, out_dim=output_dim)
        self.model_rej = Linear_net_madras_rej(input_dim=input_dim, class_dim=output_dim)
        self.expert_model = Linear_net(input_dim=input_dim, out_dim=output_dim)
        self.alpha = 0
        self.x = None
        self.y = None
        self.x_test = None
        self.y_test = None
        self.data_y = None
        self.data_x = None
        self.data_y_test = None
        self.data_x_test = None
        self.enc_x = None
        self.enc_y = None
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.scaler = None

    def encode_data(self, data_x, data_y, data_x_test, data_y_test, one_hot=True):
        # data_x is numpy array here and we prepare data for training in this function
        # data_y is batch x 1 array to be encoded - note that we assume here rewards are discrete +1,0,-1 and then treat
        # this as a multi-class problem. Hence we will transform y as well.
        if one_hot:
            self.x = data_x
            self.y = data_y
            self.x_test = data_x_test
            self.y_test = data_y_test
            x_enc = np.zeros((self.x.shape[0], self.input_dim))
            for i in range(x_enc.shape[0]):
                x_enc[i, int(self.x[i])] = 1
            self.data_x = torch.Tensor(x_enc)
            self.enc_y = OneHotEncoder(sparse=False)
            self.data_y = torch.Tensor(self.enc_y.fit_transform(data_y))

            x_enc = np.zeros((self.x_test.shape[0], self.input_dim))
            for i in range(x_enc.shape[0]):
                x_enc[i, int(self.x_test[i])] = 1
            self.data_x_test = torch.Tensor(x_enc)
            self.data_y_test = torch.Tensor(self.enc_y.transform(data_y_test))
        else:
            self.x = data_x
            self.y = data_y
            self.scaler = StandardScaler()
            self.x_test = data_x_test
            self.y_test = data_y_test
            self.data_x = torch.Tensor(self.scaler.fit_transform(data_x))
            self.enc_y = OneHotEncoder(sparse=False)
            self.data_y = torch.Tensor(self.enc_y.fit_transform(data_y))
            self.data_x_test = torch.Tensor(self.scaler.transform(data_x_test))
            self.data_y_test = torch.Tensor(self.enc_y.transform(data_y_test))

    def train_expert(self, lr=0.1, n_epochs=100):
        train_classifier_multiclass(net=self.expert_model, data_x=self.data_x, data_y=self.data_y,
                                    lr=lr, n_epochs=n_epochs)
        # print(self.enc_y.categories_)
        test_classifier_multiclass(net=self.expert_model, data_x=self.data_x_test, data_y=self.data_y_test,
                                   enc_y=self.enc_y)

    def train_model(self, lr=0.1, n_epochs=30):
        train_classifier_madras_original(net_class=self.model, net_rej=self.model_rej, net_exp=self.expert_model,
                                         data_x=self.data_x, data_y=self.data_y, lr=lr, n_epochs=n_epochs)
        test_classifier_madras_original(net_class=self.model, net_rej=self.model_rej, net_exp=self.expert_model,
                                        data_x=self.data_x_test, data_y=self.data_y_test)

    def eval(self, one_hot=True):
        batch_result, score_vector = test_classifier_madras_original(net_class=self.model, net_rej=self.model_rej,
                                                                     net_exp=self.expert_model, data_x=self.data_x,
                                                                     data_y=self.data_y)

        if one_hot:
            # get a value estimate from score vector
            V = np.zeros(self.input_dim)
            value_vec = np.multiply(self.y[:, 0], score_vector)  # element-wise mult.
            # print(score_vector, self.y[:, 0])
            # print(value_vec)
            # average over states:
            for s in range(self.input_dim):
                idx = np.where(self.x == s)[0]
                if len(idx) > 0:
                    V[s] = np.mean(value_vec[idx])
            return batch_result, np.mean(V)
        else:
            # get a value estimate from score vector
            # print(score_vector, self.y[:, 0])
            value_vec = np.multiply(self.y[:, 0], score_vector)  # element-wise mult.
            # print(value_vec)
            return batch_result, np.mean(value_vec)


class madras_action_ltd:
    # change this depending on what experiments we want to run
    def __init__(self, input_dim, output_dim, target_policy, clinician_policy, env, defer_cost=0.01, pre_encoder=None,
                 nst=False,
                 train_net=False):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        if not train_net:
            self.model = target_policy
        else:
            self.model = Linear_net_madras_class(input_dim=input_dim, out_dim=output_dim)
        self.train_net = train_net
        self.model_rej = Linear_net_madras_rej(input_dim=input_dim, class_dim=output_dim)
        self.expert_model = clinician_policy
        self.alpha = 0
        self.x = None
        self.y = None
        self.x_test = None
        self.y_test = None
        self.data_y = None
        self.data_x = None
        self.data_x_unencoded = None
        self.data_x_unencoded_test = None
        self.data_y_test = None
        self.data_x_test = None
        self.enc_x = None
        self.enc_y = None
        self.data_t = None
        self.data_t_test = None
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.scaler = None
        self.nst = nst
        self.env = env
        self.pre_encoder = pre_encoder
        self.defer_cost = defer_cost

    def pre_decoder(self, x):
        if self.pre_encoder is not None:
            n_samples, n_features = x.shape
            recovered_X = np.array([self.pre_encoder.active_features_[col] for col
                                    in x.sorted_indices().indices]).reshape(n_samples, n_features) - \
                          self.pre_encoder.feature_indices_[:-1]
            return recovered_X
        else:
            return x

    def encode_data(self, data_x, data_y, data_x_test, data_y_test, data_t, data_t_test, data_x_unencoded=None,
                    data_x_unencoded_test=None, one_hot=True):
        # data_x is numpy array here and we prepare data for training in this function
        # data_y is batch x 1 array to be encoded - note that we assume here rewards are discrete +1,0,-1 and then treat
        # this as a multi-class problem. Hence we will transform y as well.
        # these two are mainly used for evaluation in eval function.
        self.data_t = data_t
        self.data_t_test = data_t_test
        if one_hot:
            self.x = data_x
            self.y = data_y
            self.x_test = data_x_test
            self.y_test = data_y_test
            x_enc = np.zeros((self.x.shape[0], self.input_dim))
            for i in range(x_enc.shape[0]):
                x_enc[i, int(self.x[i])] = 1
            self.data_x = torch.Tensor(x_enc)
            self.enc_y = OneHotEncoder(sparse=False)
            self.enc_y.fit(np.vstack((data_y, data_y_test)))
            self.data_y = torch.Tensor(self.enc_y.transform(np.vstack(data_y)))

            x_enc = np.zeros((self.x_test.shape[0], self.input_dim))
            for i in range(x_enc.shape[0]):
                x_enc[i, int(self.x_test[i])] = 1
            self.data_x_test = torch.Tensor(x_enc)
            self.data_y_test = torch.Tensor(self.enc_y.transform(data_y_test))
            if data_x_unencoded is None:
                self.data_x_unencoded = self.data_x
                self.data_x_unencoded_test = self.data_x_test

            # print(self.data_x.shape,self.data_x_unencoded_test.shape,self.data_x_test.shape, self.data_x_unencoded.shape)
        else:
            self.x = data_x
            self.y = data_y
            self.scaler = StandardScaler()
            self.x_test = data_x_test
            self.y_test = data_y_test
            self.data_x = torch.Tensor(data_x)
            self.enc_y = OneHotEncoder(sparse=False)
            self.data_y = torch.Tensor(self.enc_y.fit_transform(data_y))
            self.data_x_test = torch.Tensor(data_x_test)
            self.data_y_test = torch.Tensor(self.enc_y.transform(data_y_test))
            self.data_x_unencoded = data_x_unencoded
            self.data_x_unencoded_test = data_x_unencoded_test

    def train_expert(self, lr=0.1, n_epochs=100):
        train_classifier_multiclass(net=self.expert_model, data_x=self.data_x, data_y=self.data_y,
                                    lr=lr, n_epochs=n_epochs)
        # print(self.enc_y.categories_)
        test_classifier_multiclass(net=self.expert_model, data_x=self.data_x_test, data_y=self.data_y_test,
                                   enc_y=self.enc_y)

    def train_model(self, lr=0.1, n_epochs=30, func=False):
        if not self.nst:
            train_classifier_madras_original_act(net_class=self.model, net_rej=self.model_rej,
                                                 net_exp=self.expert_model,
                                                 data_x=self.data_x, data_y=self.data_y, lr=lr, n_epochs=n_epochs,
                                                 func=func, env=self.env,
                                                 data_x_unencoded=None, train_net=self.train_net,
                                                 defer_cost=self.defer_cost)
            test_classifier_madras_original_act(net_class=self.model, net_rej=self.model_rej, net_exp=self.expert_model,
                                                data_x=self.data_x_test, data_y=self.data_y_test, func=func,
                                                env=self.env,
                                                data_x_unencoded=None, train_net=self.train_net,
                                                defer_cost=self.defer_cost)
        else:
            # print(self.data_x.shape,self.data_y.shape,self.data_t.shape)
            # print(self.expert_model)
            train_classifier_madras_original_act(net_class=self.model, net_rej=self.model_rej,
                                                 net_exp=self.expert_model,
                                                 data_x=self.data_x, data_y=self.data_y, lr=lr, n_epochs=n_epochs,
                                                 func=func, data_t=self.data_t, env=self.env,
                                                 data_x_unencoded=None,
                                                 train_net=self.train_net, defer_cost=self.defer_cost)
            test_classifier_madras_original_act(net_class=self.model, net_rej=self.model_rej, net_exp=self.expert_model,
                                                data_x=self.data_x_test, data_y=self.data_y_test, func=func,
                                                data_t=self.data_t_test, env=self.env,
                                                data_x_unencoded=None,
                                                train_net=self.train_net, defer_cost=self.defer_cost)

    def eval(self, true_tx, reward_mat, one_hot=True, func=False, env=None, n_trials=5, defer_cost=0.1,
             results_path=None, baseline="madras_ltd", seed=0):
        """
        :param true_tx: dict indexed by t with action x state x state matrices
        :param reward_mat: action x state
        :param one_hot: Is the data one-hot encoded? (False for continous data)
        :param func: Is the policy a function or array (function for continuous data)
        :return: performance summary and average value
        """
        if not func:
            # get a value estimate from score vector
            T = int(len(true_tx.keys()))

            result_dict = {t: [] for t in range(T)}
            cost_dict = {t: [] for t in range(T)}
            state_dict = {t: [] for t in range(T)}
            action_dict = {t: [] for t in range(T)}
            defer_dict = {t: [] for t in range(T)}
            value_dict = {t: [] for t in range(T)}
            defer_freq_dict = {t: [] for t in range(T)}
            defer_times_dict = {jj: [] for jj in range(n_trials)}
            min_t_defer = []

            # get a value estimate from score vector
            for j in range(n_trials):
                defer_cost_total = 0
                min_t_found = False
                cumulative_reward = 0
                for tn, tt in enumerate(range(T)):
                    if tn == 0:
                        xx = np.zeros(self.data_x.shape[1]).reshape(1, -1)
                        xx[0, 0] = 1
                        _, states = torch.max(torch.Tensor(xx), 1)
                        states = int(states[0])
                        if self.train_net:
                            op = self.model(torch.Tensor(xx))
                        else:
                            if not func:
                                # print(env)
                                if env == "discrete_toy" or env == "diabetes":

                                    # print('here')
                                    # print('model:', self.model)
                                    # exit(1)
                                    op = np.array([self.model[tt][states]])
                                else:
                                    op = np.array([self.model[states]])
                                # expert_prediction = np.array([net_exp[int(i)] for i in inputs[:, 0]])
                            else:
                                if self.data_t is not None:
                                    if env == "sepsis_diabetes":
                                        # print('sepsis')
                                        op = np.array([self.model(states, tt)])
                                    else:
                                        op = np.array([self.model(states, tt)])
                                else:
                                    op = np.array([self.model(states)])
                            # print(outputs.shape)
                            if op.ndim == 1:
                                op = torch.Tensor(np.vstack((1 - op, op))).T
                            elif op.shape[1] == 1:
                                op = torch.Tensor(
                                    np.vstack((1 - op[:, 0], op[:, 0]))).T
                            else:
                                op = torch.Tensor(op)
                        # r = (rej[i][1].item() >= 0.5)
                        rej = self.model_rej(torch.Tensor(xx), op)
                        act = rej[0][1].item() >= 0.5
                        # act = int(act.detach().numpy()[0])
                        _, states = torch.max(torch.Tensor(xx), 1)
                        states = states.detach().cpu().numpy()[0]
                        if act == 1:  # defer
                            min_t_defer.append(tt)
                            min_t_found = True
                            defer_times_dict[j].append(tt)

                            if self.env == "discrete_toy":
                                action = np.random.binomial(n=1, p=self.expert_model[tt][states], size=1)[0]
                            else:
                                action = np.random.choice(a=range(len(self.expert_model[tt][states])),
                                                          p=self.expert_model[tt][states], size=1)[0]
                        else:
                            _, action = torch.max(op, 1)
                        if self.env == "diabetes":
                            reward = reward_mat[tt][states, action]
                            if act == 1:
                                defer_cost_total += defer_cost
                                cumulative_reward += (reward - defer_cost)
                            else:
                                cumulative_reward += reward
                        else:
                            reward = reward_mat[states, action]
                            if act == 1:
                                defer_cost_total += defer_cost
                                cumulative_reward += (reward - defer_cost)
                            else:
                                cumulative_reward += reward

                        state_dict[tt].append(states)
                        action_dict[tt].append(action)
                        result_dict[tt].append(reward)
                        defer_dict[tt].append(int(act == 1))
                        value_dict[tt].append(cumulative_reward)
                        defer_freq_dict[tt].append(int(act == 1))
                        cost_dict[tt].append(defer_cost * int(act == 1))

                        p_transition = true_tx[int(tt)][
                            action, states]
                        s_next = np.random.choice(a=range(len(p_transition)), p=p_transition, size=1)[0]
                        s_next_vec = np.zeros((1, self.input_dim))
                        # print(s_next_vec.shape, s_next)
                        s_next_vec[0, s_next] = 1.
                    else:
                        # print(s_next_vec.shape)
                        # op = self.model(torch.tensor(s_next_vec).reshape(1, -1).float())
                        s_next_vec = s_next_vec.reshape(1, -1)
                        if self.train_net:
                            op = self.model(torch.Tensor(s_next_vec))
                        else:
                            if not func:
                                # print(env)
                                if env == "discrete_toy" or env == "diabetes":
                                    # _, states = torch.max(torch.Tensor(s_next_vec), 1)
                                    op = np.array([self.model[tt][s_next]])
                                else:
                                    op = np.array([self.model[s_next]])
                                # expert_prediction = np.array([net_exp[int(i)] for i in inputs[:, 0]])
                            else:
                                if self.data_t is not None:
                                    if env == "sepsis_diabetes":
                                        # print('sepsis')
                                        op = np.array([self.model(s_next, tt)])
                                    else:
                                        op = np.array([self.model(s_next, tt)])
                                else:
                                    op = np.array([self.model(s_next)])
                            # print(outputs.shape)
                            if op.ndim == 1:
                                op = torch.Tensor(np.vstack((op, 1 - op))).T
                            elif op.shape[1] == 1:
                                op = torch.Tensor(
                                    np.vstack((op[:, 0], 1 - op[:, 0]))).T
                            else:
                                op = torch.Tensor(op)
                        # _, act = torch.max(self.model_rej(torch.tensor(s_next_vec).reshape(1, -1).float(), op), 1)
                        rej = self.model_rej(torch.tensor(s_next_vec).reshape(1, -1).float(), op)
                        act = rej[0][1].item() >= 0.5
                        # act = int(act.detach().numpy()[0])
                        # , states = torch.max(torch.tensor(s_next_vec).reshape(1, -1), 1)
                        if act == 1:  # defer
                            if not min_t_found:
                                min_t_found = True
                                min_t_defer.append(tt)

                            if self.env == "discrete_toy":
                                action = np.random.binomial(n=1, p=self.expert_model[tt][s_next], size=1)[0]
                            else:
                                action = np.random.choice(a=range(len(self.expert_model[tt][s_next])),
                                                          p=self.expert_model[tt][s_next], size=1)[0]
                        else:
                            _, action = torch.max(op, 1)

                        if self.env == "diabetes":
                            reward = reward_mat[tt][s_next, action]
                            if act == 1:
                                defer_cost_total += defer_cost
                                cumulative_reward += (reward - defer_cost)
                            else:
                                cumulative_reward += reward
                        else:
                            reward = reward_mat[s_next, action]
                            if act == 1:
                                defer_cost_total += defer_cost
                                cumulative_reward += (reward - defer_cost)
                            else:
                                cumulative_reward += reward

                        state_dict[tt].append(s_next)
                        action_dict[tt].append(action)
                        result_dict[tt].append(reward)
                        defer_dict[tt].append(int(act == 1))
                        value_dict[tt].append(cumulative_reward)
                        defer_freq_dict[tt].append(int(act == 1))
                        cost_dict[tt].append(defer_cost * int(act == 1))

                        p_transition = true_tx[int(tt)][
                            action, s_next]
                        s_next = np.random.choice(a=range(len(p_transition)), p=p_transition, size=1)[0]
                        s_next_vec = np.zeros((1, self.input_dim))
                        s_next_vec[0, s_next] = 1.

        else:  # for continuous data true dynamics and rewards are functions. currently this will only work for
            # randomwalk data+sepsis
            # get a value estimate from score vector
            T = int(max(self.data_t))

            result_dict = {t: [] for t in range(T)}
            cost_dict = {t: [] for t in range(T)}
            state_dict = {t: [] for t in range(T)}
            action_dict = {t: [] for t in range(T)}
            defer_dict = {t: [] for t in range(T)}
            value_dict = {t: [] for t in range(T)}
            defer_freq_dict = {t: [] for t in range(T)}
            defer_times_dict = {jj: [] for jj in range(n_trials)}
            min_t_defer = []
            for j in range(n_trials):
                defer_cost_total = 0
                min_t_found = False
                cumulative_reward = 0
                for tn, tt in enumerate(range(T)):
                    if tn == 0:
                        # op = self.model(self.data_x[i, :].reshape(1, -1))
                        i = np.random.choice(a=self.data_x.shape[0], size=1)[0]
                        if self.train_net:
                            op = self.model(self.data_x[i, :].reshape(1, -1))
                        else:
                            # _, states = torch.max(self.data_x[i, :].reshape(1, -1), 1)
                            states = self.data_x[i, 0]
                            if self.env == "randomwalk":
                                # action = np.random.binomial(n=1, p=self.model(states), size=1)[0]
                                op = self.model(states)
                                op = np.array(op).reshape(1, -1)
                            elif self.env == "hiv":
                                op = self.model(states, tt)
                                op = np.array(op).reshape(1, -1)
                                # action = np.random.choice(a=range(len(pp)), p=pp, size=1)[0]

                        # print(outputs.shape)
                        if op.ndim == 1:
                            op = torch.Tensor(np.vstack((1 - op, op))).T
                        elif op.shape[1] == 1:
                            op = torch.Tensor(
                                np.vstack((1 - op[:, 0], op[:, 0]))).T
                        else:
                            op = torch.Tensor(op)

                        _, act = torch.max(self.model_rej(self.data_x[i, :].reshape(1, -1), op),
                                           1)
                        act = int(act.detach().numpy()[0])
                        if act == 1:
                            if not min_t_found:
                                min_t_found = True
                                min_t_defer.append(tt)

                            if self.env == "randomwalk":
                                action = np.random.binomial(n=1, p=self.expert_model(self.data_x[i, :]), size=1)[0]
                            elif self.env == "hiv":
                                pp = self.expert_model(self.data_x[i, :], t=tt)
                                action = np.random.choice(a=range(len(pp)), p=pp, size=1)[0]
                        else:
                            _, action = torch.max(op, 1)

                        if self.env == "randomwalk":
                            reward = reward_mat(self.data_x[i, 0])
                        elif self.env == "hiv":
                            reward = reward_mat(self.data_x[i, 0], action)

                        if act == 1:
                            defer_cost_total += defer_cost
                            cumulative_reward += (reward - defer_cost)
                        else:
                            cumulative_reward += reward

                        state_dict[tt].append(self.data_x[i, 0])
                        action_dict[tt].append(action)
                        result_dict[tt].append(reward)
                        defer_dict[tt].append(int(act == 1))
                        value_dict[tt].append(cumulative_reward)
                        defer_freq_dict[tt].append(int(act == 1))
                        cost_dict[tt].append(defer_cost * int(act == 1))

                        if self.env == "randomwalk":
                            s_next, _, _ = true_tx(self.data_x[i, :], action, tt=tt)
                        elif self.env == "hiv":
                            s_next = true_tx(self.data_x[i, :], action, time=tt)
                            # s_next = np.random.choice(a=range(len(pp)), p=pp, size=1)[0]
                        # s_next_vec = np.zeros(self.input_dim)
                        # s_next_vec[s_next] = 1
                        # s_next_vec = s_next_vec.reshape(1, -1)
                    else:
                        # op = self.model(self.data_x[i, :].reshape(1, -1))
                        if self.train_net:
                            op = self.model(s_next.reshape(1, -1))
                        else:
                            if self.env == "randomwalk":
                                # action = np.random.binomial(n=1, p=self.model(states), size=1)[0]
                                op = self.model(s_next)
                                op = np.array(op).reshape(1, -1)
                            elif self.env == "hiv":
                                op = self.model(s_next, tt)
                                op = np.array(op).reshape(1, -1)
                                # action = np.random.choice(a=range(len(pp)), p=pp, size=1)[0]

                        # print(outputs.shape)
                        if op.ndim == 1:
                            op = torch.Tensor(np.vstack((1 - op, op))).T
                        elif op.shape[1] == 1:
                            op = torch.Tensor(
                                np.vstack((1 - op[:, 0], op[:, 0]))).T
                        else:
                            op = torch.Tensor(op)

                        _, act = torch.max(self.model_rej(torch.Tensor(np.array(s_next)).reshape(-1, 1), op),
                                           1)
                        act = int(act.detach().numpy()[0])
                        if act == 1:
                            if not min_t_found:
                                min_t_found = True
                                min_t_defer.append(tt)
                            if self.env == "randomwalk":
                                action = np.random.binomial(n=1, p=self.expert_model(s_next), size=1)[0]
                            elif self.env == "hiv":
                                pp = self.expert_model(s_next, t=tt)
                                action = np.random.choice(a=range(len(pp)), p=pp, size=1)[0]
                        else:
                            _, action = torch.max(op, 1)

                        if self.env == "randomwalk":
                            reward = reward_mat(s_next)
                        elif self.env == "hiv":
                            reward = reward_mat(s_next, action)

                        if act == 1:
                            defer_cost_total += defer_cost
                            cumulative_reward += (reward - defer_cost)
                        else:
                            cumulative_reward += reward

                        state_dict[tt].append(s_next)
                        action_dict[tt].append(action)
                        result_dict[tt].append(reward)
                        defer_dict[tt].append(int(act == 1))
                        value_dict[tt].append(cumulative_reward)
                        defer_freq_dict[tt].append(int(act == 1))
                        cost_dict[tt].append(defer_cost * int(act == 1))

                        if self.env == "randomwalk":
                            s_next, _, _ = true_tx(s_next, action, tt=tt)
                        else:
                            s_next = true_tx(s_next, action, time=tt)
                            # s_next = np.random.choice(a=range(len(pp)), p=pp, size=1)[0]

        value_array = np.array(value_dict[T - 1])
        defer_freq_mat = np.array(list(defer_freq_dict.values()), dtype=float)
        cost_mat = np.array(list(cost_dict.values()), dtype=float)  # costmat is T x N
        mean_defer_time = np.nanmean(np.array([np.nanmean(defer_times_dict[jj]) for jj in range(n_trials)]))

        print(baseline, " min_t:", np.median(min_t_defer), np.mean(min_t_defer), 'value:', value_array.mean(),
              'defer frequency:', np.sum(defer_freq_mat) / np.prod(np.shape(defer_freq_mat)), 'mean defer time:',
              mean_defer_time)

        with open(os.path.join(results_path, 'models',
                               '%s_policy_%d_%s_cost_%f_trajectories.pkl' % (
                                       baseline, seed, baseline, defer_cost)),
                  'wb') as f:
            pkl.dump({'reward': result_dict, 'states': state_dict, 'actions': action_dict, 'cost': cost_dict,
                      'value': value_dict, 'defer_freq': defer_freq_dict, 'defer_time': defer_times_dict}, f)

        return value_array, cost_mat.sum(0), defer_freq_mat.sum(0)
