"""Implementation of a Gaussian Policy"""
# Libraries
from policies import BasePolicy
from abc import ABC
import numpy as np
import copy
import time
from collections import defaultdict
import hashlib
from numpy.testing import assert_allclose


class LinearGaussianPolicy(BasePolicy, ABC):
    """
    Implementation of a Gaussian Policy which is linear in the state.
    Thus, the mean will be: parameters @ state.
    The standard deviation is fixed and is defined by the user.
    """
    def __init__(
            self, parameters: np.array = None,
            std_dev: float = 0.1,
            std_decay: float = 0,
            std_min: float = 1e-4,
            dim_state: int = 1,
            dim_action: int = 1,
            multi_linear: bool = False

    ) -> None:
        # Superclass initialization
        super().__init__()

        # Attributes with checks
        self.parameters = parameters

        err_msg = "[GaussPolicy] standard deviation is negative!"
        assert std_dev > 0, err_msg
        self.std_dev = std_dev
        self.var = std_dev ** 2

        # Additional attributes
        self.dim_state = dim_state
        self.dim_action = dim_action
        self.tot_params = dim_action * dim_state
        self.multi_linear = multi_linear
        self.std_decay = std_decay
        self.std_min = std_min
        return
    
    def calculate_mean(self, state):
        if state.ndim == 1:
            return np.array(self.parameters @ state, dtype=np.float64)
        else:
            return np.array(state @ self.parameters.T, dtype=np.float64)
        
    def calculate_target_mean(self, state, parameter):
        parameter = parameter.reshape(self.dim_action, self.dim_state)
        if state.ndim == 1:
            return np.array(parameter @ state, dtype=np.float64)
        else:
            return np.array(state @ parameter.T, dtype=np.float64)
           

    def draw_action(self, state, return_mean = False) -> float:
        if len(state) != self.dim_state:
            err_msg = "[GaussPolicy] the state has not the same dimension of the parameter vector:"
            err_msg += f"\n{len(state)} vs. {self.dim_state}"
            raise ValueError(err_msg)
        

        mean = self.calculate_mean(state)

        action = np.array(np.random.normal(mean, self.std_dev), dtype=np.float64)

        #some algorithms can cache mean and reuse
        if return_mean:
            return action, mean
        
        return action

    def reduce_exploration(self):
        self.std_dev = np.clip(self.std_dev - self.std_decay, self.std_min, np.inf)

    def set_parameters(self, thetas) -> None:
        if not self.multi_linear:
            self.parameters = copy.deepcopy(thetas)
        else:
            self.parameters = np.array(np.split(thetas, self.dim_action))
            
    def get_parameters(self):
        return self.parameters
    
    
    def compute_pi(self, state, action):
        mean = self.calculate_mean(state)
        fact = 1 / (np.sqrt(2 * np.pi) * self.std_dev)
        prob = fact * np.exp(-((action - mean) ** 2) / (2 * (self.std_dev ** 2)))
        
        return prob 

    def compute_log_pi(self, state, action):
        mean = self.calculate_mean(state)
        
        fact = 1 / (np.sqrt(2 * np.pi) * self.std_dev)
        log_prob = np.log(fact) - ((action - mean) ** 2) / (2 * (self.var))
        
        return log_prob

    def compute_sum_log_pi(self, states, actions):
        
        means = self.calculate_mean(states) #means is now timesteps x action_dim

        log_fact = -np.log(np.sqrt(2 * np.pi) * self.std_dev)
        log_prob = log_fact - ((actions - means) ** 2) / (2 * self.var)
        
        return np.sum(log_prob)
    

    def compute_score(self, state, action) -> np.array:
        if self.std_dev == 0:
            return super().compute_score(state, action)

        mean = self.calculate_mean(state)

        #state = np.ravel(state)
        action_deviation = action - mean
        if self.multi_linear:
            # state = np.tile(state, self.dim_action).reshape((self.dim_action, self.dim_state))
            action_deviation = action_deviation[:, np.newaxis]
        scores = (action_deviation * state) / (self.std_dev ** 2)
        if self.multi_linear:
            scores = np.ravel(scores)
        return scores

    def compute_score_trajectory(self, states, actions):
        means = self.calculate_mean(states)
        action_deviations = actions - means
        scores = (action_deviations[:, :, np.newaxis] * states[:, np.newaxis, :]) / self.var
        return scores.reshape(scores.shape[0], -1)

    def compute_score_all_trajectories(self, states_queue, actions_queue, means):
        action_deviations = actions_queue - means
        
        if states_queue.ndim == 2:
            scores = (action_deviations[:, :, np.newaxis] * states_queue[:, np.newaxis, :]) / self.var
            return scores.reshape(scores.shape[0], -1)
        
        action_deviations = action_deviations[:, :, :, np.newaxis]  
        states_expanded = states_queue[:, :, np.newaxis, :]  
        
        scores = (action_deviations * states_expanded) / self.var  
        
        return scores.reshape(scores.shape[0], scores.shape[1], -1)
    
    def diff(self, state):
        raise NotImplementedError
    
    def compute_sum_all_log_pi(self, states, actions):
        means = self.calculate_mean(states)
        
        if actions.ndim == 2:
            action_deviations = actions[np.newaxis, :, :] - means

        elif actions.ndim == 3:
            action_deviations = actions - means

        log_fact = -np.log(np.sqrt(2 * np.pi) * self.std_dev)

        log_probs = log_fact - (action_deviations ** 2) / (2 * self.var)

        return np.sum(log_probs, axis=(1, 2)), means


