import numpy as np
import scipy.special
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d

def generate_boltzmann_distribution(q):
		"""
			  Generate a Boltzmann distribution for a given Q-function.
		"""
    return np.exp(q)/np.sum(np.exp(q), keepdims=True, axis=1)

def inverse_action_value_iteration(action_probabilities, mdp, gamma=1.0, theta=0.00001, epsilon=1e-6):
    """
        Implementation of Inverse Action-value Iteration (IAVI).

        Returns the immediate reward given an MDP and action probabilities (along with Q-function and Boltzmann distribution).
    """

    # Initialize reward and Q-table.
    r = np.zeros((mdp.nS, mdp.nA))
    q = np.zeros((mdp.nS, mdp.nA))

    # Iterate until convergence.
    while True:
        delta = 0
        # Go through the MDP in reverse topological order.
        for i in reversed(range(mdp.nS)):
            if np.all(action_probabilities[i] == 0):
                continue
            # Create coefficient matrix.
            X = []
            for a in range(mdp.nA):
                row = np.ones(mdp.nA)
                for oa in range(mdp.nA):
                    if oa == a:
                        continue
                    row[oa] /= -(mdp.nA-1)
                X.append(row)
            X = np.array(X)

            # Create target vector.
            y = []
            for a in range(mdp.nA):
                other_actions = [oa for oa in range(mdp.nA) if oa != a]
                # Get log probabilities of all other actions.
                sum_of_oa_logs = np.sum([np.log(action_probabilities[i][oa] + epsilon) for oa in other_actions])
                # Get future Q-values of all other actions.
                sum_of_oa_q = np.sum([mdp.transition_probabilities[i][oa] * gamma * (1-mdp.dones[i, oa]) * np.max(q[np.arange(mdp.nS)], axis=1) for oa in other_actions])
                y.append(np.log(action_probabilities[i][a] + epsilon)-(1/(mdp.nA-1))*sum_of_oa_logs+(1/(mdp.nA-1))*sum_of_oa_q-np.sum(mdp.transition_probabilities[i][a] * gamma* (1-mdp.dones[i, a]) * np.max(q[np.arange(mdp.nS)], axis=1)))
            y = np.array(y)

            # Solve system of linear equations via least squares.
            x = np.linalg.lstsq(X, y, rcond=None)[0]
                
            # Update Q-table based on updated reward.
            r[i] = x
            for a in range(mdp.nA):
                q[i, a] = r[i, a] + np.sum(mdp.transition_probabilities[i][a] * gamma * (1-mdp.dones[i, a]) * np.max(q[np.arange(mdp.nS)], axis=1))
    
            for a in range(mdp.nA):
                delta = max(delta, np.abs(r[i, a]-x[a]))

        if delta < theta:
            break

    return q, r, generate_boltzmann_distribution(q)

def action_value_iteration(reward, mdp, gamma=1.0, theta=0.00001):
    """
        Implementation of Action-value Iteration.

        Returns Q-function (along with Boltzmann distribution).
    """

    # Initialize Q-table.
    q = np.zeros((mdp.nS, mdp.nA))

    # Iterate until convergence.
    while True:
        delta = 0
        q_old = np.copy(q)
        for s in range(mdp.nS):
            for a in range(mdp.nA):
                # Calculate Q-target.
                ns_prob = mdp.transition_probabilities[s, a]
                next_q = np.max(q_old, axis=1)
                r = reward[s][a]
                # Update table entry.
                q[s][a] = np.sum(ns_prob * (r + gamma * (1-mdp.dones[s, a]) * next_q))
                delta = max(delta, np.abs(q[s][a] - q_old[s][a]))
        if delta < theta:
            break

    return q, generate_boltzmann_distribution(q)

def fit_weights(feature_matrix, rewards):
    """
        Fit weights of a linear combination of features to linear reward function.

        Can be replaced by a non-linear combination, e.g. a neural network.
    """
    theta = np.linalg.lstsq(feature_matrix, rewards, rcond=None)[0]
    return np.dot(feature_matrix, theta), theta

class MDP:
    """
        The MDP as defined in the paper.
    """
    def __init__(self):
        idx_before_cue = 9
        idx_correct = 12

        stay = 0
        release = 1

        self.nS = 15
        self.nA = 2

        num_real_time_states = 13
        self.failure = self.nS - 2
        self.success = self.nS - 1

        self.dones = np.zeros((self.nS, self.nA))
        self.transition_probabilities = np.zeros((self.nS, self.nA, self.nS))

        self.P = {}
        for i in range(num_real_time_states):
            if i < idx_before_cue:
                self.P[i] = {stay: (i+1, 0, False), release: (self.failure, 0, True)}
                self.dones[i][release] = 1.0
                self.transition_probabilities[i][stay][i+1] = 1.0
                self.transition_probabilities[i][release][self.failure] = 1.0
            elif i < idx_correct:
                self.P[i] = {stay: (i+1, 0, False), release: (self.success, 1, True)}
                self.dones[i][release] = 1.0
                self.transition_probabilities[i][stay][i+1] = 1.0
                self.transition_probabilities[i][release][self.success] = 1.0
            else:
                self.P[i] = {stay: (self.failure, 0, True), release: (self.failure, 0, True)}
                self.dones[i][stay] = 1.0
                self.dones[i][release] = 1.0
                self.transition_probabilities[i][stay][self.failure] = 1.0
                self.transition_probabilities[i][release][self.failure] = 1.0

