# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import pdb  # pylint: disable=unused-import
import logging
import typing as tp
import dataclasses
import collections
from pathlib import Path

import numpy as np
import torch
from dm_env import specs, TimeStep
from tqdm import tqdm
from url_benchmark.replay_buffer import EpisodeBatch
from url_benchmark.dmc import ExtendedGoalTimeStep

Specs = tp.Sequence[specs.Array]
logger = logging.getLogger(__name__)

EpisodeTuple = tp.Tuple[np.ndarray, ...]
Episode = tp.Dict[str, np.ndarray]
T = tp.TypeVar("T", np.ndarray, torch.Tensor)


def episode_len(episode: Episode) -> int:
    # subtract -1 because the dummy first transition
    return next(iter(episode.values())).shape[0] - 1


def load_episode(fn: Path) -> tp.Dict[str, np.ndarray]:
    with fn.open('rb') as f:
        episode = np.load(f)
        episode = {k: episode[k] for k in episode.keys()}
    return episode  # type: ignore


def relabel_episode(env: tp.Any, episode: tp.Dict[str, np.ndarray], goal_func: tp.Any) -> tp.Dict[str, np.ndarray]:
    goals = []
    rewards = []
    states = episode['physics']
    for i in range(states.shape[0]):
        with env.physics.reset_context():
            env.physics.set_state(states[i])
        reward = env.task.get_reward(env.physics)
        reward = np.full((1,), reward, dtype=np.float32)
        rewards.append(reward)
        if goal_func is not None:
            goals.append(goal_func(env))
    episode['reward'] = np.array(rewards, dtype=np.float32)
    if goals:
        episode['goal'] = np.array(goals, dtype=np.float32)
    return episode

# class ReplayBufferIterable:
#     def __init__(self, replay_buffer: "ReplayBuffer") -> None:
#         self._replay_buffer = replay_buffer
#
#     def __next__(self) -> EpisodeBatch:
#         return self._replay_buffer.sample()


