import numpy as np
from abc import ABC


class vi_shaping(ABC):
    def __init__(self, tx_mat, reward_vec, nst_tx=True, nst_rd=False):
        # indexed by episode and of shape action x state x state
        self.tx_dict = tx_mat
        # of shape state x actions
        self.reward = reward_vec
        self.nst_tx = nst_tx
        self.nst_rd = nst_rd

    def augment_state_action(self, deferral_cost=-0.1):
        n_states = self.tx_dict[0].shape[1]
        n_actions = self.tx_dict[0].shape[0]

        # augment transitions
        if self.nst_tx:
            for k in self.tx_dict.keys():
                tx_mat = np.zeros((n_actions + 1, n_states + 1, n_states + 1))
                for a in range(n_actions + 1):
                    tx_mat[a, n_states, :n_states] = 0
                    tx_mat[a, n_states, n_states] = 1
                self.tx_dict[k] = tx_mat + 1e-5
                self.tx_dict[k] /= np.sum(self.tx_dict[k], axis=-1, keepdims=True)
        else:
            tx_mat = np.zeros((n_actions + 1, n_states + 1, n_states + 1))
            for a in range(n_actions + 1):
                tx_mat[a, n_states, :n_states] = 0
                tx_mat[a, n_states, n_states] = 1
            self.tx_dict = tx_mat + 1e-5
            self.tx_dict /= np.sum(self.tx_dict, axis=-1, keepdims=True)

        # augment rewards
        if self.nst_rd:
            for k in self.reward.keys():
                reward_augmented = np.zeros((n_states + 1, n_actions + 1))
                # print(np.where(self.reward==1), np.where(self.reward==0))#, np.where(self.reward==-1))
                reward_augmented[np.ix_(range(n_states), range(n_actions))] = self.reward[k]
                reward_augmented[range(n_states), n_actions] = deferral_cost
                reward_augmented[n_states, :] = deferral_cost
                self.reward[k] = reward_augmented
        else:
            reward_augmented = np.zeros((n_states + 1, n_actions + 1))
            # print(np.where(self.reward==1), np.where(self.reward==0))#, np.where(self.reward==-1))
            reward_augmented[np.ix_(range(n_states), range(n_actions))] = self.reward
            reward_augmented[range(n_states), n_actions] = deferral_cost
            reward_augmented[n_states, :] = deferral_cost
            self.reward = reward_augmented

    def run(self, discount=0.99, max_iter=2000, theta=1e-5):
        tau = len(self.tx_dict.keys())
        delta = 1e10
        n_states = self.tx_dict[0].shape[1]  # augmented states
        V_func = np.zeros(n_states)
        # theta = theta  # threshold
        k = 0
        while delta > theta and k < max_iter:
            delta = 0.
            for t in range(tau - 1, -1, -1):
                # if t % 10 == 0:
                # print('t learning:', t)
                for s in range(n_states):
                    v = V_func[s]
                    if self.nst_rd:
                        Q_s = self.reward[t][s, :] + discount * np.dot(self.tx_dict[t][:, s, :], V_func)
                    else:
                        Q_s = self.reward[s, :] + discount * np.dot(self.tx_dict[t][:, s, :], V_func)
                    V_func[s] = np.max(Q_s)
                    delta = max(delta, np.abs(v - V_func[s]))
                    # print(delta)
                k += 1

        # Get non-stationary policy
        policy = np.zeros((n_states, len(self.tx_dict.keys())))
        for s in range(n_states):
            Q_s = np.zeros((tau, self.tx_dict[0].shape[0]))
            for t in range(tau - 1, -1, -1):
                #if t % 10 == 0:
                #    print('t collecting:', t)
                if self.nst_rd:
                    Q_s[t, :] = self.reward[t][s, :] + discount * np.dot(self.tx_dict[t][:, s, :], V_func)
                else:
                    Q_s[t, :] = self.reward[s, :] + discount * np.dot(self.tx_dict[t][:, s, :], V_func)
                # Q_s_vec = np.max(Q_s, axis=0)  # episode max
                policy[s, t] = np.argmax(Q_s[t, :])  # best greedy

        return policy, V_func
