import numpy as np
from typing import Sequence
from flax.struct import dataclass


@dataclass
class QLearningBatch:
    observations: np.ndarray
    actions: np.ndarray
    rewards: np.ndarray
    next_observations: np.ndarray
    dones: np.ndarray


@dataclass
class CEPLearningBatch:
    observations: np.ndarray
    actions: np.ndarray
    rewards: np.ndarray
    next_observations: np.ndarray
    dones: np.ndarray
    fake_actions: np.ndarray
    fake_next_actions: np.ndarray


@dataclass
class MRCEPLearningBatch:
    observations: np.ndarray
    actions: np.ndarray
    rewards: np.ndarray
    next_observations: np.ndarray
    dones: np.ndarray
    fake_actions: np.ndarray
    fake_next_actions: np.ndarray
    fake_rewards: np.ndarray
    fake_next_observations: np.ndarray
    fake_actions_fake_next_observations: np.ndarray


class ReplayBuffer(object):
    def __init__(self,
                 observation_shape: Sequence[int],
                 action_shape: Sequence[int],
                 capacity: int = 1_000_000,
                 seed: int = 42
                 ):
        self.observation_shape = observation_shape
        self.action_shape = action_shape
        self.observations = np.empty(shape=(capacity, *observation_shape), dtype=np.float32)
        self.actions = np.empty(shape=(capacity, *action_shape), dtype=np.float32)
        self.rewards = np.empty(shape=(capacity, 1), dtype=np.float32)
        self.dones = np.empty(shape=(capacity, 1), dtype=np.float32)
        self.next_observations = np.empty(shape=(capacity, *observation_shape), dtype=np.float32)
        self.pointer = 0
        self.capacity = capacity
        self.full = False
        self.np_rng = np.random.default_rng(seed)

    def add(self, observations: np.ndarray, actions: np.ndarray,
            rewards: float, next_observations: np.ndarray, dones: bool | float):

        self.observations[self.pointer] = observations
        self.actions[self.pointer] = actions
        self.rewards[self.pointer] = rewards
        self.dones[self.pointer] = dones
        self.next_observations[self.pointer] = next_observations
        self.pointer += 1
        if self.pointer >= self.capacity:
            self.full = True

        self.pointer %= self.capacity

    def _sample(self, index):
        batch = (self.observations[index], self.actions[index], self.rewards[index],
                 self.next_observations[index], self.dones[index])
        return batch

    def sample(self, batch_size: int):
        index = self.np_rng.integers(0, len(self), size=(batch_size,))
        return QLearningBatch(*self._sample(index))

    def __len__(self):
        if self.full:
            return self.capacity
        else:
            return self.pointer

    @classmethod
    def from_npz(cls,
                 path,
                 normalize_reward: bool = False,
                 terminal_if_timeout=False,
                 seed=42):
        dataset = np.load(path, allow_pickle=True)
        observations = dataset['observations']
        observations = observations.reshape(-1, observations.shape[-1])

        observations = observations.squeeze()
        actions = dataset['actions'].reshape(-1, dataset['actions'].shape[-1])

        actions = actions.squeeze()
        rewards = dataset['rewards'][..., None].reshape(-1, 1)

        terminal = dataset['dones'][..., None].reshape(-1, 1)
        timeout = dataset['timeout'][..., None].reshape(-1, 1)
        next_observations = dataset['next_observations'].reshape(-1, dataset['next_observations'].shape[-1])

        buffer = cls(observations.shape[1:], actions.shape[1:], capacity=len(observations), seed=seed)
        buffer.observations = observations
        buffer.actions = actions
        buffer.rewards = rewards
        buffer.next_observations = next_observations
        if normalize_reward:
            buffer.rewards = (buffer.rewards - np.mean(buffer.rewards))/(np.std(buffer.rewards) + 1e-12)
        if terminal_if_timeout:
            buffer.dones = np.logical_and(terminal, timeout)
        else:
            buffer.dones = terminal
        buffer.full = True
        return buffer

    @classmethod
    def from_qlearning_dataset(cls, dictionary,
                               normalize_reward: bool = True,
                               seed: int = 42,
                               ):
        size = len(dictionary['observations'])
        observation_shape = dictionary['observations'].shape[1:]
        action_shape = dictionary['actions'].shape[1:]

        buffer = cls(observation_shape, action_shape, capacity=size, seed=seed)
        buffer.observations = dictionary['observations']
        buffer.actions = dictionary['actions']

        buffer.rewards = dictionary['rewards'].reshape(-1, 1)
        if normalize_reward:
            buffer.rewards = (buffer.rewards - buffer.rewards.mean()) / buffer.rewards.std().clip(1e-12)

        buffer.next_observations = dictionary['next_observations']
        if 'dones' in dictionary.keys():
            buffer.dones = dictionary['dones']
        else:
            buffer.dones = dictionary['terminals']
        buffer.dones = buffer.dones.reshape(-1, 1)
        buffer.full = True
        buffer.pointer = len(buffer) - 1
        return buffer

    def enumerate_index(self, start_index, end_index):
        batch = (self.observations[start_index: end_index], self.actions[start_index: end_index],
                 self.rewards[start_index: end_index],
                 self.next_observations[start_index: end_index], self.dones[start_index: end_index])
        return QLearningBatch(*batch)


