from Environment.environment import non_state_factors
import numpy as np

class FlatNormalization:
    def __init__(self, all_names, lim_dict, goal_lims):
        self.flat_lims = (np.concatenate([lim_dict[n][0] for n in all_names if n not in non_state_factors]),
                          np.concatenate([lim_dict[n][1] for n in all_names if n not in non_state_factors]))
        self.goal_lims = goal_lims
        self.goal_mean = (goal_lims[1] + goal_lims[0]) / 2
        self.goal_var = (goal_lims[1] - goal_lims[0]) / 2
        self.mean = (self.flat_lims[1] + self.flat_lims[0]) / 2
        self.var = (self.flat_lims[1] - self.flat_lims[0]) / 2

    def normalize_obs(self, obs):
        if len(obs.shape) == len(self.mean.shape):
            return (obs - self.mean) / self.var
        return (obs - np.expand_dims(self.mean, axis=0) ) / np.expand_dims(self.var, axis=0)

    def normalize_goal(self, obs):
        if len(obs.shape) == len(self.goal_mean.shape):
            return (obs - self.goal_mean) / self.goal_var
        return (obs - np.expand_dims(self.goal_mean, axis=0) ) / np.expand_dims(self.goal_var, axis=0)
