import numpy as np
from abc import ABC, abstractmethod

import time
import torch
from torch.utils.data._utils.collate import default_collate


# job of the runner is to give processed data which can then directly used for optimization
# Also, give preprocessed data to lessons buffer directly

class AbstractEnvRunner(ABC):
    def __init__(self, *, env, nsteps, gamma, lam, device=None):
        self.env = env
        self.nenv = nenv = env.num_envs if hasattr(env, 'num_envs') else 1
        self.device = device
#        self.batch_ob_shape = (nenv*nsteps,) + env.observation_space.shape
        self.state = self.env.reset()
        self.nsteps = nsteps
        self.gamma = gamma
        self.lam = lam

        # maybe for recurrent models/policies
        # self.states = policy.initial_state
        self.dones = [False for _ in range(nenv)]

    @abstractmethod
    def run(self):
        raise NotImplementedError


class Runner(AbstractEnvRunner):
    """
    We use this object to make a mini batch of experiences
    __init__:
    - Initialize the runner

    run():
    - Make a mini batch
    """

    @staticmethod
    def transform_action_dict(actions):
        """
        Take a list of list of dictionary actions and return a numpy array of actions.
        """
        return np.asarray(actions, dtype=np.object)

    @staticmethod
    def transform_state_dict(states):
        """
        Take in a list of dictionary's and return a list of states
        """
        # number of states
        length = len(states)
        # number of elements in the state dictionary
        num_dict_elements = len(states[0])
        state_list = [[] for _ in range(num_dict_elements)]
        for i in range(length):
            for j, k in zip(state_list, states[i].keys()):
                j.append(states[i][k])

        return np.stack(state_list[0]), np.stack(state_list[1]), np.stack(state_list[2])

    @staticmethod
    def dict_to_device(dictionary, device):
        for k in dictionary.keys():
            dictionary[k] = dictionary[k].to(device)
        return dictionary

    @staticmethod
    def sf01(arr):
        """
        swap and then flatten axes 0 and 1
        """
        s = arr.shape
        return arr.swapaxes(0, 1).reshape(s[0] * s[1], *s[2:])

    def __init__(self, *, model, dataset, nenv, env, nsteps, seq_len, gamma, lam, episode_length, reward_scale,
                 episodic_interactions, task_to_check="log", task_based_reward=False, num_tasks_to_check=1,
                 safe_exploration=False, incremental_threshold=10, incremental=False, resource_penalization_scale=0.5,
                 timestep_based_reward=False, step_counter_worker_envs_reset=False, envname=None, device):
        super().__init__(env=env, nsteps=nsteps, gamma=gamma, lam=lam, device=device)

        self.model = model
        self.input_space = dataset.INPUT_SPACE
        self.action_space = dataset.ACTION_SPACE

        self.envname = envname
        self.nenv = nenv
        # fill up the agent states with seq_len observations
        self.seq_len = seq_len
        self.episode_length = episode_length
        self.steps = 0
        self.resource_penalization_scale = resource_penalization_scale
        self.reward_scale = reward_scale
        self.timestep_based_reward = timestep_based_reward
        # not allowed to use timestep_based_reward with resource_penalization_scale since negative
        # rewards are not supported
        assert not (self.resource_penalization_scale and self.timestep_based_reward)
        self.returns = np.zeros(self.nenv)
        self.episodic_returns = []
        self.episodic_interactions = episodic_interactions
        self.step_counter_worker_envs_reset = step_counter_worker_envs_reset
        self.safe_exploration = safe_exploration
        self.task_to_check = task_to_check
        self.task_based_reward = task_based_reward
        if self.task_based_reward:
            self.reward_received = [False for _ in range(self.nenv)]
            self.last_inventory = [0 for _ in range(self.nenv)]
            self.num_tasks_to_check = num_tasks_to_check
        self.incremental = incremental
        self.incremental_threshold = incremental_threshold

    def run(self):
        """
        Perform rollout and prepare results for further processing
        """

        # perform rollout
        with torch.no_grad():
            b_states, b_rewards, b_dones, b_actions, b_values, b_logprobs, b_last_values, epinfo = self.rollout()

        # compute the advantage and return from the batch of data
        b_rewards, b_dones, b_values, b_advs, b_returns = self.compute_advantage(b_rewards, b_dones,
                                                                                 b_last_values, b_values)

        # convert rest of data to numpy arrays
        b_logprobs = np.stack(b_logprobs)
        pov, bin_actions, camera_actions = self.transform_state_dict(b_states)
        b_actions = np.asarray(b_actions, dtype=np.object)

        # actions and states are in the form of list
        return (*map(self.sf01, (b_returns, b_advs, b_logprobs, pov, bin_actions, camera_actions, b_actions, b_values)),
                epinfo)

    def rollout(self, nsteps=None):
        """
        Perform rollout of nsteps with current policy
        """

        # init
        nsteps = self.nsteps if nsteps is None else nsteps
        states, actions, rewards, dones = [], [], [], []
        values, logprobs = [], []
        epinfo = []

        # reset to beginning of episode
        if self.episodic_interactions:
            self.state = self.env.reset()
            if self.task_based_reward:
                self.last_inventory = [0 for _ in range(self.nenv)]
                self.reward_received = [False for _ in range(self.nenv)]

        if self.safe_exploration:
            # with some prob select among the following
            choice = np.random.choice([0, 1, 2], p=[0.6, 0.2, 0.2])
            if choice == 1:
                # do some random actions
                # Move in one direction for fixed number of steps with attack on
                # do some random actions
                n1 = np.random.randint(20, 50)
                for i in range(n1):
                    action = self.env.action_space.sample()
                    self.state, reward, self.dones, infos = self.env.step((action,)*self.nenv)

                n2 = np.random.randint(20, 50)
                for i in range(n2):
                    action = self.env.action_space.noop()
                    action["jump"] = 1
                    action["attack"] = 1
                    action["sprint"] = 1
                    self.state, reward, self.dones, infos = self.env.step((action,)*self.nenv)

                n3 = np.random.randint(5, 10)
                for i in range(n3):
                    action = self.env.action_space.sample()
                    self.state, reward, self.dones, infos = self.env.step((action,)*self.nenv)
            elif choice == 2:
                # do some random actions
                n1 = np.random.randint(5, 10)
                for i in range(n1):
                    action = self.env.action_space.sample()
                    self.state, reward, self.dones, infos = self.env.step((action,)*self.nenv)

        last_inventory_state = [None for _ in range(self.nenv)]
        reward_steps = [0 for _ in range(self.nenv)]

        # perform nstep rollout
        start_time = time.time()
        for _ in range(nsteps):
            # collate, to tensor, to device
            input_dict = default_collate(self.state)
            input_dict = self.dict_to_device(input_dict, self.device)

            # predict next action, value, logprob of action
            out_dict = self.model(input_dict)

            # translate predictions to environment actions
            action = self.action_space.logits_to_dict(self.env.action_space.noop(), out_dict)

            # evaluate model
            if self.envname == "MineRLTreechop-v0":
                value, action_log_probs, camera_log_probs, action_entropy, camera_entropy = self.action_space.evaluate_actions(out_dict, action)

                # append the logprobs
                logprob_action = torch.cat((action_log_probs, camera_log_probs), 1).cpu().numpy()
            else:
                value, action_log_probs, camera_log_probs, equip_log_probs, place_log_probs, craft_log_probs, \
                nearbyCraft_log_probs, nearbySmelt_log_probs, action_entropy, camera_entropy, equip_entropy, \
                place_entropy, craft_entropy, nearbyCraft_entropy, nearbySmelt_entropy = \
                    self.action_space.evaluate_actions(out_dict, action)

                # append the logprobs
                logprob_action = torch.cat((action_log_probs, camera_log_probs, equip_log_probs, place_log_probs,
                                            craft_log_probs, nearbyCraft_log_probs, nearbySmelt_log_probs), 1).cpu().numpy()

            # take env step
            self.state, reward, self.dones, infos = self.env.step(action)
            resources_deltas = [0.0 for _ in range(self.nenv)]
            # compare current resources
            if last_inventory_state[0] is not None:
                for i, info in enumerate(infos):
                    resources_deltas[i] = np.array([min(v - last_inventory_state[i][k], 0) for k, v in info['meta_controller']['state']['inventory'].items()]).mean().item()

            if self.task_based_reward and self.task_to_check != 'diamond':
                r = []
                # checks the inventory for the task, and if it was not present before give reward
                # if it was present before, give zero reward
                for rew_rec, info, done in zip(np.arange(len(self.reward_received)), infos, self.dones):
                    if info['meta_controller']['state']['inventory'][self.task_to_check] > self.last_inventory[rew_rec]:
                        r.append(1.0)
                        self.reward_received[rew_rec] = True
                        self.last_inventory[rew_rec] = info['meta_controller']['state']['inventory'][self.task_to_check]
                    else:
                        r.append(0.0)
                reward = np.asarray(r, dtype=np.float32)
            elif self.task_based_reward and self.task_to_check == "diamond":
                r = []
                for i, r in enumerate(reward):
                    if reward[i] >= 1000:
                        reward[i] = 1
                    else:
                        reward[i] = 0
                reward = np.asarray(r, dtype=np.float32)
            elif self.incremental:
                for i, r in enumerate(reward):
                    if r < self.incremental_threshold and r > 0:
                        reward[i] = 0.0
                    elif r > 0 and r >= self.incremental_threshold:
                        reward[i] = r/self.incremental_threshold
            else:
                reward = np.asarray(reward, dtype=np.float32)

            # increment step if reward
            reward_steps = [s+1 for s in reward_steps]

            # add resource consumption penalization
            reward = reward + self.resource_penalization_scale * np.asarray(resources_deltas, dtype=np.float32)

            self.returns = self.returns + reward
            reward *= self.reward_scale

            if self.timestep_based_reward:
                reward = reward / (1 + np.log(1 + np.array(reward_steps)))

            # reset if reward has finished
            reward_steps = [0 if reward[i] > 0 else s for i, s in enumerate(reward_steps)]

            # if episode is done, for that env set reward received flat to False
            for i, done in enumerate(self.dones):
                if done:
                    if self.task_based_reward:
                        self.reward_received[i] = False
                        self.last_inventory[i] = 0
                    self.episodic_returns.append(self.returns[i])
                    self.returns[i] = 0

            # check if episode info is available
            for info in infos:
                maybeepinfo = info.get('episode')
                if maybeepinfo:
                    epinfo.append(maybeepinfo)

            # get numpy arrays from model inputs
            input_dict = self.dict_to_device(input_dict, "cpu")

            # increase step counter and reset worker-envs if necessary
            if not self.episodic_interactions and self.step_counter_worker_envs_reset:
                self.steps += 1
                if self.steps >= self.episode_length:
                    self.state = self.env.reset()
                    # set reward received flag to false
                    if self.task_based_reward:
                        self.last_inventory = [0 for _ in range(self.nenv)]
                        self.reward_received = [False for _ in range(self.nenv)]
                    self.steps = 0
                    for i in range(self.nenv):
                        if not self.dones[i]:
                            self.episodic_returns.append(self.returns[i])
                            self.returns[i] = 0
                    self.dones = [True for _ in range(self.nenv)]

            # book keeping
            states.append(input_dict)
            actions.append(action)
            rewards.append(reward)
            dones.append(self.dones)
            values.append(value.squeeze().cpu().numpy())
            logprobs.append(logprob_action)

            # update the inventory state for resource tracking
            for i, info in enumerate(infos):
                if 'inventory' in info['meta_controller']['state']:
                    last_inventory_state[i] = info['meta_controller']['state']['inventory']

        # compute fps of rollout
        rollout_time = time.time() - start_time
        print("rollout of %d steps performed with %.2f fps" % (nsteps, float(nsteps) / rollout_time))

        # compile input dictionary for model
        input_dict = default_collate(self.state)
        for k in input_dict.keys():
            input_dict[k] = input_dict[k].to(self.device)

        # predict next action, value, logprob of action
        out_dict = self.model(input_dict)
        last_value = out_dict["value"].cpu().numpy()

        if self.episodic_interactions:
            for i in range(self.nenv):
                self.episodic_returns.append(self.returns[i])
                self.returns[i] = 0
            self.dones = [False for _ in range(self.nenv)]
            dones.pop(-1)
            dones.append([True for _ in range(self.nenv)])

        return states, rewards, dones, actions, values, logprobs, last_value, epinfo

    def compute_advantage(self, b_rewards, b_dones, b_last_values, b_values):
        b_rewards = np.stack(b_rewards)
        b_dones = np.stack(b_dones)
        b_last_values = b_last_values
        b_values = np.stack(b_values)

        b_advs = np.zeros_like(b_rewards)
        lastgaelam = 0

        # reversely iterate rollout
        for t in reversed(range(self.nsteps)):

            # special case: reuse value of previous rollout
            if t == self.nsteps - 1:
                nextnonterminal = 1.0 - np.array(self.dones)
                nextvalues = b_last_values.squeeze()
            # standard case
            else:
                nextnonterminal = 1.0 - b_dones[t+1]
                nextvalues = b_values[t+1]
            # compute advantage
            delta = b_rewards[t] + self.gamma * nextvalues * nextnonterminal - b_values[t]
            b_advs[t] = lastgaelam = delta + self.gamma * self.lam * nextnonterminal * lastgaelam

        b_returns = b_advs + b_values

        return b_rewards, b_dones, b_values, b_advs, b_returns
