from collections import deque, OrderedDict

from lifelong_rl.util.eval_util import create_stats_ordered_dict
from lifelong_rl.samplers.utils.rollout_functions import rollout_with_latent
from lifelong_rl.samplers import rollout, multitask_rollout
from lifelong_rl.samplers import PathCollector


class MdpPathCollector(PathCollector):

    def __init__(
            self,
            env,
            embedder,
            policy,
            max_num_epoch_paths_saved=None,
            render=False,
            render_kwargs=None,
    ):
        if render_kwargs is None:
            render_kwargs = {}

        self._env = env
        self._policy = policy
        self._embedder = embedder
        self._max_num_epoch_paths_saved = max_num_epoch_paths_saved
        self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved)
        self._render = render
        self._render_kwargs = render_kwargs

        self._num_steps_total = 0
        self._num_paths_total = 0

    def rollout_function(self, *args, **kwargs):
        return rollout(*args, **kwargs)

    def reset_policy(self):
        self._policy.reset()
        self._policy.eval()

    def finish_path(self, path):
        return

    def end_path_collection(self):
        self._policy.train()
    
    # CHANGE: add sample mode
    def collect_new_paths(
            self,
            max_path_length,
            num_samples,
            alpha=None,
            sample_mode='steps',
            discard_incomplete_paths=False,
    ):
        paths = []
        num_steps_collected = 0

        if sample_mode == 'steps':
            while num_steps_collected < num_samples:
                max_path_length_this_loop = min(  # Do not go over num_steps
                    max_path_length,
                    num_samples - num_steps_collected,
                )
                self.reset_policy()
                path = self.rollout_function(
                    env=self._env,
                    embedder=self._embedder,
                    agent=self._policy,
                    max_path_length=max_path_length_this_loop,
                )
                path_len = len(path['actions'])
                if (
                        path_len != max_path_length
                        and not path['terminals'][-1]
                        and discard_incomplete_paths
                ):
                    break
                num_steps_collected += path_len
                self.finish_path(path)
                paths.append(path)
        elif sample_mode == 'paths':
            for _ in range(num_samples):
                self.reset_policy()
                path = self.rollout_function(
                    self._env,
                    self._policy,
                    max_path_length=max_path_length,
                )
                path_len = len(path['actions'])
                num_steps_collected += path_len
                self.finish_path(path)
                paths.append(path)
        else:
            raise NotImplementedError

        self._num_paths_total += len(paths)
        self._num_steps_total += num_steps_collected
        self._epoch_paths.extend(paths)
        self.end_path_collection()
        return paths

    def get_epoch_paths(self):
        return self._epoch_paths

    def end_epoch(self, epoch):
        self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved)

    def get_diagnostics(self):
        path_lens = [len(path['actions']) for path in self._epoch_paths]
        stats = OrderedDict([
            ('num steps total', self._num_steps_total),
            ('num paths total', self._num_paths_total),
        ])
        stats.update(create_stats_ordered_dict(
            "path length",
            path_lens,
            always_show_all_stats=True,
        ))
        return stats

    def get_snapshot(self):
        return dict(
            env=self._env,
            policy=self._policy,
        )


class LatentPathCollector(MdpPathCollector):

    """
    At the beginning of each trajectory, sample a latent to feed as input
    to a PriorLatentPolicy.
    """

    def __init__(self, sample_latent_every=None, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.sample_latent_every = sample_latent_every
        self.prev_latent = None
        self.rollout_func = rollout_with_latent

    def rollout_function(self, *args, **kwargs):
        return rollout_with_latent(sample_latent_every=self.sample_latent_every, *args, **kwargs)

    def finish_path(self, path):
        path['latent'] = self.prev_latent

    def reset_policy(self):
        super().reset_policy()
        self._policy.fixed_latent = True
        self._policy.sample_latent()
        self.prev_latent = self._policy.get_current_latent()

    def end_path_collection(self):
        super().end_path_collection()
        self._policy.fixed_latent = False
        self._policy.sample_latent()


class GoalConditionedPathCollector(PathCollector):

    def __init__(
            self,
            env,
            policy,
            max_num_epoch_paths_saved=None,
            render=False,
            render_kwargs=None,
            observation_key='observation',
            desired_goal_key='desired_goal',
    ):
        if render_kwargs is None:
            render_kwargs = {}
        self._env = env
        self._policy = policy
        self._max_num_epoch_paths_saved = max_num_epoch_paths_saved
        self._render = render
        self._render_kwargs = render_kwargs
        self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved)
        self._observation_key = observation_key
        self._desired_goal_key = desired_goal_key

        self._num_steps_total = 0
        self._num_paths_total = 0

    def collect_new_paths(
            self,
            max_path_length,
            num_steps,
            discard_incomplete_paths,
    ):
        paths = []
        num_steps_collected = 0
        while num_steps_collected < num_steps:
            max_path_length_this_loop = min(  # Do not go over num_steps
                max_path_length,
                num_steps - num_steps_collected,
            )
            path = multitask_rollout(
                self._env,
                self._policy,
                max_path_length=max_path_length_this_loop,
                render=self._render,
                render_kwargs=self._render_kwargs,
                observation_key=self._observation_key,
                desired_goal_key=self._desired_goal_key,
                return_dict_obs=True,
            )
            path_len = len(path['actions'])
            if (
                    path_len != max_path_length
                    and not path['terminals'][-1]
                    and discard_incomplete_paths
            ):
                break
            num_steps_collected += path_len
            paths.append(path)
        self._num_paths_total += len(paths)
        self._num_steps_total += num_steps_collected
        self._epoch_paths.extend(paths)
        return paths

    def get_epoch_paths(self):
        return self._epoch_paths

    def end_epoch(self, epoch):
        self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved)

    def get_diagnostics(self):
        path_lens = [len(path['actions']) for path in self._epoch_paths]
        stats = OrderedDict([
            ('num steps total', self._num_steps_total),
            ('num paths total', self._num_paths_total),
        ])
        stats.update(create_stats_ordered_dict(
            "path length",
            path_lens,
            always_show_all_stats=True,
        ))
        return stats

    def get_snapshot(self):
        return dict(
            env=self._env,
            policy=self._policy,
            observation_key=self._observation_key,
            desired_goal_key=self._desired_goal_key,
        )
