import argparse
import collections
import copy
import gym
import env.wrappers as gym_wrappers
import json
import os
from pathlib import Path
import shelve
import logging
from gym.envs.classic_control.rendering import SimpleImageViewer
import ray
import ray.cloudpickle as cloudpickle
from ray.rllib.agents.registry import get_trainer_class
from ray.rllib.env import MultiAgentEnv
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
from ray.rllib.env.env_context import EnvContext
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray
from ray.tune.utils import merge_dicts
from ray.tune.registry import get_trainable_cls, _global_registry, ENV_CREATOR
from ray.rllib.agents.impala.impala import ImpalaTrainer
from ray.rllib.agents.ppo.ppo import PPOTrainer
from ray.rllib.agents.impala.vtrace_torch_policy import VTraceTorchPolicy
from ray.rllib.agents.dqn.r2d2 import R2D2Trainer
from ray.rllib.agents.dqn.r2d2_torch_policy import R2D2TorchPolicy
from actors.dqnpolicy import build_custom_dqn_policy
from actors.ppo.ppopolicy import PPOTorchCustomPolicy
from custom_scenario import scenario_generator

import torch
import numpy as np

from starlette.requests import Request


logger = logging.getLogger(__name__)
# from ray.rllib.evaluate import RolloutSaver
from arguments import create_parser
from config import return_config
from ray.rllib.models import ModelCatalog

simple_image_viewer = None
parser = create_parser()


class RolloutSaver:
    """Utility class for storing rollouts.
    Currently supports two behaviours: the original, which
    simply dumps everything to a pickle file once complete,
    and a mode which stores each rollout as an entry in a Python
    shelf db file. The latter mode is more robust to memory problems
    or crashes part-way through the rollout generation. Each rollout
    is stored with a key based on the episode number (0-indexed),
    and the number of episodes is stored with the key "num_episodes",
    so to load the shelf file, use something like:
    with shelve.open('rollouts.pkl') as rollouts:
       for episode_index in range(rollouts["num_episodes"]):
          rollout = rollouts[str(episode_index)]
    If outfile is None, this class does nothing.
    """

    def __init__(
        self,
        outfile=None,
        use_shelve=False,
        write_update_file=False,
        target_steps=None,
        target_episodes=None,
        save_info=False,
    ):
        self._outfile = outfile
        self._update_file = None
        self._use_shelve = use_shelve
        self._write_update_file = write_update_file
        self._shelf = None
        self._num_episodes = 0
        self._rollouts = []
        self._current_rollout = []
        self._total_steps = 0
        self._target_episodes = target_episodes
        self._target_steps = target_steps
        self._save_info = save_info

    def _get_tmp_progress_filename(self):
        outpath = Path(self._outfile)
        return outpath.parent / ("__progress_" + outpath.name)

    @property
    def outfile(self):
        return self._outfile

    def __enter__(self):
        if self._outfile:
            if self._use_shelve:
                # Open a shelf file to store each rollout as they come in
                self._shelf = shelve.open(self._outfile)
            else:
                # Original behaviour - keep all rollouts in memory and save
                # them all at the end.
                # But check we can actually write to the outfile before going
                # through the effort of generating the rollouts:
                try:
                    with open(self._outfile, "wb") as _:
                        pass
                except IOError as x:
                    print(
                        "Can not open {} for writing - cancelling rollouts.".format(
                            self._outfile
                        )
                    )
                    raise x
            if self._write_update_file:
                # Open a file to track rollout progress:
                self._update_file = self._get_tmp_progress_filename().open(mode="w")
        return self

    def __exit__(self, type, value, traceback):
        if self._shelf:
            # Close the shelf file, and store the number of episodes for ease
            self._shelf["num_episodes"] = self._num_episodes
            self._shelf.close()
        elif self._outfile and not self._use_shelve:
            # Dump everything as one big pickle:
            cloudpickle.dump(self._rollouts, open(self._outfile, "wb"))
        if self._update_file:
            # Remove the temp progress file:
            self._get_tmp_progress_filename().unlink()
            self._update_file = None

    def _get_progress(self):
        if self._target_episodes:
            return "{} / {} episodes completed".format(
                self._num_episodes, self._target_episodes
            )
        elif self._target_steps:
            return "{} / {} steps completed".format(
                self._total_steps, self._target_steps
            )
        else:
            return "{} episodes completed".format(self._num_episodes)

    def begin_rollout(self):
        self._current_rollout = []

    def end_rollout(self):
        if self._outfile:
            if self._use_shelve:
                # Save this episode as a new entry in the shelf database,
                # using the episode number as the key.
                self._shelf[str(self._num_episodes)] = self._current_rollout
            else:
                # Append this rollout to our list, to save laer.
                self._rollouts.append(self._current_rollout)
        self._num_episodes += 1
        if self._update_file:
            self._update_file.seek(0)
            self._update_file.write(self._get_progress() + "\n")
            self._update_file.flush()

    def append_step(self, obs, action, next_obs, reward, done, info):
        """Add a step to the current rollout, if we are saving them"""
        if self._outfile:
            if self._save_info:
                self._current_rollout.append(
                    [obs, action, next_obs, reward, done, info]
                )
            else:
                self._current_rollout.append([obs, action, next_obs, reward, done])
        self._total_steps += 1


