import numpy as np
import torch
from typing import Any, Dict, List, Optional, Tuple, Union
from utils import split_into_trajectories

TensorBatch = List[torch.Tensor]

class ReplayBuffer:
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        buffer_size: int,
        device: str = "cpu",
    ):
        self._buffer_size = buffer_size
        self._pointer = 0
        self._size = 0

        self._states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._actions = torch.zeros(
            (buffer_size, action_dim), dtype=torch.float32, device=device
        )
        self._rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        self._next_states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        self._device = device

    def _to_tensor(self, data: np.ndarray) -> torch.Tensor:
        return torch.tensor(data, dtype=torch.float32, device=self._device)

    # Loads data in d4rl format, i.e. from Dict[str, np.array].
    def load_dataset(self, data: Dict[str, np.ndarray]):
        if self._size != 0:
            raise ValueError("Trying to load data into non-empty replay buffer")
        n_transitions = data["observations"].shape[0]
        if n_transitions > self._buffer_size:
            raise ValueError(
                "Replay buffer is smaller than the dataset you are trying to load!"
            )
        self._states[:n_transitions] = self._to_tensor(data["observations"])
        self._actions[:n_transitions] = self._to_tensor(data["actions"])
        self._rewards[:n_transitions] = self._to_tensor(data["rewards"][..., None])
        self._next_states[:n_transitions] = self._to_tensor(data["next_observations"])
        self._dones[:n_transitions] = self._to_tensor(data["realterminals"][..., None])
        self._size += n_transitions
        self._pointer = min(self._size, n_transitions)

        print(f"Dataset size: {n_transitions}")

    def sample(self, batch_size: int) -> TensorBatch:
        indices = np.random.randint(0, self._size, size=batch_size)
        states = self._states[indices]
        actions = self._actions[indices]
        rewards = self._rewards[indices]
        next_states = self._next_states[indices]
        dones = self._dones[indices]
        return [states, actions, rewards, next_states, dones]

    def add_transition(
        self,
        state: np.ndarray,
        action: np.ndarray,
        reward: float,
        next_state: np.ndarray,
        done: bool,
    ):
        # Use this method to add new data into the replay buffer during fine-tuning.
        self._states[self._pointer] = self._to_tensor(state)
        self._actions[self._pointer] = self._to_tensor(action)
        self._rewards[self._pointer] = self._to_tensor(reward)
        self._next_states[self._pointer] = self._to_tensor(next_state)
        self._dones[self._pointer] = self._to_tensor(done)

        self._pointer = (self._pointer + 1) % self._buffer_size
        self._size = min(self._size + 1, self._buffer_size)