class ReplayBuffer:
    def __init__(self,
                 max_episodes: int, discount: float, future: float, max_episode_length: tp.Optional[int] = None) -> None:
        # data_specs: Specs,
        # self._data_specs = tuple(data_specs)
        # self._meta_specs = tuple(meta_specs)
        # self._batch_size = batch_size
        self._max_episodes = max_episodes
        self._discount = discount
        assert 0 <= future <= 1
        self._future = future
        self._current_episode: tp.Dict[str, tp.List[np.ndarray]] = collections.defaultdict(list)
        self._idx = 0
        self._full = False
        self._num_transitions = 0
        self._storage: tp.Dict[str, np.ndarray] = collections.defaultdict()
        self._collected_episodes = 0
        self._batch_names = set(field.name for field in dataclasses.fields(ExtendedGoalTimeStep))
        self._episodes_length = np.zeros(max_episodes, dtype=np.int32)
        self._episodes_selection_probability = None
        self._is_fixed_episode_length = True
        self._max_episode_length = max_episode_length

    def __len__(self) -> int:
        return self._max_episodes if self._full else self._idx

    def __setstate__(self, state):
        self.__dict__.update(state)
        self._backward_compatibility()

    def _backward_compatibility(self):
        if self._storage and not hasattr(self, '_episodes_length'):
            self._episodes_length = np.array([len(array) - 1 for array in self._storage["discount"]], dtype=np.int32)
            self._episodes_length[len(self):] = 0
            assert self._episodes_length[:len(self)].min() == self._episodes_length[:len(self)].max()
            self._episodes_selection_probability = None
            self._is_fixed_episode_length = True
            self._max_episode_length = None

    def restore_episode_lengths(self):
        self._episodes_length = np.array([len(array) - 1 for array in self._storage["discount"]], dtype=np.int32)
        self._episodes_length[len(self):] = 0
        assert self._episodes_length[:len(self)].min() == self._episodes_length[:len(self)].max()
        self._episodes_selection_probability = None
        self._is_fixed_episode_length = True
        self._max_episode_length = None

    def add(self, time_step: TimeStep) -> None:
        dtype = np.float32
        for field in dataclasses.fields(time_step):
            value = time_step[field.name]
            if np.isscalar(value):
                value = np.full((1,), value, dtype=dtype)
            if isinstance(value, np.ndarray):
                self._current_episode[field.name].append(np.array(value, dtype=dtype))
        if time_step.last():
            if not hasattr(self, "_batch_names"):
                self._batch_names = set(field.name for field in dataclasses.fields(ExtendedGoalTimeStep))
            for name, value_list in self._current_episode.items():
                values = np.array(value_list, dtype)
                if name not in self._storage:
                    # first iteration, the buffer is created with appropriate size
                    _shape = values.shape
                    if self._max_episode_length is not None:
                        _shape = (self._max_episode_length,) + _shape[1:]
                    self._storage[name] = np.empty((self._max_episodes,) + _shape, dtype=dtype)
                self._storage[name][self._idx][:len(values)] = values
            self._episodes_length[self._idx] = len(self._current_episode['discount']) - 1  # compensate for the dummy transition at the beginning
            if self._episodes_length[self._idx] != self._episodes_length[self._idx - 1] and self._episodes_length[self._idx - 1] != 0:
                self._is_fixed_episode_length = False
            self._current_episode = collections.defaultdict(list)
            self._collected_episodes += 1
            self._idx = (self._idx + 1) % self._max_episodes
            self._full = self._full or self._idx == 0
            self._episodes_selection_probability = None

    @property
    def avg_episode_length(self) -> int:
        return round(self._episodes_length[:len(self)].mean())
    
    def _sample_goals(self, batch_size, ep_idx, step_idx, p_randomgoal=None, p_trajgoal=None, p_currgoal=None):
        assert p_randomgoal + p_trajgoal + p_currgoal == 1.0
        geom_sample=False
        # Random goals
        # goal_indx = np.random.randint(self.dataset.size, size=batch_size)
        random_ep_idx = np.random.randint(0, len(self), size=batch_size)
        random_step_idx = np.random.randint(0, self._episodes_length[random_ep_idx]) + 1

        goals = self._storage['observation'][random_ep_idx, random_step_idx]

        # Goals from the same trajectory
        # final_state_indx = self.terminal_locs[np.searchsorted(self.terminal_locs, indx)]
        final_state_indx = self._episodes_length[ep_idx]

        distance = np.random.rand(batch_size)
        if geom_sample:
            us = np.random.rand(batch_size)
            middle_goal_indx = np.minimum(step_idx + np.ceil(np.log(1 - us) / np.log(self.discount)).astype(int), final_state_indx)
        else:
            middle_goal_indx = np.round((np.minimum(step_idx + 1, final_state_indx) * distance + final_state_indx * (1 - distance))).astype(int)

        traj_bool = np.random.rand(batch_size) < p_trajgoal / (1.0 - p_currgoal)
        ind = np.where(traj_bool)[0]
        ep_ind = ep_idx[ind]
        goals[ind] = self._storage['observation'][ep_ind, middle_goal_indx[ind]]
        # goals = np.where(np.random.rand(batch_size) < p_trajgoal / (1.0 - p_currgoal), self._storage['observation'][ep_idx, middle_goal_indx], goals)

        # Goals at the current state
        rand_bool = np.random.rand(batch_size) < p_currgoal
        ind = np.where(rand_bool)[0]
        ep_ind = ep_idx[ind]
        step_ind = step_idx[ind]
        goals[ind] = self._storage['observation'][ep_ind, step_ind]
        rewards = np.zeros((batch_size, ), dtype=np.float32)
        successes = np.ones((batch_size, ), dtype=np.float32)
        rewards = np.where(rand_bool, successes, rewards)
        # print('-------------------')
        # print(rewards.shape)
        return goals, rewards


    def sample(self, batch_size, custom_reward: tp.Optional[tp.Any] = None, with_physics: bool = False) -> EpisodeBatch:
        if not hasattr(self, "_batch_names"):
            self._batch_names = set(field.name for field in dataclasses.fields(ExtendedGoalTimeStep))
        if not isinstance(self._future, float):
            assert isinstance(self._future, bool)
            self._future = float(self._future)

        if self._is_fixed_episode_length:
            ep_idx = np.random.randint(0, len(self), size=batch_size)
        else:
            if self._episodes_selection_probability is None:
                self._episodes_selection_probability = self._episodes_length / self._episodes_length.sum()
            ep_idx = np.random.choice(np.arange(len(self._episodes_length)), size=batch_size, p=self._episodes_selection_probability)

        eps_lengths = self._episodes_length[ep_idx]
        # add +1 for the first dummy transition
        step_idx = np.random.randint(0, eps_lengths) + 1
        assert (step_idx <= eps_lengths).all()
        if self._future < 1:
            # future_idx = step_idx + np.random.randint(0, self.episode_length - step_idx + 1, size=self._batch_size)
            future_idx = step_idx + np.random.geometric(p=(1 - self._future), size=batch_size)
            future_idx = np.clip(future_idx, 0, eps_lengths)
            assert (future_idx <= eps_lengths).all()
        # meta = {name: data[ep_idx, step_idx - 1] for name, data in self._storage.items() if name not in self._batch_names}
        obs = self._storage['observation'][ep_idx, step_idx - 1]
        obs_hash = self._storage['obs_hash'][ep_idx, step_idx - 1]
        action = self._storage['action'][ep_idx, step_idx]
        next_obs = self._storage['observation'][ep_idx, step_idx]
        next_obs_hash = self._storage['obs_hash'][ep_idx, step_idx]
        phy = self._storage['physics'][ep_idx, step_idx]
        if custom_reward is not None:
            reward = np.array([[custom_reward.from_physics(p)] for p in phy], dtype=np.float32)
        else:
            reward = self._storage['reward'][ep_idx, step_idx]
        discount = self._discount * self._storage['discount'][ep_idx, step_idx]
        goal: tp.Optional[np.ndarray] = next_obs
        next_goal: tp.Optional[np.ndarray] = next_obs
        future_obs: tp.Optional[np.ndarray] = None
        future_goal: tp.Optional[np.ndarray] = None
        if 'goal' in self._storage.keys():
            goal = self._storage['goal'][ep_idx, step_idx - 1]
            next_goal = self._storage['goal'][ep_idx, step_idx]
            if self._future < 1:
                future_goal = self._storage['goal'][ep_idx, future_idx - 1]
        # elif self._future:
        if self._future < 1:
            future_obs = self._storage['observation'][ep_idx, future_idx - 1]
        additional = {}
        if with_physics:
            additional["_physics"] = phy
        # TODO remove type ignore when working
        return EpisodeBatch(obs=obs, obs_hash=obs_hash, goal=goal, action=action, reward=reward, discount=discount,
                            next_obs=next_obs,next_obs_hash=next_obs_hash, next_goal=next_goal,
                            future_obs=future_obs, future_goal=future_goal, **additional)
    
    def sample_bh(self, batch_size, num_neg_samples, custom_reward: tp.Optional[tp.Any] = None, with_physics: bool = False) -> EpisodeBatch:
        if not hasattr(self, "_batch_names"):
            self._batch_names = set(field.name for field in dataclasses.fields(ExtendedGoalTimeStep))
        if not isinstance(self._future, float):
            assert isinstance(self._future, bool)
            self._future = float(self._future)

        if self._is_fixed_episode_length:
            ep_idx = np.random.randint(0, len(self), size=batch_size)
        else:
            if self._episodes_selection_probability is None:
                self._episodes_selection_probability = self._episodes_length / self._episodes_length.sum()
            ep_idx = np.random.choice(np.arange(len(self._episodes_length)), size=batch_size, p=self._episodes_selection_probability)

        eps_lengths = self._episodes_length[ep_idx]
        # add +1 for the first dummy transition
        step_idx = np.random.randint(0, eps_lengths) + 1
        assert (step_idx <= eps_lengths).all()
        if self._future < 1:
            # future_idx = step_idx + np.random.randint(0, self.episode_length - step_idx + 1, size=self._batch_size)
            future_idx = step_idx + np.random.geometric(p=(1 - self._future), size=batch_size)
            future_idx = np.clip(future_idx, 0, eps_lengths)
            assert (future_idx <= eps_lengths).all()
        # meta = {name: data[ep_idx, step_idx - 1] for name, data in self._storage.items() if name not in self._batch_names}
        obs = self._storage['observation'][ep_idx, step_idx - 1]
        obs_hash = self._storage['obs_hash'][ep_idx, step_idx - 1]
        action = self._storage['action'][ep_idx, step_idx]
        next_obs = self._storage['observation'][ep_idx, step_idx]
        next_obs_hash = self._storage['obs_hash'][ep_idx, step_idx]
        phy = self._storage['physics'][ep_idx, step_idx]
        if custom_reward is not None:
            reward = np.array([[custom_reward.from_physics(p)] for p in phy], dtype=np.float32)
        else:
            reward = self._storage['reward'][ep_idx, step_idx]
        discount = self._discount * self._storage['discount'][ep_idx, step_idx]
        goal: tp.Optional[np.ndarray] = None
        next_goal: tp.Optional[np.ndarray] = None
        future_obs: tp.Optional[np.ndarray] = None
        future_goal: tp.Optional[np.ndarray] = None
        goal = next_obs
        next_goal = next_obs

        obs = np.repeat(obs, num_neg_samples, axis=0)
        obs_hash = np.repeat(obs_hash, num_neg_samples, axis=0)
        action = np.repeat(action, num_neg_samples, axis=0)
        next_obs = np.repeat(next_obs, num_neg_samples, axis=0)
        next_obs_hash = np.repeat(next_obs_hash, num_neg_samples, axis=0)
        phy = np.repeat(phy, num_neg_samples, axis=0)
        reward = np.repeat(reward, num_neg_samples, axis=0)
        discount = np.repeat(discount, num_neg_samples, axis=0)
        goal = np.repeat(goal, num_neg_samples, axis=0)
        next_goal = np.repeat(next_goal, num_neg_samples, axis=0)
        future_obs = np.repeat(future_obs, num_neg_samples, axis=0)
        # future_goal = np.repeat(future_goal, num_neg_samples, axis=0)

        # goals = self._sample_goals(batch_size * num_neg_samples, ep_idx, step_idx, p_randomgoal=1.0, p_trajgoal=0.0, p_currgoal=0.0)
        random_ep_idx = np.random.randint(0, len(self), size=batch_size * num_neg_samples)
        random_step_idx = np.random.randint(0, self._episodes_length[random_ep_idx], size=batch_size * num_neg_samples) + 1

        goals = self._storage['observation'][random_ep_idx, random_step_idx]

        neg_idx = np.where(np.arange(batch_size * num_neg_samples)% num_neg_samples !=0)[0]
        # goal = np.where(np.arange(batch_size) % num_neg_samples == 0, goal, goals)
        goal[neg_idx] = goals[neg_idx]
        next_goal = goal

        # if 'goal' in self._storage.keys():
        #     goal = self._storage['goal'][ep_idx, step_idx - 1]
        #     next_goal = self._storage['goal'][ep_idx, step_idx]
        # if self._future < 1:
        #     future_goal = self._storage['goal'][ep_idx, future_idx - 1]
        # elif self._future:
        if self._future < 1:
            future_obs = self._storage['observation'][ep_idx, future_idx - 1]
        additional = {}
        if with_physics:
            additional["_physics"] = phy
        # TODO remove type ignore when working
        return EpisodeBatch(obs=obs, obs_hash=obs_hash, goal=goal, action=action, reward=reward, discount=discount,
                            next_obs=next_obs,next_obs_hash=next_obs_hash, next_goal=next_goal,
                            future_obs=future_obs, future_goal=future_goal, **additional)
    

    def sample_icvf(self, batch_size, custom_reward: tp.Optional[tp.Any] = None, with_physics: bool = False) -> EpisodeBatch:
        if not hasattr(self, "_batch_names"):
            self._batch_names = set(field.name for field in dataclasses.fields(ExtendedGoalTimeStep))
        if not isinstance(self._future, float):
            assert isinstance(self._future, bool)
            self._future = float(self._future)

        if self._is_fixed_episode_length:
            ep_idx = np.random.randint(0, len(self), size=batch_size)
        else:
            if self._episodes_selection_probability is None:
                self._episodes_selection_probability = self._episodes_length / self._episodes_length.sum()
            ep_idx = np.random.choice(np.arange(len(self._episodes_length)), size=batch_size, p=self._episodes_selection_probability)

        eps_lengths = self._episodes_length[ep_idx]
        # add +1 for the first dummy transition
        step_idx = np.random.randint(0, eps_lengths) + 1
        assert (step_idx <= eps_lengths).all()
        if self._future < 1:
            # future_idx = step_idx + np.random.randint(0, self.episode_length - step_idx + 1, size=self._batch_size)
            future_idx = step_idx + np.random.geometric(p=(1 - self._future), size=batch_size)
            future_idx = np.clip(future_idx, 0, eps_lengths)
            assert (future_idx <= eps_lengths).all()
        # meta = {name: data[ep_idx, step_idx - 1] for name, data in self._storage.items() if name not in self._batch_names}
        obs = self._storage['observation'][ep_idx, step_idx - 1]
        obs_hash = self._storage['obs_hash'][ep_idx, step_idx - 1]
        action = self._storage['action'][ep_idx, step_idx]
        next_obs = self._storage['observation'][ep_idx, step_idx]
        next_obs_hash = self._storage['obs_hash'][ep_idx, step_idx]
        phy = self._storage['physics'][ep_idx, step_idx]
        # if custom_reward is not None:
        #     reward = np.array([[custom_reward.from_physics(p)] for p in phy], dtype=np.float32)
        # else:
        #     reward = self._storage['reward'][ep_idx, step_idx]
        discount = self._discount * self._storage['discount'][ep_idx, step_idx]
        goal: tp.Optional[np.ndarray] = None
        next_goal: tp.Optional[np.ndarray] = None
        future_obs: tp.Optional[np.ndarray] = None
        future_goal: tp.Optional[np.ndarray] = None
        # if 'goal' in self._storage.keys():
        #     goal = self._storage['goal'][ep_idx, step_idx - 1]
        #     next_goal = self._storage['goal'][ep_idx, step_idx]
        #     if self._future < 1:
        #         future_goal = self._storage['goal'][ep_idx, future_idx - 1]

        goal, reward = self._sample_goals(batch_size, ep_idx, step_idx, p_randomgoal=0.5, p_trajgoal=0.3, p_currgoal=0.2)
        next_goal = goal

        # elif self._future:
        if self._future < 1:
            future_obs = self._storage['observation'][ep_idx, future_idx - 1]
        additional = {}
        if with_physics:
            additional["_physics"] = phy
        # TODO remove type ignore when working
        return EpisodeBatch(obs=obs, obs_hash=obs_hash, goal=goal, action=action, reward=reward, discount=discount,
                            next_obs=next_obs,next_obs_hash=next_obs_hash, next_goal=next_goal,
                            future_obs=future_obs, future_goal=future_goal, **additional)
        

    def load(self, env: tp.Any, replay_dir: Path, relabel: bool = True, goal_func: tp.Any = None) -> None:
        eps_fns = sorted(replay_dir.glob('*.npz'))
        for eps_fn in tqdm(eps_fns):
            if self._full:
                break
            episode = load_episode(eps_fn)
            if relabel:
                episode = relabel_episode(env, episode, goal_func)
            # for field in dataclasses.fields(TimeStep):
            for name, values in episode.items():
                # values = episode[field.name]
                if name not in self._storage:
                    # first iteration, the buffer is created with appropriate size
                    self._storage[name] = np.empty((self._max_episodes,) + values.shape, dtype=np.float32)
                self._storage[name][self._idx] = np.array(values, dtype=np.float32)
            self._idx = (self._idx + 1) % self._max_episodes
            self._full = self._full or self._idx == 0

    def relabel(self, custom_reward) -> None:

        for (ep_idx, phy) in tqdm(enumerate(self._storage["physics"])):
            reward = np.array([[custom_reward.from_physics(p)] for p in phy], dtype=np.float32)
            self._storage["reward"][ep_idx] = reward
        self._max_episodes = len(self._storage["physics"])
        self._full = True

    # def __iter__(self) -> ReplayBufferIterable:
    #     ''' Returns the Iterator object '''
    #     return ReplayBufferIterable(self)

    # def __iter__(self) -> tp.Iterator[EpisodeBatch[np.ndarray]]:
    #     while True:
    #         yield self.sample()
