import numpy as np
import torch

from neroRL.sampler.buffer import Buffer
from neroRL.utils.worker import Worker

class TrajectorySampler():
    """The TrajectorySampler employs n environment workers to sample data for s worker steps regardless if an episode ended.
    Hence, the collected trajectories may contain multiple episodes or incomplete ones."""
    def __init__(self, configs, worker_id, visual_observation_space, vector_observation_space, action_space_shape, model, device) -> None:
        """Initializes the TrajectorSampler and launches its environment workers.

        Arguments:
            configs {dict} -- The whole set of configurations (e.g. training and environment configs)
            worker_id {int} -- Specifies the offset for the port to communicate with the environment, which is needed for Unity ML-Agents environments.
            visual_observation_space {box} -- Dimensions of the visual observation space (None if not available)
            vector_observation_space {tuple} -- Dimensions of the vector observation space (None if not available)
            action_space_shape {tuple} -- Dimensions of the action space
            model {nn.Module} -- The model to retrieve the policy and value from
            device {torch.device} -- The device that is used for retrieving the data from the model
        """
        # Set member variables
        self.configs = configs
        self.visual_observation_space = visual_observation_space
        self.vector_observation_space = vector_observation_space
        self.model = model
        self.n_workers = configs["sampler"]["n_workers"]
        self.worker_steps = configs["sampler"]["worker_steps"]
        self.recurrence = None if not "recurrence" in configs["model"] else configs["model"]["recurrence"]
        self.device = device

        # Create Buffer
        self.buffer = Buffer(self.n_workers, self.worker_steps, visual_observation_space, vector_observation_space,
                        action_space_shape, self.recurrence, "helm" in configs["model"], self.device, self.model.share_parameters, self)

        # Launch workers
        self.workers = [Worker(configs["environment"], worker_id + 200 + w) for w in range(self.n_workers)]
        
        # Setup initial observations
        if visual_observation_space is not None:
            self.vis_obs = np.zeros((self.n_workers,) + visual_observation_space.shape, dtype=np.float32)
        else:
            self.vis_obs = None
        if vector_observation_space is not None:
            self.vec_obs = np.zeros((self.n_workers,) + vector_observation_space, dtype=np.float32)
        else:
            self.vec_obs = None

        # Setup initial recurrent cell
        if self.recurrence is not None:
            hxs, cxs = self.model.init_recurrent_cell_states(self.n_workers, self.device)
            if self.recurrence["layer_type"] == "gru":
                self.recurrent_cell = hxs
            elif self.recurrence["layer_type"] == "lstm":
                self.recurrent_cell = (hxs, cxs)
        else:
            self.recurrent_cell = None

        # Setup HELM memory
        if "helm" in self.configs["model"]:
            self.helm_memory = [torch.zeros((511, self.n_workers, 1024)) for _ in range(18)]
        else:
            self.helm_memory = None

        # Reset workers
        for worker in self.workers:
            worker.child.send(("reset", None))
        # Grab initial observations
        for i, worker in enumerate(self.workers):
            vis_obs, vec_obs = worker.child.recv()
            if self.vis_obs is not None:
                self.vis_obs[i] = vis_obs
            if self.vec_obs is not None:
                self.vec_obs[i] = vec_obs

    def sample(self, device) -> list:
        """Samples training data (i.e. experience tuples) using n workers for t worker steps.

        Arguments:
            device {torch.device} -- The device that is used for retrieving the data from the model

        Returns:
            {list} -- List of completed episodes. Each episode outputs a dictionary containing at least the
            achieved reward and the episode length.
        """
        episode_infos = []

        # Sample actions from the model and collect experiences for training
        for t in range(self.worker_steps):
            # Gradients can be omitted for sampling data
            with torch.no_grad():
                # Save the initial observations and hidden states
                if self.vis_obs is not None:
                    self.buffer.vis_obs[:, t] = torch.tensor(self.vis_obs)
                if self.vec_obs is not None:
                    self.buffer.vec_obs[:, t] = torch.tensor(self.vec_obs)
                # Store recurrent cell states inside the buffer
                if self.recurrence is not None:
                    if self.recurrence["layer_type"] == "gru":
                        self.buffer.hxs[:, t] = self.recurrent_cell.squeeze(0)
                    elif self.recurrence["layer_type"] == "lstm":
                        self.buffer.hxs[:, t] = self.recurrent_cell[0].squeeze(0)
                        self.buffer.cxs[:, t] = self.recurrent_cell[1].squeeze(0)

                # Forward the model to retrieve the policy (making decisions), 
                # the states' value of the value function and the recurrent hidden states (if available)
                vis_obs_batch = torch.tensor(self.vis_obs) if self.vis_obs is not None else None
                vec_obs_batch = torch.tensor(self.vec_obs) if self.vec_obs is not None else None
                if self.helm_memory is not None:
                    self.model.helm_encoder.memory = self.helm_memory
                policy, value, self.recurrent_cell, _, h_helm = self.model(vis_obs_batch, vec_obs_batch, self.recurrent_cell)
                if self.helm_memory is not None:
                    self.helm_memory = self.model.helm_encoder.memory
                if "helm" in self.configs["model"]:
                    self.buffer.h_helm[:, t] = h_helm
                self.buffer.values[:, t] = value.data

                # Sample actions from each individual policy branch
                actions = []
                log_probs = []
                for action_branch in policy:
                    action = action_branch.sample()
                    actions.append(action)
                    log_probs.append(action_branch.log_prob(action))
                self.buffer.actions[:, t] = torch.stack(actions, dim=1)
                self.buffer.log_probs[:, t] = torch.stack(log_probs, dim=1)

            # Execute actions
            actions = self.buffer.actions[:, t].cpu().numpy() # send actions as batch to the CPU, to save IO time
            for w, worker in enumerate(self.workers):
                worker.child.send(("step", actions[w]))

            # Retrieve results
            for w, worker in enumerate(self.workers):
                vis_obs, vec_obs, self.buffer.rewards[w, t], self.buffer.dones[w, t], info = worker.child.recv()
                if self.vis_obs is not None:
                    self.vis_obs[w] = vis_obs
                if self.vec_obs is not None:
                    self.vec_obs[w] = vec_obs
                if info:
                    # Store the information of the completed episode (e.g. total reward, episode length)
                    episode_infos.append(info)
                    # Reset agent (potential interface for providing reset parameters)
                    worker.child.send(("reset", None))
                    # Get data from reset
                    vis_obs, vec_obs = worker.child.recv()
                    if self.vis_obs is not None:
                        self.vis_obs[w] = vis_obs
                    if self.vec_obs is not None:
                        self.vec_obs[w] = vec_obs
                    # Reset recurrent cell states
                    if self.recurrence is not None:
                        if self.recurrence["reset_hidden_state"]:
                            hxs, cxs = self.model.init_recurrent_cell_states(1, self.device)
                            if self.recurrence["layer_type"] == "gru":
                                self.recurrent_cell[:, w] = hxs
                            elif self.recurrence["layer_type"] == "lstm":
                                self.recurrent_cell[0][:, w] = hxs
                                self.recurrent_cell[1][:, w] = cxs
                    # Reset HELM Memory
                    if "helm" in self.configs["model"]:
                        for l in range(len(self.helm_memory)):
                            self.helm_memory[l][:, w] = 0.

        return episode_infos

    def last_vis_obs(self) -> np.ndarray:
        """
        Returns:
            {np.ndarray} -- The last visual observation of the sampling process, which can be used to calculate the advantage.
        """
        return torch.tensor(self.vis_obs) if self.vis_obs is not None else None

    def last_vec_obs(self) -> np.ndarray:
        """
        Returns:
            {np.ndarray} -- The last vector observation of the sampling process, which can be used to calculate the advantage.
        """
        return torch.tensor(self.vec_obs) if self.vec_obs is not None else None

    def last_recurrent_cell(self) -> tuple:
        """
        Returns:
            {tuple} -- The latest recurrent cell of the sampling process, which can be used to calculate the advantage.
        """
        return self.recurrent_cell

    def close(self) -> None:
        """Closes the sampler and shuts down its environment workers."""
        try:
            for worker in self.workers:
                worker.child.send(("close", None))
        except:
            pass