class DefaultMapping(collections.defaultdict):
    """default_factory now takes as an argument the missing key."""

    def __missing__(self, key):
        self[key] = value = self.default_factory(key)
        return value


def default_policy_agent_mapping(unused_agent_id):
    return DEFAULT_POLICY_ID


def keep_going(steps, num_steps, episodes, num_episodes):
    """Determine whether we've collected enough data"""
    # If num_episodes is set, stop if limit reached.
    if num_episodes and episodes >= num_episodes:
        return False
    # If num_steps is set, stop if limit reached.
    elif num_steps and steps >= num_steps:
        return False
    # Otherwise, keep going.
    return True


def rollout___(
    agent,
    env_name,
    num_steps,
    num_episodes=0,
    saver=None,
    no_render=True,
    video_dir=None,
    multiagent=False,
    args=None,
):
    simple_image_viewer = None
    policy_agent_mapping = default_policy_agent_mapping
    if saver is None:
        saver = RolloutSaver()

    #run evaluation locally for better access then on remote policies/
    if hasattr(agent, "workers") and isinstance(agent.workers, WorkerSet):
        env = agent.workers.local_worker().env
        multiagent = isinstance(env, MultiAgentEnv)
        if agent.workers.local_worker().multiagent:
            policy_agent_mapping = agent.config["multiagent"]["policy_mapping_fn"]
        policy_map = agent.workers.local_worker().policy_map
        state_init = {p: m.get_initial_state() for p, m in policy_map.items()}
        use_lstm = {p: len(s) > 0 for p, s in state_init.items()}
    else:
        assert False, "Agent has no workers!"

    action_init = {
        p: flatten_to_single_ndarray(m.action_space.sample())
        for p, m in policy_map.items()
    }

    # If monitoring has been requested, manually wrap our environment with a
    # gym monitor, which is set to record every episode.
    if video_dir:
        env = gym_wrappers.Monitor(
            env=env,
            directory=video_dir,
            video_callable=lambda _: True,
            force=True,
            multiagent=multiagent,
        )

    steps = 0
    episodes = 0
    while keep_going(steps, num_steps, episodes, num_episodes):
        mapping_cache = {}  # in case policy_agent_mapping is stochastic
        saver.begin_rollout()
        obs = env.reset()
        agent_states = DefaultMapping(
            lambda agent_id: state_init[mapping_cache[agent_id]]
        )
        prev_actions = DefaultMapping(
            lambda agent_id: action_init[mapping_cache[agent_id]]
        )
        prev_rewards = collections.defaultdict(lambda: 0.0)
        done = False
        reward_total = 0.0
        while not done and keep_going(steps, num_steps, episodes, num_episodes):
            multi_obs = obs if multiagent else {_DUMMY_AGENT_ID: obs}
            action_dict = {}
            concept_dict = {}
            if args.correct_concepts:
                correct_concept_list = env.get_concepts()
            agent_infos = {}
            for agent_id, a_obs in multi_obs.items():
                if a_obs is not None:
                    policy_id = mapping_cache.setdefault(
                        agent_id, policy_agent_mapping(agent_id, None)
                    )
                    p_use_lstm = use_lstm[policy_id]

                    if args.correct_concepts:
                        concept_update = torch.tensor(
                            correct_concept_list[agent_id], dtype=torch.float32
                        )
                        concept_update = concept_update.reshape((1, 1, -1))
                        print("correcting concept")
                    elif args.concept_update is not None:
                        concept_update = torch.tensor(
                            args.concept_update, dtype=torch.float32
                        )
                        concept_update = concept_update.reshape((1, 1, -1))
                        print("injecting false concept")
                    else:
                        concept_update = None
                        print("concept update is none")

                    if len(args.replacement) > 0:
                        do_update = True
                    else:
                        do_update = False
                    if p_use_lstm:
                        a_action, p_state, info = agent.compute_single_action_(
                            a_obs,
                            state=agent_states[agent_id],
                            prev_action=prev_actions[agent_id],
                            prev_reward=prev_rewards[agent_id],
                            policy_id=policy_id,
                            concept_update=concept_update,
                            do_update=do_update,
                            concepts_to_update=args.replacement,
                        )
                        agent_states[agent_id] = p_state
                        print(a_obs)
                    else:
                        a_action = agent.compute_single_action_(
                            a_obs,
                            prev_action=prev_actions[agent_id],
                            prev_reward=prev_rewards[agent_id],
                            policy_id=policy_id,
                            concept_update=concept_update,
                            do_update=do_update,
                            concepts_to_update=args.replacement,
                        )
                    softmax_concepts = info["concepts_after_softmax"]
                    a_action = flatten_to_single_ndarray(a_action)
                    concept_dict[agent_id] = softmax_concepts
                    action_dict[agent_id] = a_action
                    prev_actions[agent_id] = a_action
                    agent_infos[agent_id] = {}
                    agent_infos[agent_id]["predicted_concepts"] = softmax_concepts
            action = action_dict

            action = action if multiagent else action[_DUMMY_AGENT_ID]
            env.return_all = True
            next_obs, reward, done, info = env.step(action)
            info["agent_infos"] = agent_infos

            if multiagent:
                for agent_id, r in reward.items():
                    prev_rewards[agent_id] = r
            else:
                prev_rewards[_DUMMY_AGENT_ID] = reward

            if multiagent:
                reward_total += sum(r for r in reward.values() if r is not None)
            else:
                reward_total += reward
            if not no_render:
                rendered = env.render()
                if isinstance(rendered, np.ndarray) and len(rendered.shape) == 3:
                    # ImageViewer not defined yet, try to create one.
                    if simple_image_viewer is None:
                        try:
                            from gym.envs.classic_control.rendering import (
                                SimpleImageViewer,
                            )

                            simple_image_viewer = SimpleImageViewer()
                        except (ImportError, ModuleNotFoundError):
                            render = False  # disable rendering
                            logger.warning(
                                "Could not import gym.envs.classic_control."
                                "rendering! Try `pip install gym[all]`."
                            )
                    if simple_image_viewer:
                        simple_image_viewer.imshow(rendered)
                elif rendered not in [True, False, None]:
                    raise ValueError(
                        "The env's ({base_env}) `try_render()` method returned an"
                        " unsupported value! Make sure you either return a "
                        "uint8/w x h x 3 (RGB) image or handle rendering in a "
                        "window and then return `True`."
                    )
                info["render"] = rendered
            saver.append_step(obs, action, next_obs, reward, done, info)
            steps += 1
            obs = next_obs
        saver.end_rollout()
        print("Episode #{}: reward: {}".format(episodes, reward_total))
        if done:
            episodes += 1
