import numpy as np

from tensorflow.keras import backend as K, Model
from tensorflow.keras.layers import concatenate, Input, Lambda

from features.mvsf import MVSF


class KerasMVSF(MVSF):
    
    def __init__(self, keras_psi_model_handle, keras_Sigma_model_handle, target_update_ev, *args, 
                 update_sigma=True, **kwargs):
        super(KerasMVSF, self).__init__(*args, learning_rate_w=0.0, use_true_reward=True, rank='diag', **kwargs)        
        self.keras_psi_model_handle = keras_psi_model_handle
        self.keras_Sigma_model_handle = keras_Sigma_model_handle
        self.target_update_ev = target_update_ev
        self.update_sigma = update_sigma
        self.key = 'sfdqn'
        
    def reset(self):
        MVSF.reset(self)
        self.updates_since_target_updated = []
        
    def build_successor(self, task, source=None):
        
        # input tensor for all networks is shared
        if self.n_tasks == 0:
            self.n_actions = task.action_count()
            self.n_features = task.feature_dim()
            self.inputs = Input(shape=(task.encode_dim(),))
        
        # build SF network and copy its weights from previous task
        # output shape is assumed to be [n_batch, n_actions, n_features]
        psi = self.keras_psi_model_handle(self.inputs)
        if source is not None and self.n_tasks > 0:
            source_psi, _ = self.psi[source]
            psi.set_weights(source_psi.get_weights())
        
        # do the same thing for covariance
        Sigma = self.keras_Sigma_model_handle(self.inputs)
        if source is not None and self.n_tasks > 0:
            source_Sigma, _ = self.Sigma[source]
            Sigma.set_weights(source_Sigma.get_weights())
        
        # concatenate SF/Sigma predictions across all existing tasks
        expand_psi = Lambda(lambda x: K.expand_dims(x, axis=1))(psi.output)
        expand_Sigma = Lambda(lambda x: K.expand_dims(x, axis=1))(Sigma.output)
        if self.n_tasks == 0:
            self.all_psi_outputs = expand_psi
            self.all_Sigma_outputs = expand_Sigma
        else:
            self.all_psi_outputs = concatenate([self.all_psi_outputs, expand_psi], axis=1)
            self.all_Sigma_outputs = concatenate([self.all_Sigma_outputs, expand_Sigma], axis=1)
        self.all_psi = Model(inputs=self.inputs, outputs=self.all_psi_outputs)
        self.all_Sigma = Model(inputs=self.inputs, outputs=self.all_Sigma_outputs)
        
        # dummy compile or keras complains when predicting
        self.all_psi.compile('sgd', 'mse')  
        self.all_Sigma.compile('sgd', 'mse')
        
        # build target models and copy their weights 
        target_psi = self.keras_psi_model_handle(self.inputs)
        target_Sigma = self.keras_Sigma_model_handle(self.inputs)
        target_psi.set_weights(psi.get_weights())
        target_Sigma.set_weights(Sigma.get_weights())        
        self.updates_since_target_updated.append(0)
        
        return (psi, target_psi), (Sigma, target_Sigma)    
        
    def get_successor(self, state, index):
        psi, _ = self.psi[index]
        Sigma, _ = self.Sigma[index]
        psi = psi.predict_on_batch(state)
        Sigma = Sigma.predict_on_batch(state)
        
        # covariance is currently diagonal, so for each sample in the batch make a diagonal matrix
        Sigma_diag = np.zeros(Sigma.shape + (Sigma.shape[-1],), dtype=Sigma.dtype)
        np.einsum('ijkk->ijk', Sigma_diag)[:] = Sigma
        return psi, Sigma_diag
    
    def get_successors(self, state):
        psi = self.all_psi.predict_on_batch(state)
        Sigma = self.all_Sigma.predict_on_batch(state)
        
        # covariance is currently diagonal, so for each sample in the batch make a diagonal matrix
        Sigma_diag = np.zeros(Sigma.shape + (Sigma.shape[-1],), dtype=Sigma.dtype)
        np.einsum('ijkll->ijkl', Sigma_diag)[:] = Sigma
        return psi, Sigma_diag
    
    def update_successor(self, transitions, index):
        if transitions is None: return
        states, actions, phis, next_states, gammas = transitions
        n_batch = len(gammas)
        indices = np.arange(n_batch)
        gammas = gammas.reshape((-1, 1))
         
        # next actions come from GPI
        q1, _ = self.GPI(next_states, index)
        next_actions = np.argmax(np.max(q1, axis=1), axis=-1)
        
        # compute the targets and TD errors
        psi, target_psi = self.psi[index]
        current_psi = psi.predict_on_batch(states)
        targets = phis + gammas * target_psi.predict_on_batch(next_states)[indices, next_actions,:]
        errors = targets - current_psi[indices, actions,:]
        
        # train the SF network
        current_psi[indices, actions,:] = targets
        psi.train_on_batch(states, current_psi)
        
        # update sigma network
        if self.update_sigma:
        
            # compute the targets for Sigma
            Sigma, target_Sigma = self.Sigma[index]
            targets2 = (errors ** 2) + \
                (gammas ** 2) * target_Sigma.predict_on_batch(next_states)[indices, next_actions,:]
                     
            # train the Sigma network
            current_Sigma = Sigma.predict_on_batch(states)
            current_Sigma[indices, actions,:] = targets2
            Sigma.train_on_batch(states, current_Sigma)
            
        # target network update
        self.updates_since_target_updated[index] += 1
        if self.updates_since_target_updated[index] >= self.target_update_ev:
            target_psi.set_weights(psi.get_weights())
            if self.update_sigma:
                target_Sigma.set_weights(Sigma.get_weights())
            self.updates_since_target_updated[index] = 0
        
