# source: https://github.com/nakamotoo/Cal-QL/tree/main
# https://arxiv.org/pdf/2303.05479.pdf
import os
import random
import uuid
from copy import deepcopy
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import d4rl
import gym
import numpy as np
import pyrallis
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from torch.distributions import Normal, TanhTransform, TransformedDistribution
from utils import split_into_trajectories

TensorBatch = List[torch.Tensor]

ENVS_WITH_GOAL = ("antmaze", "pen", "door", "hammer", "relocate")

class ReplayBuffer:
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        buffer_size: int,
        device: str = "cpu",
    ):
        self._buffer_size = buffer_size
        self._pointer = 0
        self._size = 0

        self._states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._actions = torch.zeros(
            (buffer_size, action_dim), dtype=torch.float32, device=device
        )
        self._rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        self._next_states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        self._mc_returns = torch.zeros(
            (buffer_size, 1), dtype=torch.float32, device=device
        )

        self._device = device

    def _to_tensor(self, data: np.ndarray) -> torch.Tensor:
        return torch.tensor(data, dtype=torch.float32, device=self._device)

    # Loads data in d4rl format, i.e. from Dict[str, np.array].
    def load_dataset(self, data: Dict[str, np.ndarray]):
        if self._size != 0:
            raise ValueError("Trying to load data into non-empty replay buffer")
        n_transitions = data["observations"].shape[0]
        if n_transitions > self._buffer_size:
            raise ValueError(
                "Replay buffer is smaller than the dataset you are trying to load!"
            )
        self._states[:n_transitions] = self._to_tensor(data["observations"])
        self._actions[:n_transitions] = self._to_tensor(data["actions"])
        self._rewards[:n_transitions] = self._to_tensor(data["rewards"][..., None])
        self._next_states[:n_transitions] = self._to_tensor(data["next_observations"])
        self._dones[:n_transitions] = self._to_tensor(data["realterminals"][..., None])
        self._mc_returns[:n_transitions] = self._to_tensor(data["mc_returns"][..., None])
        self._size += n_transitions
        self._pointer = min(self._size, n_transitions)

        print(f"Dataset size: {n_transitions}")

    def sample(self, batch_size: int) -> TensorBatch:
        indices = np.random.randint(0, self._size, size=batch_size)
        states = self._states[indices]
        actions = self._actions[indices]
        rewards = self._rewards[indices]
        next_states = self._next_states[indices]
        dones = self._dones[indices]
        mc_returns = self._mc_returns[indices]
        return [states, actions, rewards, next_states, dones, mc_returns]

    def add_transition(
        self,
        state: np.ndarray,
        action: np.ndarray,
        reward: float,
        next_state: np.ndarray,
        done: bool,
    ):
        # Use this method to add new data into the replay buffer during fine-tuning.
        self._states[self._pointer] = self._to_tensor(state)
        self._actions[self._pointer] = self._to_tensor(action)
        self._rewards[self._pointer] = self._to_tensor(reward)
        self._next_states[self._pointer] = self._to_tensor(next_state)
        self._dones[self._pointer] = self._to_tensor(done)
        self._mc_returns[self._pointer] = 0.0

        self._pointer = (self._pointer + 1) % self._buffer_size
        self._size = min(self._size + 1, self._buffer_size)


