from collections import deque
import gzip
import pickle
from itertools import islice

import numpy as np

from softlearning.utils.numpy import softmax
from .replay_pool import ReplayPool


def random_int_with_variable_range(mins, maxs):
    result = np.floor(np.random.uniform(mins, maxs)).astype(int)
    return result


class TrajectoryReplayPool(ReplayPool):
    def __init__(self,
                 observation_space,
                 action_space,
                 max_size,
                 obs_filter=False,
                 modify_rew=False):
        super(TrajectoryReplayPool, self).__init__()

        max_size = int(max_size)
        self._max_size = max_size

        self._trajectories = deque(maxlen=max_size)
        self._trajectory_lengths = deque(maxlen=max_size)
        self._num_samples = 0
        self._trajectories_since_save = 0

    @property
    def num_trajectories(self):
        return len(self._trajectories)

    @property
    def size(self):
        return sum(self._trajectory_lengths)

    @property
    def num_samples(self):
        return self._num_samples

    def add_paths(self, trajectories):
        self._trajectories += trajectories
        self._trajectory_lengths += [
            trajectory[next(iter(trajectory.keys()))].shape[0]
            for trajectory in trajectories
        ]
        self._trajectories_since_save += len(trajectories)

    def add_path(self, trajectory):
        self.add_paths([trajectory])

    def add_sample(self, sample):
        raise NotImplementedError(
            f"{self.__class__.__name__} only supports adding full paths at"
            " once.")

    def add_samples(self, samples):
        raise NotImplementedError(
            f"{self.__class__.__name__} only supports adding full paths at"
            " once.")

    def batch_by_indices(self,
                         episode_indices,
                         step_indices,
                         field_name_filter=None,
                         n_step=None,
                         Qs=None,
                         policy=None,
                         discount=None,
                         use_max_n_step=False):
        assert len(episode_indices) == len(step_indices)

        batch_size = len(episode_indices)
        trajectories = [self._trajectories[i] for i in episode_indices]

        batch = {
            field_name: np.empty(
                (batch_size, *values.shape[1:]), dtype=values.dtype)
            for field_name, values in trajectories[0].items()
        }
        if n_step is not None:
            next_observations = np.concatenate([trajectories[i]['next_observations'][step_indices[i]:step_indices[i]+min(n_step, trajectories[i]['next_observations'].shape[0] - step_indices[i])] for i in range(len(trajectories))], axis=0)
            # next_observations_tile = np.reshape(np.tile(np.expand_dims(next_observations, axis=1),
            #                                         (1, 10, 1)),
            #                                         [-1, trajectories[0]['next_observations'].shape[-1]])
            # next_actions_temp = policy.actions_np([next_observations_tile])
            # next_Qs_values = tuple(
            #     np.amax(np.reshape(Q.predict([next_observations_tile, next_actions_temp]), [-1, 10, 1]), axis=1)
            #     for Q in Qs)
            next_actions = policy.actions_np([next_observations])
            next_Qs_values = tuple(
                        Q.predict([next_observations, next_actions])
                        for Q in Qs)
            batch['q_targets'] = np.empty((batch_size, 1), dtype=np.float32)

        start_index = 0
        for i, episode in enumerate(trajectories):
            for field_name, episode_values in episode.items():
                batch[field_name][i] = episode_values[step_indices[i]]
            if n_step is not None:
                n_real_step = min(n_step, trajectories[i]['next_observations'].shape[0] - step_indices[i])
                q_target = episode['rewards'][step_indices[i]]
                # next_actions = policy.actions_np([episode['next_observations'][step_indices[i]][None]])
                # next_Qs_values = tuple(
                #     Q.predict([episode['next_observations'][step_indices[i]][None], next_actions])
                #     for Q in Qs)
                next_Qs_values_episode = tuple(
                    next_Q_values[start_index]
                    for next_Q_values in next_Qs_values)
                q_targets = [q_target + discount * (1.0 - episode['terminals'][step_indices[i]]) * np.amin(next_Qs_values_episode)]
                for j in range(1, min(n_step, episode_values.shape[0] - step_indices[i])):
                    q_target += discount**j * episode['rewards'][step_indices[i]+j]
                    # next_actions = policy.actions_np([episode['next_observations'][step_indices[i]+j][None]])
                    # next_Qs_values = tuple(
                    #     Q.predict([episode['next_observations'][step_indices[i]+j][None], next_actions])
                    #     for Q in Qs)
                    next_Qs_values_episode = tuple(
                        next_Q_values[start_index+j]
                        for next_Q_values in next_Qs_values)
                    q_targets.append(q_target + discount**(j+1) * (1.0 - episode['terminals'][step_indices[i]+j]) * np.amin(next_Qs_values_episode))
                if use_max_n_step:
                    chosen_q_target = np.amax(q_targets)
                else:
                    chosen_q_target = q_targets[-1]
                batch['q_targets'][i] = chosen_q_target
                start_index += n_real_step
        if n_step is not None:
            try:
                assert start_index == next_observations.shape[0]
            except:
                import pdb; pdb.set_trace()
        return batch

    def random_batch(self, batch_size, n_step=None, Qs=None, policy=None, discount=None, use_max_n_step=False, *args, **kwargs):
        num_trajectories = len(self._trajectories)
        if num_trajectories < 1:
            return {}

        trajectory_lengths = np.array(self._trajectory_lengths)
        trajectory_weights = trajectory_lengths / np.sum(trajectory_lengths)
        # trajectory_probabilities = softmax(trajectory_weights)
        trajectory_probabilities = trajectory_weights

        trajectory_indices = np.random.choice(
            np.arange(num_trajectories),
            size=batch_size,
            replace=True,
            p=trajectory_probabilities)
        first_key = next(iter(
            self._trajectories[trajectory_indices[0]].keys()))
        trajectory_lengths = np.array([
            self._trajectories[trajectory_index][first_key].shape[0]
            for trajectory_index in trajectory_indices
        ])

        step_indices = random_int_with_variable_range(
            np.zeros_like(trajectory_lengths, dtype=np.int64),
            trajectory_lengths)

        batch = self.batch_by_indices(trajectory_indices, step_indices, n_step=n_step, Qs=Qs, policy=policy, discount=discount, use_max_n_step=use_max_n_step)
        return batch

    def last_n_batch(self, last_n, field_name_filter=None, **kwargs):
        num_trajectories = len(self._trajectories)
        if num_trajectories < 1:
            return {}

        trajectory_indices = []
        step_indices = []

        trajectory_lengths = 0
        for trajectory_index in range(num_trajectories-1, -1, -1):
            trajectory = self._trajectories[trajectory_index]
            trajectory_length = trajectory[list(trajectory.keys())[0]].shape[0]

            steps_from_this_episode = min(trajectory_length, last_n - trajectory_lengths)
            step_indices += list(range(
                trajectory_length-1,
                trajectory_length - steps_from_this_episode - 1,
                -1))
            trajectory_indices += [trajectory_index] * steps_from_this_episode

            trajectory_lengths += trajectory_length

            if trajectory_lengths >= last_n:
                break

        trajectory_indices = trajectory_indices[::-1]
        step_indices = step_indices[::-1]

        batch = self.batch_by_indices(trajectory_indices, step_indices)

        return batch

    def save_latest_experience(self, pickle_path):
        # deque doesn't support direct slicing, thus need to use islice
        num_trajectories = self.num_trajectories
        start_index = max(num_trajectories - self._trajectories_since_save, 0)
        end_index = num_trajectories

        latest_trajectories = tuple(islice(
            self._trajectories, start_index, end_index))

        with gzip.open(pickle_path, 'wb') as f:
            pickle.dump(latest_trajectories, f)

        self._trajectories_since_save = 0

    def load_experience(self, experience_path):
        with gzip.open(experience_path, 'rb') as f:
            latest_trajectories = pickle.load(f)

        self.add_paths(latest_trajectories)
        self._trajectories_since_save = 0

    def return_all_samples(self):
        return {
            field_name: np.concatenate([trajectory[field_name] for trajectory in self._trajectories], axis=0)
            for field_name in self._trajectories[0].keys()
        }
