import numpy as np

from tensorflow.keras import backend as K, Model
from tensorflow.keras.layers import concatenate, Input, Lambda

from agents.agent import Agent


class C51SF:
    
    def __init__(self, model_lambda, v_min, v_max, target_update_ev, risk_aversion, method, *args,
                 n_atoms=51, **kwargs):
        assert method in ['gauss', 'laplace']
        self.model_lambda = model_lambda
        self.v_min = np.array(v_min).reshape((1, -1, 1))
        self.v_max = np.array(v_max).reshape((1, -1, 1))
        self.target_update_ev = target_update_ev
        self.n_atoms = n_atoms
        self.risk_aversion = risk_aversion
        self.method = method
        self.key = 'sfc51'
        
        # atoms
        self.delta_z = (self.v_max - self.v_min) / (self.n_atoms - 1.)
        self.z = np.row_stack([[self.v_min[0, j, 0] + i * self.delta_z[0, j, 0] for i in range(self.n_atoms)]
                               for j in range(self.delta_z.size)])  # [n_features, n_atoms]
        
    def build_successor(self, task, source=None):
        
        # input tensor for all networks is shared
        if self.n_tasks == 0:
            self.n_actions = task.action_count()
            self.n_features = task.feature_dim()
            self.inputs = Input(shape=(task.encode_dim(),))
        
        # build SF network and copy its weights from previous task
        # output shape is [n_batch, n_actions, n_features, n_atoms]
        psi = self.model_lambda(self.inputs)
        if source is not None and self.n_tasks > 0:
            source_psi, _ = self.psi[source]
            psi.set_weights(source_psi.get_weights())
        
        # concatenate SF predictions across all existing tasks
        # output shape is [n_batch, n_policies, n_actions, n_features, n_atoms]
        expand_psi = Lambda(lambda x: K.expand_dims(x, axis=1))(psi.output)
        if self.n_tasks == 0:
            self.all_psi_outputs = expand_psi
        else:
            self.all_psi_outputs = concatenate([self.all_psi_outputs, expand_psi], axis=1)
        self.all_psi = Model(inputs=self.inputs, outputs=self.all_psi_outputs)
        self.all_psi.compile('sgd', 'mse')  
        
        # build target models and copy their weights 
        target_psi = self.model_lambda(self.inputs)
        target_psi.set_weights(psi.get_weights())
        self.updates_since_target_updated.append(0)
        
        return psi, target_psi
        
    def get_successor(self, states, index):
        psi, _ = self.psi[index]
        return psi.predict_on_batch(states)
    
    def get_successors(self, states):
        return self.all_psi.predict_on_batch(states)
    
    def update_successor(self, transitions, index, **kwargs):
        
        # unpack transitions
        if transitions is None: return
        states, actions, phis, next_states, gammas = transitions
        n_batch = len(gammas)
        gammas = gammas.reshape((n_batch, 1, 1))
        phis = phis[:,:, np.newaxis]

        # next actions come from GPI
        psi, target_psi = self.psi[index]
        z_ = target_psi.predict_on_batch(next_states)
        q1, _ = self.GPI(next_states, index)
        next_actions = np.argmax(np.max(q1, axis=1), axis=-1)
        
        # compute target values
        m_prob = np.zeros_like(z_)
        Tz = np.clip(phis + gammas * self.z[np.newaxis,:,:], self.v_min, self.v_max)
        bj = (Tz - self.v_min) / self.delta_z 
        m_l = np.floor(bj).astype(int)
        m_u = np.ceil(bj).astype(int)
        idx = np.arange(n_batch)
        delta_u = z_[idx, next_actions,:] * (m_u - bj)
        delta_l = z_[idx, next_actions,:] * (bj - m_l)
        idx0 = np.repeat(idx, self.n_features * self.n_atoms)
        idx1 = np.repeat(actions, self.n_features * self.n_atoms)
        idx2 = np.tile(np.repeat(np.arange(self.n_features), self.n_atoms), n_batch)
        np.add.at(m_prob, (idx0, idx1, idx2, m_l.ravel()), delta_u.ravel())
        np.add.at(m_prob, (idx0, idx1, idx2, m_u.ravel()), delta_l.ravel())
        
        # fit model
        loss = psi.train_on_batch(states, m_prob)
       
        # target network update
        self.updates_since_target_updated[index] += 1
        if self.updates_since_target_updated[index] >= self.target_update_ev:
            target_psi.set_weights(psi.get_weights())
            self.updates_since_target_updated[index] = 0
        
    def GPI_w(self, states, w):
        p = self.get_successors(states)
        atoms = np.expand_dims(self.z, axis=(0, 1, 2))
        m1 = np.sum(p * atoms, axis=-1)
        m2 = np.sum(p * (atoms ** 2), axis=-1)
        v = m2 - (m1 ** 2)
        w = w.reshape((1, 1, 1, -1))
        mean = np.sum(m1 * w, axis=-1)
        var = np.sum(v * (w ** 2), axis=-1)
        if self.method == 'gauss':
            penalty = self.risk_aversion * var
        else:
            penalty = (-1. / self.risk_aversion) * np.log(
                np.maximum(1. - 0.5 * self.risk_aversion ** 2 * var, 1e-15))
        q = mean - penalty
        task = np.squeeze(np.argmax(np.max(q, axis=2), axis=1))
        return q, task
    
    def GPI(self, states, index, update_counters=False):
        w = self.fit_w[index]
        q, task = self.GPI_w(states, w)
        if update_counters:
            self.gpi_counters[index][task] += 1
        return q, task
    
    def compute_mean_variance(self, states, policy_index):
        p = self.get_successor(states, policy_index)
        atoms = np.expand_dims(self.z, axis=(0, 1))
        m1 = np.sum(p * atoms, axis=-1)
        m2 = np.sum(p * (atoms ** 2), axis=-1)
        v = m2 - (m1 ** 2)
        return m1, v
    
    def GPI_usage_percent(self, index):
        counts = self.gpi_counters[index]        
        return 1. - (float(counts[index]) / np.sum(counts))

    def reset(self):
        self.n_tasks = 0
        self.psi = []
        self.true_w = []
        self.fit_w = []
        self.gpi_counters = []
        self.updates_since_target_updated = []

    def add_training_task(self, task, source=None):
        
        # add successor features to the library
        psi = self.build_successor(task, source)
        self.psi.append(psi)
        self.n_tasks = len(self.psi)
        
        # build new reward function
        true_w = task.get_w()
        self.true_w.append(true_w)
        fit_w = true_w
        n_features = task.feature_dim()
        self.fit_w.append(fit_w)
        
        # add statistics for counting actions transferred from other tasks
        for i in range(len(self.gpi_counters)):
            self.gpi_counters[i] = np.append(self.gpi_counters[i], 0)
        self.gpi_counters.append(np.zeros((self.n_tasks,), dtype=int))
    
    def update_reward(self, phi, r, index):
        pass
    