class SequenceReplayBuffer:
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        max_sequence_length: int = 1000,
        num_sequence: int = 10,
        env_name: str=None
    ):  
        self.env_name = env_name
        
        self._observations = np.zeros((num_sequence, max_sequence_length, state_dim), dtype=np.float32)
        self._actions = np.zeros((num_sequence, max_sequence_length, action_dim), dtype=np.float32)
        self._rewards = np.zeros((num_sequence, max_sequence_length, 1), dtype=np.float32)
        self._next_observations = np.zeros((num_sequence, max_sequence_length, state_dim), dtype=np.float32)
        self._terminals = np.zeros((num_sequence, max_sequence_length, 1), dtype=np.float32)

        self._max_sequence_length = max_sequence_length
        self._num_sequence = num_sequence
        self._pointer = np.zeros(num_sequence)
        self._returns = np.zeros(num_sequence)
        self._size = 0
        
    def load_dataset(self, dataset: Dict[str, np.ndarray]):
        dones_float = np.zeros_like(dataset['rewards'])

        for i in range(len(dones_float) - 1):
            if np.linalg.norm(dataset['observations'][i + 1] -
                            dataset['next_observations'][i]
                            ) > 1e-6 or dataset['terminals'][i] == 1.0:
                dones_float[i] = 1
            else:
                dones_float[i] = 0
        dones_float[-1] = 1

        if 'realterminals' in dataset:
            # We updated terminals in the dataset, but continue using
            # the old terminals for consistency with original IQL.
            masks = 1.0 - dataset['realterminals'].astype(np.float32)
        else:
            masks = 1.0 - dataset['terminals'].astype(np.float32)
        trajs = split_into_trajectories(
            observations=dataset['observations'].astype(np.float32),
            actions=dataset['actions'].astype(np.float32),
            rewards=dataset['rewards'].astype(np.float32),
            masks=masks,
            dones_float=dones_float.astype(np.float32),
            next_observations=dataset['next_observations'].astype(np.float32))
        
        if self.env_name.startswith('antmaze'):

            returns = [np.sum(traj["rewards"]) / (1e-4 + np.linalg.norm(traj["observations"][0][:2])) for traj in trajs]
        else:
            returns = [np.sum(traj["rewards"]) for traj in trajs]
            
        self._num_sequence = min(self._num_sequence, len(returns))
        top_indices = np.argsort(returns)[-self._num_sequence:]
        top_episodes = [trajs[i] for i in top_indices]
        
        for i, episode in enumerate(top_episodes):
            length = episode["rewards"].shape[0]
            self._observations[i,:length] = episode["observations"]
            self._actions[i,:length] = episode["actions"]
            self._rewards[i,:length] = episode["rewards"]
            self._next_observations[i,:length] = episode["next_observations"]
            self._terminals[i,:length] = episode["dones"]

            self._pointer[i] = length
            self._size += length
            self._returns[i] = np.sum(episode["rewards"])
            if self.env_name.startswith('antmaze'):
                self._returns[i] /= np.linalg.norm(episode["observations"][0][:2])

    def update_top_episodes(self, episode: Dict[str, np.ndarray]) -> bool:
        total_return = np.sum(episode["rewards"])
        if self.env_name.startswith("antmaze"):
            total_return /= np.linalg.norm(episode["observations"][0][:2])
        min_return = np.min(self._returns)
        min_index = np.argmin(self._returns)

        if total_return > min_return:
            length = episode["rewards"].shape[0]
            self._observations[min_index,:length] = episode["observations"]
            self._actions[min_index,:length] = episode["actions"]
            self._rewards[min_index,:length] = episode["rewards"][:,None]
            self._next_observations[min_index,:length] = episode["next_observations"]
            self._terminals[min_index,:length] = episode["terminals"][:,None]
            
            self._size -= self._pointer[min_index]
            self._pointer[min_index] = length
            self._size += length
            self._returns[min_index] = total_return
            return True
        return False
            
    def get_obs_action_concat(self):
        observations = np.concatenate([self._observations[i,:int(self._pointer[i]), :] for i in range(self._num_sequence)], axis = 0)
        actions = np.concatenate([self._actions[i,:int(self._pointer[i]), :] for i in range(self._num_sequence)], axis = 0)
        return np.concatenate([observations, actions], axis = -1)

    def get_buffer_data_dict(self):
        observations = np.concatenate([self._observations[i,:int(self._pointer[i]), :] for i in range(self._num_sequence)], axis = 0)
        actions = np.concatenate([self._actions[i,:int(self._pointer[i]), :] for i in range(self._num_sequence)], axis = 0)
        next_observations = np.concatenate([self._next_observations[i,:int(self._pointer[i]), :] for i in range(self._num_sequence)], axis = 0)
        rewards = np.concatenate([self._rewards[i,:int(self._pointer[i]), :] for i in range(self._num_sequence)], axis = 0).squeeze(-1)
        terminals = np.concatenate([self._terminals[i,:int(self._pointer[i]), :] for i in range(self._num_sequence)], axis = 0).squeeze(-1)

        return {
            "observations": observations,
            "actions": actions,
            "next_observations": next_observations,
            "rewards": rewards, 
            "realterminals": terminals
        }