import collections
from typing import Optional
import jax
import jax.numpy as jnp
import d4rl
import gym
import numpy as np
from tqdm import tqdm

Batch = collections.namedtuple(
    'Batch',
    ['observations', 'actions', 'rewards', 'masks', 'next_observations'])


def split_into_trajectories(observations, actions, rewards, masks, dones_float,
                            next_observations):
    trajs = [[]]

    for i in tqdm(range(len(observations))):
        trajs[-1].append((observations[i], actions[i], rewards[i], masks[i],
                          dones_float[i], next_observations[i]))
        if dones_float[i] == 1.0 and i + 1 < len(observations):
            trajs.append([])

    return trajs


def merge_trajectories(trajs):
    observations = []
    actions = []
    rewards = []
    masks = []
    dones_float = []
    next_observations = []

    for traj in trajs:
        for (obs, act, rew, mask, done, next_obs) in traj:
            observations.append(obs)
            actions.append(act)
            rewards.append(rew)
            masks.append(mask)
            dones_float.append(done)
            next_observations.append(next_obs)

    return np.stack(observations), np.stack(actions), np.stack(
        rewards), np.stack(masks), np.stack(dones_float), np.stack(
        next_observations)


class Dataset(object):
    def __init__(self, observations: np.ndarray, actions: np.ndarray,
                 rewards: np.ndarray, masks: np.ndarray,
                 dones_float: np.ndarray, next_observations: np.ndarray,
                 size: int):
        self.observations = observations
        self.actions = actions
        self.rewards = rewards
        self.masks = masks
        self.dones_float = dones_float
        self.next_observations = next_observations
        self.size = size

    def sample(self, batch_size: int) -> Batch:
        indx = np.random.randint(self.size, size=batch_size)
        return Batch(observations=self.observations[indx],
                     actions=self.actions[indx],
                     rewards=self.rewards[indx],
                     masks=self.masks[indx],
                     next_observations=self.next_observations[indx])


class D4RLDataset(Dataset):
    def __init__(self,
                 env: gym.Env,
                 clip_to_eps: bool = True,
                 eps: float = 1e-5):
        dataset = d4rl.qlearning_dataset(env)

        if clip_to_eps:
            lim = 1 - eps
            dataset['actions'] = np.clip(dataset['actions'], -lim, lim)

        dones_float = np.zeros_like(dataset['rewards'])

        for i in range(len(dones_float) - 1):
            if np.linalg.norm(dataset['observations'][i + 1] -
                              dataset['next_observations'][i]
                              ) > 1e-6 or dataset['terminals'][i] == 1.0:
                dones_float[i] = 1
            else:
                dones_float[i] = 0

        dones_float[-1] = 1

        super().__init__(dataset['observations'].astype(np.float32),
                         actions=dataset['actions'].astype(np.float32),
                         rewards=dataset['rewards'].astype(np.float32),
                         masks=1.0 - dataset['terminals'].astype(np.float32),
                         dones_float=dones_float.astype(np.float32),
                         next_observations=dataset['next_observations'].astype(
                             np.float32),
                         size=len(dataset['observations']))


class PartialD4RLDataset(D4RLDataset):
    def __init__(self,
                 env: gym.Env,
                 clip_to_eps: bool = True,
                 eps: float = 1e-5,
                 selected_index=None,
                 ):
        super(PartialD4RLDataset, self).__init__(
            env=env,
            clip_to_eps=clip_to_eps,
            eps=eps, )
        self.select_index = selected_index
        if self.select_index is not None:
            self.observations = self.observations[self.select_index]
            self.actions = self.actions[self.select_index]
            self.rewards = self.rewards[self.select_index]
            self.masks = self.masks[self.select_index]
            self.dones_float = self.dones_float[self.select_index]
            self.next_observations = self.next_observations[self.select_index]
            self.size = self.select_index.shape[0]


class PrioritizedReplayBuffer:
    def __init__(self, observation_space: gym.spaces.Box, action_dim: int, capacity: int,
                 alpha=0.6, beta=0.4, epsilon=1e-3):

        # Buffer for storing experiences
        self.observations = np.empty((capacity, *observation_space.shape),
                                     dtype=observation_space.dtype)
        self.actions = np.empty((capacity, action_dim), dtype=np.float32)
        self.rewards = np.empty((capacity,), dtype=np.float32)
        self.masks = np.empty((capacity,), dtype=np.float32)
        self.dones_float = np.empty((capacity,), dtype=np.float32)
        self.next_observations = np.empty((capacity, *observation_space.shape),
                                          dtype=observation_space.dtype)
        
        self.capacity = capacity
        self.alpha = alpha  # prioritization strength
        self.beta = beta  # importance sampling weight
        self.epsilon = epsilon  # small constant to avoid zero priorities
        self.priorities = np.zeros((capacity,), dtype=np.float32)
        self.position = 0
        self.size = 0

    def insert(self, observation: np.ndarray, action: np.ndarray,
               reward: float, mask: float, done_float: float,
               next_observation: np.ndarray, priority: float):

        self.observations[self.position] = observation
        self.actions[self.position] = action
        self.rewards[self.position] = reward
        self.masks[self.position] = mask
        self.dones_float[self.position] = done_float
        self.next_observations[self.position] = next_observation

        # Set priority based on error
        self.priorities[self.position] = priority

        self.position = (self.position + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)

    def sample(self, batch_size: int):
        """Sample a batch of experiences based on their priorities."""
        if self.size == self.capacity:
            priorities = self.priorities
        else:
            priorities = self.priorities[:self.position]

        probs = priorities / np.sum(priorities)
        
        indices = np.random.choice(len(probs), batch_size, p=probs)

        samples = Batch(
            observations=self.observations[indices],
            actions=self.actions[indices],
            rewards=self.rewards[indices],
            masks=self.masks[indices],
            next_observations=self.next_observations[indices]
        )
    
        return samples, indices
    
    def update_priorities(self, indices, priorities):
        self.priorities[indices] = priorities

    def initialize_with_dataset(self, dataset, num_samples: Optional[int] = None):
        """Initialize the buffer with an offline dataset."""
        dataset_size = len(dataset.observations)

        # Determine how many samples to load (default is the full dataset)
        if num_samples is None:
            num_samples = dataset_size
        else:
            num_samples = min(dataset_size, num_samples)
        assert self.capacity >= num_samples, 'Dataset cannot be larger than the replay buffer capacity.'

        # Shuffle the dataset if loading a subset
        if num_samples < dataset_size:
            indices = np.random.permutation(dataset_size)[:num_samples]
        else:
            indices = np.arange(num_samples)

        # Load data into the buffer
        self.observations[:num_samples] = dataset.observations[indices]
        self.actions[:num_samples] = dataset.actions[indices]
        self.rewards[:num_samples] = dataset.rewards[indices]
        self.masks[:num_samples] = dataset.masks[indices]
        self.dones_float[:num_samples] = dataset.dones_float[indices]
        self.next_observations[:num_samples] = dataset.next_observations[indices]

        # Set all priorities to the max priority
        self.priorities[:num_samples] = 1

        # Update buffer's state
        self.position = num_samples
        self.size = num_samples