import gymnasium as gym
from point_mass.pointmass_box2d import PointMassBox2d
from rl.utils.replay_buffer import ReplayBuffer
from rl.utils.base_wrapper import NormalizedGymnasiumBoxEnv

import pickle
import numpy as np


class PointMassProcessor(object):
    def __init__(self,
                 path: str,
                 risk_prob: float = 0.95,
                 risk_var: float = 50,
                 timelimit: int = 200,
                 normalize_obs: bool = False,
                 *,
                 seed: int = 42
                 ):

        self.env = gym.wrappers.TimeLimit(PointMassBox2d(risk_prob=risk_prob, risk_var=risk_var,
                                                         seed=seed, eval_env=False,
                                                         ), max_episode_steps=timelimit * 3)

        self.risk_prob = risk_prob
        self.risk_var = risk_var

        with open(path, 'rb') as f:
            self.dict_data = {k: np.asarray(v) for k, v in pickle.load(f).items()}

        self.np_rng = np.random.default_rng(np.invert(seed).astype(np.uint32))
        self.buffer = self.process_data(self.dict_data)
        self.normalize_obs = normalize_obs
        if self.normalize_obs:
            obs = self.buffer.observations
            mu = obs.mean(axis=0, keepdims=True)
            sigma = obs.std(axis=0, keepdims=True)
            self.buffer.observations = (obs - mu) / (sigma + 1e-8)
            self.buffer.next_observations = (self.buffer.next_observations - mu) / (sigma + 1e-8)
            self.env = NormalizedGymnasiumBoxEnv(self.env, obs_mean=mu, obs_std=sigma)

    def process_data(self, dict_data):
        deterministic_reward = dict_data['deterministic_reward']
        in_the_risk_zone = dict_data['in_the_risk_zone'].astype(np.float32)
        out_of_world = dict_data['out_of_world']
        # probs
        has_cost = (self.np_rng.uniform(0, 1, size=in_the_risk_zone.shape) > self.risk_prob).astype(np.float32)
        costs = self.risk_var * np.random.normal(0, 1, size=in_the_risk_zone.shape)
        penalty = has_cost * costs * in_the_risk_zone
        reward = deterministic_reward + penalty
        reward = np.where(out_of_world, reward - 10, reward)
        reward = reward

        observations = dict_data['observations']
        actions = dict_data['actions']
        next_observations = dict_data['next_observations']
        dones = dict_data['dones']
        q_learning_data_set = {"observations": observations, "actions": actions,
                               "next_observations": next_observations, "dones": dones, "rewards": reward}
        buffer = ReplayBuffer.from_qlearning_dataset(q_learning_data_set,
                                                     seed=self.np_rng.integers(0, 2 ** 32 - 1),
                                                     normalize_reward=False)

        return buffer