class CEPReplayBuffer(ReplayBuffer):
    def __init__(self,
                 observation_shape: Sequence[int],
                 action_shape: Sequence[int],
                 n_fake_action: int = 16,
                 capacity: int = 1000_000,
                 seed: int = 42,
                 ):
        super().__init__(observation_shape,
                         action_shape,
                         capacity,
                         seed=seed
                         )
        self.n_fake_action = n_fake_action
        self.fake_actions = np.empty(shape=(capacity, n_fake_action) + self.actions.shape[1:], dtype=np.float32)
        self.fake_next_actions = np.empty(shape=(capacity, n_fake_action) + self.actions.shape[1:], dtype=np.float32)

    @classmethod
    def from_replay(cls, agent, replay: ReplayBuffer,
                    n_fake_action: int = 32,
                    batch_size=256):
        from tqdm import trange

        buffer = cls(
            observation_shape=replay.observation_shape,
            action_shape=replay.action_shape,
            n_fake_action=n_fake_action,
            capacity=len(replay))

        num_iter = len(replay) // batch_size + int(bool(len(replay) % batch_size))
        buffer.observations = replay.observations
        buffer.actions = replay.actions
        buffer.dones = replay.dones
        buffer.next_observations = replay.next_observations
        buffer.rewards = replay.rewards
        buffer.pointer = replay.pointer
        buffer.full = replay.full
        for s in trange(num_iter):
            start_index = s * batch_size
            end_index = min((s + 1) * batch_size, len(replay))
            batch = replay.enumerate_index(start_index, end_index)
            fake_action = agent.generate_fake_action(batch.observations, n_fakes=n_fake_action)
            fake_next_action = agent.generate_fake_action(batch.next_observations, n_fakes=n_fake_action)

            buffer.fake_actions[start_index:end_index] = fake_action
            buffer.fake_next_actions[start_index:end_index] = fake_next_action

        return buffer

    def _sample(self, index):
        batch = (self.observations[index], self.actions[index], self.rewards[index],
                 self.next_observations[index], self.dones[index].astype(np.float32),
                 self.fake_actions[index], self.fake_next_actions[index])

        return batch

    def sample(self, batch_size: int) -> CEPLearningBatch:
        index = self.np_rng.integers(0, len(self), size=(batch_size,))
        return CEPLearningBatch(*self._sample(index))

    def save(self, path):
        data = {
            "observation_shape": self.observation_shape,
            "action_shape": self.action_shape,
            "capacity": self.capacity,
            "observations": self.observations,
            "actions": self.actions,
            "rewards": self.rewards,
            "dones": self.dones,
            "next_observations": self.next_observations,
            "fake_actions": self.fake_actions,
            "next_fake_actions": self.fake_next_actions,
            "pointer": self.pointer,
            "full": self.full,
        }
        np.savez(path, **data)

    @classmethod
    def load(cls, path):
        loaded_data = np.load(path)

        # Extract data from the loaded file
        observation_shape = loaded_data["observation_shape"]
        action_shape = loaded_data["action_shape"]
        capacity = loaded_data["capacity"]
        observations = loaded_data["observations"]
        actions = loaded_data["actions"]
        rewards = loaded_data["rewards"]
        dones = loaded_data["dones"]
        next_observations = loaded_data["next_observations"]
        fake_actions = loaded_data["fake_actions"]
        fake_next_actions = loaded_data[
            "next_fake_actions"]
        buffer = cls(
            observation_shape=observation_shape,
            action_shape=action_shape,
            capacity=capacity
        )

        buffer.observations = observations
        buffer.actions = actions
        buffer.rewards = rewards
        buffer.dones = dones
        buffer.next_observations = next_observations
        buffer.fake_actions = fake_actions
        buffer.fake_next_actions = fake_next_actions

        buffer.full = loaded_data['full']
        buffer.pointer = loaded_data['pointer']

        return buffer