if __name__ == "__main__":
    # Generate MDP as used in the paper.
    mdp = MDP()
    print("Generate random action distribution.")
    # Generate random action probabilities.
    real_action_probabilities = np.random.rand(mdp.nS, mdp.nA)
    real_action_probabilities /= np.sum(real_action_probabilities, keepdims=True, axis=1)
    # In the defined MDP, the last two states are abstract terminal states (success/failure).
    real_action_probabilities[-2:] = 0.5
    print("Find reward by IAVI.")
    # Find reward function and induced Boltzmann distribution.
    _, reward, boltzmann_distribution = inverse_action_value_iteration(real_action_probabilities, mdp)
    print("Distance between real action distribution and found Boltzmann distribution: %s"%scipy.special.kl_div(boltzmann_distribution, real_action_probabilities).sum())

    # Set a number of random neurons.
    num_neurons = 30
    print("Generate random features of %s neurons."%num_neurons)
    # Generate random features of random neurons. 
    feature_matrix = np.random.rand(mdp.nS, num_neurons)
    # In the defined MDP, the last two states are abstract terminal states (success/failure).
    feature_matrix[-2:] = 0
    print("Fit weights for random features.")
    # Fit all features to reward.
    fitted_reward, theta = fit_weights(feature_matrix, reward)
    # And get the induced Boltzmann distribution after fitting the all features.
    _, fitted_boltzmann_distribution = action_value_iteration(fitted_reward, mdp)
    print("Distance between real action distribution and found fitted Boltzmann distribution: %s"%scipy.special.kl_div(fitted_boltzmann_distribution, real_action_probabilities).sum())

    # Randomly select some neurons for inhibition (=non-active).
    num_inhibited_neurons = np.random.randint(2, 16)
    inhibited_boltzmann_distributions = []
    print("Inhibit %s random neurons."%num_inhibited_neurons)
    print("Generate new reward based on inhibited neurons and the previously found weights.")
    for i in range(100):
        active_neurons = np.random.permutation([0] * num_inhibited_neurons + [1] * (num_neurons - num_inhibited_neurons))
        # Set features of inhibited neurons to zero.
        inhibited_feature_matrix = feature_matrix * active_neurons
        # Generate new reward based on inhibited features and found weights.
        inhibited_reward = np.dot(inhibited_feature_matrix, theta)
        _, inhibited_boltzmann_distribution = action_value_iteration(inhibited_reward, mdp)
        inhibited_boltzmann_distributions.append(inhibited_boltzmann_distribution)

    # Create plot.
    plt.rc('text', usetex=True)
    font = {'family' : 'serif',
        'weight' : 'bold',
        'size'   : 10}
    plt.rc('font', **font)
    fig = plt.figure(figsize=(6.4, 3.5))
    bins = np.arange(0, 2.6, 0.2)
    x_inter = np.linspace(0, 2.6, 300)
    real_y_inter = interp1d(bins, real_action_probabilities[:-2, 1], kind='quadratic', bounds_error=False)
    boltzmann_y_inter = interp1d(bins, boltzmann_distribution[:-2, 1], kind='quadratic', bounds_error=False)
    fitted_y_inter = interp1d(bins, fitted_boltzmann_distribution[:-2, 1], kind='quadratic', bounds_error=False)
    inhibited_mean = np.mean(inhibited_boltzmann_distributions, axis=0)
    inhibited_std = np.std(inhibited_boltzmann_distributions, axis=0)
    inhibited_y_inter_mean = interp1d(bins, inhibited_mean[:-2, 1], kind='quadratic', bounds_error=False)
    inhibited_y_inter_std_min = interp1d(bins, inhibited_mean[:-2, 1]-inhibited_std[:-2, 1], kind='quadratic', bounds_error=False)
    inhibited_y_inter_std_max = interp1d(bins, inhibited_mean[:-2, 1]+inhibited_std[:-2, 1], kind='quadratic', bounds_error=False)
    plt.plot(x_inter, real_y_inter(x_inter)*100, label="Real Action Distribution", linewidth=3.0)
    plt.plot(x_inter, boltzmann_y_inter(x_inter)*100, label="Boltzmann Distribution after IAVI", linewidth=3.0, linestyle="--")
    plt.plot(x_inter, fitted_y_inter(x_inter)*100, label="Boltzmann Distribution after Fitting Weights to Reward", linewidth=3.0, linestyle=":")
    p = plt.plot(x_inter, inhibited_y_inter_mean(x_inter)*100, label="Inhibition of %s Random Neurons"%num_inhibited_neurons, linewidth=3.0)
    plt.fill_between(x_inter, inhibited_y_inter_std_min(x_inter)*100, inhibited_y_inter_std_max(x_inter)*100, alpha=0.2, color=p[-1].get_color())
    ax = plt.gca()
    ax.axvline(x=1.6, color='k', linewidth=2.0, linestyle="--" , alpha=0.8) 
    ax.axvline(x=2.2, color='k', linewidth=2.0, linestyle="--" , alpha=0.8) 
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)
    plt.xlabel("Time [s]")
    plt.ylabel(r"\% of Releases")
    handles, labels = plt.gca().get_legend_handles_labels()
    handles = handles[-1:] + handles[0:-1]
    labels = labels[-1:] + labels[0:-1]
    plt.legend(handles, labels, bbox_to_anchor=(0,1.02,1,0.2), loc="lower left", mode="expand", borderaxespad=0, ncol=1)
    plt.title("Random Showcase of Toolbox", y=1.4)
    print("Create plot in ./NeuRL.pdf.")
    plt.savefig("NeuRL.pdf", bbox_inches='tight')