from typing import Optional
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.vec_env import VecNormalize
from rl.utils.base_wrapper import NormalizedGymnasiumBoxEnv
from rl.utils.replay_buffer import QLearningBatch
from airsims.airsim_env import AirSimEnv
import pickle
from gymnasium.wrappers import TimeLimit


class NumpyBuffer(ReplayBuffer):
    def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> QLearningBatch:
        batch: ReplayBuffer = super(ReplayBuffer, self).sample(batch_size)
        new_batch: QLearningBatch = QLearningBatch(batch.observations.cpu().numpy(), batch.actions.cpu().numpy(),
                                                   batch.rewards.cpu().numpy(), batch.next_observations.cpu().numpy(),
                                                   batch.dones.cpu().numpy())
        return new_batch

    def __len__(self):
        if self.full:
            return self.observations.shape[0]
        else:
            return self.pos


class AirsimPreprocessor(object):
    def __init__(self,
                 ip: str,
                 path: str,
                 normalize_obs: bool = True,
                 normalize_reward: bool = False,
                 verbose: bool = False,
                 terminal_if_success: bool = True,
                 hard: bool = False,
                 seed: int = 0,
                 ):
        self.gymnasium_env = AirSimEnv(ip=ip, seed=seed, verbose=verbose,
                                       terminal_if_success=terminal_if_success, hard=hard, )
        self.normalize_reward = normalize_reward
        with open(path, 'rb') as f:
            self.buffer: ReplayBuffer = pickle.load(f)
            self.buffer.__class__ = NumpyBuffer
        self.buffer: NumpyBuffer
        self.obs_mean = self.buffer.observations[:len(self.buffer)].mean(axis=(0, 1), keepdims=True)
        self.obs_std = self.buffer.observations[:len(self.buffer)].std(axis=(0, 1), keepdims=True).clip(1e-12)

        if normalize_obs:
            def normalize(x, mean, std):
                return (x - mean) / (std.clip(1e-12, ))

            self.buffer.observations = normalize(self.buffer.observations, self.obs_mean, self.obs_std)
            self.buffer.next_observations = normalize(self.buffer.next_observations, self.obs_mean, self.obs_std)
            self.gymnasium_env = NormalizedGymnasiumBoxEnv(self.gymnasium_env,
                                                           obs_mean=self.obs_mean, obs_std=self.obs_std)
        self.gymnasium_env = TimeLimit(self.gymnasium_env, max_episode_steps=1000)

    @property
    def env(self):
        return self.gymnasium_env

    def get_replay_buffer(self):
        return self.buffer