class MRCEPReplayBuffer(CEPReplayBuffer):
    def __init__(self,
                 observation_shape: Sequence[int],
                 action_shape: Sequence[int],
                 n_fake_action: int = 16,
                 capacity: int = 1000_000,
                 seed: int = 42,
                 ):
        super().__init__(observation_shape, action_shape, n_fake_action, capacity, seed)
        self.fake_next_observations = np.empty((self.capacity, self.n_fake_action) + tuple(self.observation_shape),
                                               dtype=np.float32)
        self.fake_rewards = np.empty((self.capacity, self.n_fake_action, 1), dtype=np.float32)
        self.fake_actions_fake_next_observations = np.empty((self.capacity, self.n_fake_action, self.n_fake_action) +
                                                            tuple(self.action_shape), dtype=np.float32)

    @classmethod
    def from_replay(cls, agent, replay: ReplayBuffer,
                    n_fake_action: int = 16,
                    batch_size=256):
        from tqdm import trange

        buffer = cls(
            observation_shape=replay.observation_shape,
            action_shape=replay.action_shape,
            n_fake_action=n_fake_action,
            capacity=len(replay))

        num_iter = len(replay) // batch_size + int(bool(len(replay) % batch_size))
        buffer.observations = replay.observations
        buffer.actions = replay.actions
        buffer.dones = replay.dones
        buffer.next_observations = replay.next_observations
        buffer.rewards = replay.rewards
        buffer.pointer = replay.pointer
        buffer.full = replay.full

        for s in trange(num_iter):
            start_index = s * batch_size
            end_index = min((s + 1) * batch_size, len(replay))
            batch = replay.enumerate_index(start_index, end_index)
            fake_action = agent.generate_fake_action(batch.observations)
            fake_next_action = agent.generate_fake_action(batch.next_observations)
            fake_rewards, fake_next_observation = agent.build_fake_next(batch.observations, fake_action)
            buffer.fake_actions[start_index:end_index] = fake_action
            buffer.fake_next_actions[start_index:end_index] = fake_next_action
            buffer.fake_next_observations[start_index:end_index] = fake_next_observation
            fake = agent.generate_fake_action(fake_next_observation.reshape(-1, fake_next_observation.shape[-1]))
            fake = fake.reshape(fake_action.shape[0], fake_action.shape[1], fake_action.shape[1], -1)
            buffer.fake_actions_fake_next_observations[start_index:end_index] = fake
        return buffer

    def save(self, path):
        data = {
            "observation_shape": self.observation_shape,
            "action_shape": self.action_shape,
            "capacity": self.capacity,
            "observations": self.observations,
            "actions": self.actions,
            "rewards": self.rewards,
            "dones": self.dones,
            "next_observations": self.next_observations,
            "fake_actions": self.fake_actions,
            "next_fake_actions": self.fake_next_actions,
            "fake_rewards": self.fake_rewards,
            "fake_next_observations": self.fake_next_observations,
            'fake_actions_fake_next_observations': self.fake_actions_fake_next_observations,
            "pointer": self.pointer,
            "full": self.full,
        }
        np.savez(path, **data)

    @classmethod
    def load(cls, path):
        loaded_data = np.load(path)
        # Extract data from the loaded file
        observation_shape = loaded_data["observation_shape"]
        action_shape = loaded_data["action_shape"]
        capacity = loaded_data["capacity"]
        observations = loaded_data["observations"]
        actions = loaded_data["actions"]
        rewards = loaded_data["rewards"]
        dones = loaded_data["dones"]
        next_observations = loaded_data["next_observations"]
        fake_actions = loaded_data["fake_actions"]
        fake_next_observations = loaded_data['fake_next_observations']
        fake_rewards = loaded_data['fake_rewards']
        fake_next_actions = loaded_data["next_fake_actions"]
        fake_actions_fake_next_observations = loaded_data["fake_actions_fake_next_observations"]
        buffer = cls(
            observation_shape=observation_shape,
            action_shape=action_shape,
            capacity=capacity
        )

        buffer.observations = observations
        buffer.actions = actions
        buffer.rewards = rewards
        buffer.dones = dones
        buffer.next_observations = next_observations
        buffer.fake_actions = fake_actions
        buffer.fake_next_actions = fake_next_actions
        buffer.fake_rewards = fake_rewards
        buffer.fake_next_observations = fake_next_observations
        buffer.fake_actions_fake_next_observations = fake_actions_fake_next_observations
        buffer.full = loaded_data['full']
        buffer.pointer = loaded_data['pointer']

        return buffer

    def _sample(self, index):
        batch = (self.observations[index], self.actions[index], self.rewards[index],
                 self.next_observations[index], self.dones[index],
                 self.fake_actions[index], self.fake_next_actions[index], self.fake_rewards[index],
                 self.fake_next_observations[index], self.fake_actions_fake_next_observations[index])
        return batch

    def sample(self, batch_size: int) -> MRCEPLearningBatch:
        index = self.np_rng.integers(0, len(self), size=(batch_size,))
        return MRCEPLearningBatch(*self._sample(index))



