from collections import defaultdict
from copy import deepcopy
import numpy as np

from features.mvsf import MVSF


class TabularMVSF(MVSF):
    
    def __init__(self, alpha, alpha_var, *args,
                 noise_init=lambda size: np.random.uniform(-0.01, 0.01, size=size), **kwargs):
        super(TabularMVSF, self).__init__(*args, **kwargs)
        self.alpha = alpha
        self.alpha_var = alpha_var
        self.noise_init = noise_init
    
    def build_successor(self, task, source=None):
        if source is None or len(self.psi) == 0:
            
            # represent SF and covariance as hash tables, assume state hashable
            n_actions = task.action_count()
            n_features = task.feature_dim()
            psi = defaultdict(lambda: self.noise_init((n_actions, n_features)))
            Sigma = defaultdict(lambda: np.zeros((n_actions, n_features, n_features)))
        else:
            
            # SF and covariance are copies from previous tasks
            psi = deepcopy(self.psi[source])
            Sigma = deepcopy(self.Sigma[source])
        return psi, Sigma
                
    def get_successor(self, states, index):
        psi = np.expand_dims(self.psi[index][states], axis=0)
        Sigma = np.expand_dims(self.Sigma[index][states], axis=0)
        return psi, Sigma
    
    def get_successors(self, states):
        psi = np.expand_dims(np.array([psi[states] for psi in self.psi]), axis=0)
        Sigma = np.expand_dims(np.array([sigma[states] for sigma in self.Sigma]), axis=0)
        return psi, Sigma
    
    def update_successor(self, transitions, index):
        for state, action, phi, next_state, next_action, gamma in transitions:
            
            # update mean
            psi = self.psi[index]
            targets = phi.flatten() + gamma * psi[next_state][next_action,:] 
            error = targets - psi[state][action,:]
            psi[state][action,:] = psi[state][action,:] + self.alpha * error
             
            # update covariance, e.g. see https://arxiv.org/pdf/1801.08287.pdf
            Sigma = self.Sigma[index]
            if self.rank == 'full':
                error2 = np.outer(error, error)
            elif self.rank == 'diag':
                error2 = np.diag(error ** 2)
            else:
                raise Exception('Invalid rank {}'.format(self.rank))
            targets2 = error2 + (gamma ** 2) * Sigma[next_state][next_action,:,:]
            error2 = targets2 - Sigma[state][action,:,:]
            Sigma[state][action,:,:] = Sigma[state][action,:,:] + self.alpha_var * error2
        