class SequenceReplayBuffer:
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        max_sequence_length: int = 1000,
        num_sequence: int = 10,
        env_name: str=None
    ):  
        self.env_name = env_name
        
        self._observations = np.zeros((num_sequence, max_sequence_length, state_dim), dtype=np.float32)
        self._actions = np.zeros((num_sequence, max_sequence_length, action_dim), dtype=np.float32)
        self._rewards = np.zeros((num_sequence, max_sequence_length, 1), dtype=np.float32)
        self._next_observations = np.zeros((num_sequence, max_sequence_length, state_dim), dtype=np.float32)
        self._terminals = np.zeros((num_sequence, max_sequence_length, 1), dtype=np.float32)

        self._max_sequence_length = max_sequence_length
        self._num_sequence = num_sequence
        self._pointer = np.zeros(num_sequence)
        self._returns = np.zeros(num_sequence)
        self._size = 0
        
    def load_dataset(self, dataset: Dict[str, np.ndarray]):
        dones_float = np.zeros_like(dataset['rewards'])

        for i in range(len(dones_float) - 1):
            if np.linalg.norm(dataset['observations'][i + 1] -
                            dataset['next_observations'][i]
                            ) > 1e-6 or dataset['terminals'][i] == 1.0:
                dones_float[i] = 1
            else:
                dones_float[i] = 0
        dones_float[-1] = 1

        if 'realterminals' in dataset:
            # We updated terminals in the dataset, but continue using
            # the old terminals for consistency with original IQL.
            masks = 1.0 - dataset['realterminals'].astype(np.float32)
        else:
            masks = 1.0 - dataset['terminals'].astype(np.float32)
        trajs = split_into_trajectories(
            observations=dataset['observations'].astype(np.float32),
            actions=dataset['actions'].astype(np.float32),
            rewards=dataset['rewards'].astype(np.float32),
            masks=masks,
            dones_float=dones_float.astype(np.float32),
            next_observations=dataset['next_observations'].astype(np.float32))
        if self.env_name.startswith('antmaze'):

            returns = [np.sum(traj["rewards"]) / (1e-4 + np.linalg.norm(traj["observations"][0][:2])) for traj in trajs]
        else:
            returns = [np.sum(traj["rewards"]) for traj in trajs]
        top_indices = np.argsort(returns)[-self._num_sequence:]
        top_episodes = [trajs[i] for i in top_indices]
        
        for i, episode in enumerate(top_episodes):
            length = episode["rewards"].shape[0]
            self._observations[i,:length] = episode["observations"]
            self._actions[i,:length] = episode["actions"]
            self._rewards[i,:length] = episode["rewards"]
            self._next_observations[i,:length] = episode["next_observations"]
            self._terminals[i,:length] = episode["dones"]

            self._pointer[i] = length
            self._size += length
            self._returns[i] = np.sum(episode["rewards"])
            if self.env_name.startswith('antmaze'):
                self._returns[i] /= np.linalg.norm(episode["observations"][0][:2])

    def update_top_episodes(self, episode: Dict[str, np.ndarray]) -> bool:
        total_return = np.sum(episode["rewards"])
        if self.env_name.startswith("antmaze"):
            total_return /= np.linalg.norm(episode["observations"][0][:2])
        min_return = np.min(self._returns)
        min_index = np.argmin(self._returns)

        if total_return > min_return:
            length = episode["rewards"].shape[0]
            self._observations[min_index,:length] = episode["observations"]
            self._actions[min_index,:length] = episode["actions"]
            self._rewards[min_index,:length] = episode["rewards"][:,None]
            self._next_observations[min_index,:length] = episode["next_observations"]
            self._terminals[min_index,:length] = episode["terminals"][:,None]
            
            self._size -= self._pointer[min_index]
            self._pointer[min_index] = length
            self._size += length
            self._returns[min_index] = total_return
            return True
        return False
            
    def get_obs_action_concat(self):
        observations = np.concatenate([self._observations[i,:int(self._pointer[i]), :] for i in range(self._num_sequence)], axis = 0)
        actions = np.concatenate([self._actions[i,:int(self._pointer[i]), :] for i in range(self._num_sequence)], axis = 0)
        return np.concatenate([observations, actions], axis = -1)

    def get_buffer_data_dict(self):
        observations = np.concatenate([self._observations[i,:int(self._pointer[i]), :] for i in range(self._num_sequence)], axis = 0)
        actions = np.concatenate([self._actions[i,:int(self._pointer[i]), :] for i in range(self._num_sequence)], axis = 0)
        next_observations = np.concatenate([self._next_observations[i,:int(self._pointer[i]), :] for i in range(self._num_sequence)], axis = 0)
        rewards = np.concatenate([self._rewards[i,:int(self._pointer[i]), :] for i in range(self._num_sequence)], axis = 0).squeeze(-1)
        terminals = np.concatenate([self._terminals[i,:int(self._pointer[i]), :] for i in range(self._num_sequence)], axis = 0).squeeze(-1)

        return {
            "observations": observations,
            "actions": actions,
            "next_observations": next_observations,
            "rewards": rewards, 
            "realterminals": terminals
        }