@dataclass
class SequentialQLearningBatch:
    observations: np.ndarray
    actions: np.ndarray
    history: np.ndarray
    next_history: np.ndarray
    rewards: np.ndarray
    next_observations: np.ndarray
    dones: np.ndarray
    masks: np.ndarray


class SequentialReplayBuffer(ReplayBuffer):
    def __init__(self,
                 observation_shape,
                 action_shape,
                 capacity: int = int(1e+4),
                 sample_length: int = 200,
                 n_envs: int = 1,
                 seed: int = 42
                 ):
        super().__init__(
            observation_shape=observation_shape,
            action_shape=action_shape,
            capacity=capacity,
            seed=seed,
        )
        self.sample_length = sample_length
        self.n_envs = n_envs

    def fix_buffer(self):
        self.observations = np.ndarray(self.observations)
        self.actions = np.ndarray(self.actions)

    def _sample_multi_env(self, index, env_index):
        def compose_fn(x):
            return np.lib.stride_tricks.sliding_window_view(x, window_shape=(self.sample_length, 1)).swapaxes(1, 2).squeeze(-1)
        window_view_fn = compose_fn # partial(np.lib.stride_tricks.as_strided, window_shape=(self.sample_length, 1))
        # print(self.observations[:, env_index].shape)
        target_obs = window_view_fn(self.observations[:, env_index])
        # print(target_obs.shape)
        target_actions = window_view_fn(self.actions[:, env_index])
        target_rewards = window_view_fn(self.rewards[:, env_index])
        target_next_observations = window_view_fn(self.next_observations[:, env_index])
        target_dones = window_view_fn(self.dones[:, env_index])
        dones = target_dones[index]
        masks = np.cumsum(dones, axis=1).astype(np.float32).clip(0, 1)
        obs = target_obs[index]
        action = target_actions[index]
        next_history = np.concatenate([obs, action], axis=-1).copy()
        history = np.roll(next_history, shift=1, axis=-2).copy()
        history[:, 0, :] = 0.

        batch = (obs,
                 action,
                 history,
                 next_history,
                 target_rewards[index],
                 target_next_observations[index],
                 target_dones[index],
                 masks
                 )
        return batch

    def sample(self, batch_size: int):
        index = self.np_rng.integers(0, len(self) - self.sample_length, size=(batch_size,))
        env_index = self.np_rng.integers(0, self.n_envs,)
        return SequentialQLearningBatch(*self._sample_multi_env(index, env_index))

    @classmethod
    def from_qlearning_dataset(cls,
                               dictionary,
                               normalize_reward: bool = True,
                               sample_length: int = 200,
                               seed: int = 42,
                               ):
        size = len(dictionary['observations'])
        observation_shape = dictionary['observations'].shape[1:]
        action_shape = dictionary['actions'].shape[1:]
        buffer = cls(observation_shape, action_shape, capacity=size, seed=seed, sample_length=sample_length)
        buffer.observations = dictionary['observations']
        buffer.observations = buffer.observations[:, None]
        buffer.actions = dictionary['actions']
        buffer.actions = buffer.actions[:, None]
        buffer.rewards = dictionary['rewards'].reshape(-1, 1, 1)
        if normalize_reward:
            buffer.rewards = (buffer.rewards - buffer.rewards.mean()) / buffer.rewards.std().clip(1e-12)

        buffer.next_observations = dictionary['next_observations']
        buffer.next_observations = buffer.next_observations[:, None]
        if 'dones' in dictionary.keys():
            buffer.dones = dictionary['dones']
        else:
            buffer.dones = dictionary['terminals']
        buffer.dones = buffer.dones.reshape(-1, 1, 1)
        buffer.full = True
        buffer.pointer = len(buffer) - 1
        return buffer

