import numpy as np
from abc import ABC, abstractmethod

import time
import torch
import numpy as np
from torch.utils.data._utils.collate import default_collate


# job of the runner is to give processed data which can then directly used for optimization
# Also, give preprocessed data to lessons buffer directly

class AbstractEnvRunner(ABC):
    def __init__(self, *, env, nsteps, gamma, lam, device=None):
        self.env = env
        self.nenv = nenv = env.num_envs if hasattr(env, 'num_envs') else 1
        self.device = device
#        self.batch_ob_shape = (nenv*nsteps,) + env.observation_space.shape
        self.state = self.env.reset()
        self.nsteps = nsteps
        self.gamma = gamma
        self.lam = lam

        # maybe for recurrent models/policies
        # self.states = policy.initial_state
        self.dones = [False for _ in range(nenv)]

    @abstractmethod
    def run(self):
        raise NotImplementedError


class Runner(AbstractEnvRunner):
    """
    We use this object to make a mini batch of experiences
    __init__:
    - Initialize the runner

    run():
    - Make a mini batch
    """

    @staticmethod
    def transform_action_dict(actions):
        """
        Take a list of list of dictionary actions and return a numpy array of actions.
        """
        return np.asarray(actions, dtype=np.object)

    @staticmethod
    def transform_state_dict(states):
        """
        Take in a list of dictionary's and return a list of states
        """
        # number of states
        length = len(states)
        # number of elements in the state dictionary
        num_dict_elements = len(states[0])
        state_list = [[] for _ in range(num_dict_elements)]
        for i in range(length):
            for j, k in zip(state_list, states[i].keys()):
                j.append(states[i][k])

        return np.stack(state_list[0]), np.stack(state_list[1]), np.stack(state_list[2])

    @staticmethod
    def dict_to_device(dictionary, device):
        for k in dictionary.keys():
            dictionary[k] = dictionary[k].to(device)
        return dictionary

    @staticmethod
    def sf01(arr):
        """
        swap and then flatten axes 0 and 1
        """
        s = arr.shape
        return arr.swapaxes(0, 1).reshape(s[0] * s[1], *s[2:])

    def __init__(self, *, model, rudder, dataset, nenv, env, nsteps, seq_len, gamma, lam, episode_length, reward_scale, episodic_interactions, device):
        super().__init__(env=env, nsteps=nsteps, gamma=gamma, lam=lam, device=device)

        self.model = model
        self.input_space = dataset.INPUT_SPACE
        self.action_space = dataset.ACTION_SPACE

        self.nenv = nenv
        # fill up the agent states with seq_len observations
        self.seq_len = seq_len
        self.episode_length = episode_length
        self.steps = 0
        self.reward_scale = reward_scale
        self.returns = np.zeros(self.nenv)
        self.episodic_returns = []
        self.episodic_interactions = episodic_interactions
        self.rudder = rudder # rudder object with its own model and functions

    def run(self):
        """
        Perform rollout and prepare results for further processing
        """

        # perform rollout
        with torch.no_grad():
            b_states, b_rewards, b_dones, b_actions, b_values, b_logprobs, b_last_values, epinfo = self.rollout()

        # compute the advantage and return from the batch of data
        b_rewards, b_dones, b_values, b_advs, b_returns = self.compute_advantage(b_rewards, b_dones,
                                                                                 b_last_values, b_values)

        # convert rest of data to numpy arrays
        b_logprobs = np.stack(b_logprobs)
        pov, bin_actions, camera_actions = self.transform_state_dict(b_states)
        b_actions = np.asarray(b_actions, dtype=np.object)

        # actions and states are in the form of list
        return (*map(self.sf01, (b_returns, b_advs, b_logprobs, pov, bin_actions, camera_actions, b_actions, b_values)),
                epinfo)

    def learn_from_buffer(self, writer, rudder_pretraining=False, stop_loss=10, batch_size=64):
        # learn from samples in the buffer until, loss is below stop_loss
        self.rudder.update(writer, stop_loss, rudder_pretraining, batch_size)

    def rollout(self, nsteps=None):
        """
        Perform rollout of nsteps with current policy
        """

        # init
        nsteps = self.nsteps if nsteps is None else nsteps
        states, actions, rewards, dones = [], [], [], []
        values, logprobs = [], []
        epinfo = []

        # reset to beginning of episode
        if self.episodic_interactions:
            self.state = self.env.reset()

        # perform nstep rollout
        start_time = time.time()
        for _ in range(nsteps):
            # collate, to tensor, to device
            input_dict = default_collate(self.state)
            input_dict = self.dict_to_device(input_dict, self.device)

            # predict next action, value, logprob of action
            out_dict = self.model(input_dict)

            # translate predictions to environment actions
            action = self.action_space.logits_to_dict(self.env.action_space.noop(), out_dict)

            # evaluate model
            value, action_log_probs, camera_log_probs, action_entropy, camera_entropy = \
                self.action_space.evaluate_actions(out_dict, action)

            # append the logprobs
            logprob_action = torch.cat((action_log_probs, camera_log_probs), 1).cpu().numpy()

            # take env step
            self.state, reward, self.dones, infos = self.env.step(action)
            reward = np.asarray(reward, dtype=np.float32)
            self.returns = self.returns + reward
            reward *= self.reward_scale

            # check if episode info is available
            for info in infos:
                maybeepinfo = info.get('episode')
                if maybeepinfo:
                    epinfo.append(maybeepinfo)

            # get numpy arrays from model inputs
            input_dict = self.dict_to_device(input_dict, "cpu")

            # increase step counter and reset worker-envs if necessary
            if not self.episodic_interactions:
                self.steps += 1
                if self.steps >= self.episode_length:
                    self.state = self.env.reset()
                    self.dones = [False for _ in range(self.nenv)]
                    self.steps = 0
                    self.episodic_returns.append(self.returns)
                    self.returns = np.zeros(self.nenv)

            # book keeping
            states.append(input_dict)
            actions.append(action)
            rewards.append(reward)
            dones.append(self.dones)
            values.append(value.squeeze().cpu().numpy())
            logprobs.append(logprob_action)

        # compute fps of rollout
        rollout_time = time.time() - start_time
        print("rollout of %d steps performed with %.2f fps" % (nsteps, float(nsteps) / rollout_time))

        # compile input dictionary for model
        input_dict = default_collate(self.state)
        for k in input_dict.keys():
            input_dict[k] = input_dict[k].to(self.device)

        # predict next action, value, logprob of action
        out_dict = self.model.forward(input_dict)
        last_value = out_dict["value"].cpu().numpy()

        if self.episodic_interactions:
            self.dones = [False for _ in range(self.nenv)]
            self.episodic_returns.append(self.returns)
            self.returns = np.zeros(self.nenv)

        return states, rewards, dones, actions, values, logprobs, last_value, epinfo

    def compute_advantage(self, b_rewards, b_dones, b_last_values, b_values):
        b_rewards = np.stack(b_rewards)
        b_dones = np.stack(b_dones)
        b_last_values = b_last_values
        b_values = np.stack(b_values)

        b_advs = np.zeros_like(b_rewards)
        lastgaelam = 0

        # reversely iterate rollout
        for t in reversed(range(self.nsteps)):

            # special case: reuse value of previous rollout
            if t == self.nsteps - 1:
                nextnonterminal = 1.0 - np.array(self.dones)
                nextvalues = b_last_values.squeeze()
            # standard case
            else:
                nextnonterminal = 1.0 - b_dones[t+1]
                nextvalues = b_values[t+1]
            # compute advantage
            delta = b_rewards[t] + self.gamma * nextvalues * nextnonterminal - b_values[t]
            b_advs[t] = lastgaelam = delta + self.gamma * self.lam * nextnonterminal * lastgaelam

        b_returns = b_advs + b_values

        return b_rewards, b_dones, b_values, b_advs, b_returns
