import os
from collections import deque, defaultdict

import numpy as np
import torch
from baselines.common.running_mean_std import RunningMeanStd

from plr import LevelSampler
from util import flatten_suffix_dict, is_discrete_actions, get_obs_at_index, set_obs_at_index

# import os
import matplotlib as mpl
# mpl.use("macOSX")
import matplotlib.pyplot as plt


class Runner(object):
    """
    Performs rollouts of agent in venv.
    """
    def __init__(
        self,
        args,
        venv,
        agent,
        train=False,
        plr_args=None,
        device='cpu'):
        """
        venv: Vectorized, adversarial gym env with agent-specific wrappers.
        agent: Protagonist trainer.
        """
        self.args = args
        self.venv = venv
        self.is_discrete_actions = is_discrete_actions(self.venv)
        self.agent = agent
        self.agent_rollout_steps = args.num_steps
        self.device = device

        if train:
            self.train()
        else:
            self.eval()

        # Set up batch stats
        self._update_batch_episodic_counts()

        # Set up PLR
        self.level_sampler = None
        if plr_args:
            self.level_sampler = LevelSampler(**plr_args)

        # Set up learned belief model state
        self.use_learned_beliefs = \
            args.force_obl_correction and args.use_learned_beliefs

        # Get belief-token spec
        self.belief_spec = self.venv.get_belief_spec()
        self._init_belief_model(spec=self.belief_spec)

        self.reset()

    def reset(self):
        args = self.args

        self.num_updates = 0
        self.total_episodes_collected = 0
        self.total_seeds_collected = 0
        self.student_grad_updates = 0

        max_return_queue_size = 10
        self.agent_returns = deque(maxlen=max_return_queue_size)

        obs = self.venv.reset()
        self.agent.storage.copy_obs_to_index(obs,0)

        self.level_seeds = None
        if self.level_sampler:
            self.level_seeds = torch.zeros(args.num_processes, 1, dtype=torch.int32)
            for i in range(args.num_processes):
                self.level_seeds[i] = torch.tensor(self.venv.level_seed(i))

        self.belief_token_counts = None

    def train(self):
        self.is_training = True
        self.agent.train()

    def eval(self):
        self.is_training = False
        self.agent.eval()

    def state_dict(self):
        agent_state_dict = {}
        optimizer_state_dict = {}

        agent_state_dict = self.agent.algo.actor_critic.state_dict()
        optimizer_state_dict = self.agent.algo.optimizer.state_dict()

        return {
            'agent_state_dict': agent_state_dict,
            'optimizer_state_dict': optimizer_state_dict,
            'agent_returns': self.agent_returns,
            'num_updates': self.num_updates,
            'total_episodes_collected': self.total_episodes_collected,
            'total_seeds_collected': self.total_seeds_collected,
            'student_grad_updates': self.student_grad_updates,
            'level_sampler': self.level_sampler,
            'batch_episodic_counts': self.batch_episodic_counts,
            'belief_model': self.belief_model
        }

    def load_state_dict(self, state_dict):
        agent_state_dict = state_dict.get('agent_state_dict')
        self.agent.algo.actor_critic.load_state_dict(agent_state_dict)

        optimizer_state_dict = state_dict.get('optimizer_state_dict')
        self.agent.algo.optimizer.load_state_dict(optimizer_state_dict)

        self.agent_returns = state_dict.get('agent_returns')
        self.num_updates = state_dict.get('num_updates')
        self.total_episodes_collected = state_dict.get('total_episodes_collected')
        self.total_seeds_collected = state_dict.get('total_seeds_collected')
        self.student_grad_updates = state_dict.get('student_grad_updates')

        self.level_sampler = state_dict.get('level_sampler')

        self.batch_episodic_counts = state_dict.get('batch_episodic_counts')

        self.belief_model = state_dict.get('belief_model')

    def _get_rollout_return_stats(self, rollout_returns):
        # @todo: need to record agent curricula-specific env metrics:
        # - shortest path length
        # - num blocks
        # - passable ratio
        mean_return = torch.zeros(self.args.num_processes, 1)
        max_return = torch.zeros(self.args.num_processes, 1)
        for b, returns in enumerate(rollout_returns):
            if len(returns) > 0:
                mean_return[b] = float(np.mean(returns))
                max_return[b] = float(np.max(returns))

        stats = {
            'mean_return': mean_return,
            'max_return': max_return,
            'returns': rollout_returns 
        }

        return stats

    def _get_env_stats_multigrid(self, agent_info, log_replay_complexity):
        clutter_count = self.venv.get_clutter_count()
        passable_ratio = self.venv.get_passable()
        shortest_path_lengths = self.venv.get_shortest_path_length()
        aux_properties = self.venv.get_aux_properties()

        stats = {}
        stats.update(flatten_suffix_dict(
                clutter_count, 
                'count'))
        stats.update(flatten_suffix_dict(
                passable_ratio, 
                'passable_ratio'))
        stats.update(flatten_suffix_dict(
                shortest_path_lengths, 
                'shortest_path_length'))
        stats.update(flatten_suffix_dict(
            aux_properties, None
        ))

        return stats

    def _get_env_stats_minihack(self, agent_info, log_replay_complexity=False):
        stats = self._get_env_stats_multigrid(agent_info, log_replay_complexity)
        
        return stats

    def _get_env_stats(self, agent_info, log_replay_complexity=False):
        env_name = self.args.env_name
        if env_name.startswith('MultiGrid'):
            stats = self._get_env_stats_multigrid(agent_info, log_replay_complexity)
        elif env_name.startswith('MiniHack'):
            stats = self._get_env_stats_minihack(agent_info, log_replay_complexity)

        stats_ = {}
        for k,v in stats.items():
            stats_['env_metrics/plr_' + k] = v if log_replay_complexity else None
            stats_['env_metrics/' + k] = v if not log_replay_complexity else None

        return stats_

    def _get_batch_episodic_stats(self):
        total = self.batch_episodic_counts['num_episodes']

        stats = {}
        for k,v in self.batch_episodic_counts.items():
            if k == 'num_episodes':
                continue
            stats[f'train/{k}_mean'] = float(v)/total

        return stats

    def _sample_replay_decision(self):
        return self.level_sampler.sample_replay_decision()

    def _init_belief_model(self, spec):
        self.belief_model = {}
        for i, (name,info) in enumerate(spec.items()):
            if info['type'] == 'categorical':
                self.belief_model[name] = {
                    'k': info['size'],
                    'counts': np.ones(info['size'], dtype=np.int),
                }
                if info['size'] == 1:
                    self.belief_model[name].update({'p':0.5})
                else:
                    self.belief_model[name].update({
                        'p':np.ones(info['size'])/info['size']
                    })
            else:
                raise ValueError(f"Unsupported r.v., {info['type']}")

    def _update_belief_model(self, sample):
        for i, (name,v) in enumerate(sample.items()):
            if self.belief_spec[name]['type'] == 'categorical':
                counts = self.belief_model[name]['counts']
                counts += v.sum(0)
                self.belief_model[name]['p'] = counts/counts.sum() 
            else:
                raise ValueError(f"Unsupported r.v., {token['type']}")

        belief_dists = {name: info['p'] for name, info in self.belief_model.items()}
        return belief_dists

    def _update_batch_episodic_counts(self, count_dict=None):
        if count_dict is None: # clear counts
            self.batch_episodic_counts = defaultdict(int)
            self.batch_episodic_counts['num_episodes'] = 0
        else:
            for k,v in count_dict.items():
                self.batch_episodic_counts[k] += v
            self.batch_episodic_counts['num_episodes'] += 1

    def agent_rollout(self, 
                      agent, 
                      num_steps, 
                      update=False,
                      batchwise_plr=False,
                      level_replay=False, # Only used for robust PLR
                      level_sampler=None,
                      discard_grad=False):
        args = self.args

        self._update_batch_episodic_counts()
        
        rollout_returns = [[] for _ in range(args.num_processes)]

        # If batchwise PLR, reset env to new samples
        if batchwise_plr:
            for i in range(args.num_processes):
                if level_replay: # sample replay levels + update level seeds
                    level_seed = level_sampler.sample_replay_level()
                else: # sample new levels + update level seeds
                    level_seed = level_sampler.sample_unseen_level()
                self.venv.seed(level_seed, i)
                self.level_seeds[i] = torch.tensor(level_seed)

            obs = self.venv.reset()
            self.agent.storage.copy_obs_to_index(obs,0)

            if self.use_learned_beliefs and not level_replay:
                belief_tokens = self.venv.get_belief_tokens()
                belief_dist = self._update_belief_model(sample=belief_tokens)

                self.venv.set_belief_dist(belief_dist)

        for step in range(num_steps):
            if args.render:
                self.venv.render_to_screen()

            # Sample actions
            with torch.no_grad():
                obs_id = agent.storage.get_obs(step)
                value, action, action_log_dist, recurrent_hidden_states = agent.act(
                    obs_id, agent.storage.get_recurrent_hidden_state(step), agent.storage.masks[step])
                if self.is_discrete_actions:
                    action_log_prob = action_log_dist.gather(-1, action)
                else:
                    action_log_prob = action_log_dist

            # Observe reward and next obs
            _action = agent.process_action(action.cpu())
            obs, reward, done, infos = self.venv.step(_action)

            if args.clip_reward:
                reward = torch.clamp(reward, -args.clip_reward, args.clip_reward)

            if step >= num_steps - 1:
                # Handle early termination due to cliffhanger rollout
                if agent.storage.use_proper_time_limits:
                    for i, done_ in enumerate(done):
                        if not done_:
                            infos[i]['cliffhanger'] = True
                            infos[i]['truncated'] = True
                            infos[i]['truncated_obs'] = get_obs_at_index(obs, i)

                done = np.ones_like(done, dtype=np.float)

            if level_sampler:
                next_level_seeds = self.level_seeds.clone().detach()
    
            for i, info in enumerate(infos):
                if 'episode' in info.keys():
                    rollout_returns[i].append(info['episode']['r'])

                    if not args.use_plr:
                        self.total_seeds_collected += 1

                    self.total_episodes_collected += 1
                    # print('total_episodes:', self.total_episodes_collected, flush=True)

                    if 'episodic_counts' in info:
                        self._update_batch_episodic_counts(info['episodic_counts'])

                    # Handle early termination
                    if agent.storage.use_proper_time_limits:
                        if 'truncated_obs' in info.keys():
                            truncated_obs = info['truncated_obs']
                            agent.storage.insert_truncated_obs(truncated_obs, index=i)

                    if level_sampler:
                        if batchwise_plr:
                            if level_replay:
                                level_seed = level_sampler.sample_replay_level()
                            else:
                                level_seed = level_sampler.sample_unseen_level()
                        else:
                            level_seed = level_sampler.sample()

                        obs_i = self.venv.seed(level_seed, i)
                        set_obs_at_index(obs, obs_i, i)
                        next_level_seeds[i] = level_seed # done step matched to done level seed

            # If done then clean the history of observations.
            masks = torch.FloatTensor(
                [[0.0] if done_ else [1.0] for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'truncated' in info.keys() else [1.0]
                 for info in infos])
            cliffhanger_masks = torch.FloatTensor(
                [[0.0] if 'cliffhanger' in info.keys() else [1.0]
                 for info in infos])

            agent.insert(
                obs, recurrent_hidden_states, 
                action, action_log_prob, action_log_dist, 
                value, reward, masks, bad_masks, 
                level_seeds=self.level_seeds)

            if level_sampler:
                self.level_seeds = next_level_seeds

        rollout_info = self._get_rollout_return_stats(rollout_returns)

        # Update non-env agent if required
        if update: 
            with torch.no_grad():
                obs_id = agent.storage.get_obs(-1)
                next_value = agent.get_value(
                    obs_id, agent.storage.get_recurrent_hidden_state(-1),
                    agent.storage.masks[-1]).detach()

            agent.storage.compute_returns(
                next_value, args.use_gae, args.gamma, args.gae_lambda)

            # Update level sampler and remove any ejected seeds level store
            if level_sampler:
                level_sampler.update_with_rollouts(agent.storage)
                # print('updated level sampler weights', level_sampler.seeds, level_sampler.seed_scores)

            value_loss, action_loss, dist_entropy, info = agent.update(discard_grad=discard_grad)
            # value_loss, action_loss, dist_entropy, info = 0,0,0,{}

            if level_sampler:
                # print('Updated level sampler')
                level_sampler.after_update()
                # print(level_sampler.seed_scores)

            rollout_info.update({
                'value_loss': value_loss,
                'action_loss': action_loss,
                'dist_entropy': dist_entropy,
                'update_info': info,
            })

            # Compute LZ complexity of action trajectories
            if args.log_action_complexity:
                rollout_info.update({'action_complexity': agent.storage.get_action_complexity()})

        return rollout_info

    def run(self):
        args = self.args
        agent = self.agent

        level_replay = False
        if args.use_plr:
            level_replay = self._sample_replay_decision()

        # Discard student gradients if not level replay (sampling new levels)
        student_discard_grad = False
        no_exploratory_grad_updates = \
            vars(args).get('no_exploratory_grad_updates', False)
        if args.use_plr and (not level_replay) and no_exploratory_grad_updates:
            student_discard_grad = True

        if self.is_training and not student_discard_grad:
            self.student_grad_updates += 1

        # print(f'level replay, discard grad: {level_replay, student_discard_grad}')

        # Agent rollout
        agent_info = self.agent_rollout(
            agent=agent, 
            num_steps=self.agent_rollout_steps,
            update=self.is_training,
            batchwise_plr=no_exploratory_grad_updates,
            level_replay=level_replay,
            level_sampler=self.level_sampler,
            discard_grad=student_discard_grad)

        if self.is_training:
            self.num_updates += 1

        # === LOGGING ===
        # Only update env-related stats when run generates new envs (not level replay)
        stats = self._get_env_stats(agent_info, log_replay_complexity=level_replay)

        [self.agent_returns.append(r) for b in agent_info['returns'] for r in reversed(b)]
        mean_agent_return = 0
        if len(self.agent_returns) > 0:
            mean_agent_return = np.mean(self.agent_returns)

        stats.update({
            'steps': (self.num_updates) * args.num_processes * args.num_steps,
            'total_episodes': self.total_episodes_collected,
            'total_seeds': self.total_seeds_collected,
            'total_student_grad_updates': self.student_grad_updates,

            'mean_agent_return': mean_agent_return,
            'agent_value_loss': agent_info['value_loss'],
            'agent_pg_loss': agent_info['action_loss'],
            'agent_dist_entropy': agent_info['dist_entropy'],

            **(self._get_batch_episodic_stats())
        })

        if args.log_grad_norm:
            agent_grad_norm = np.mean(agent_info['update_info']['grad_norms'])
            stats.update({
                'agent_grad_norm': agent_grad_norm,
            })

        if args.log_action_complexity:
            stats.update({
                'agent_action_complexity': agent_info['action_complexity'], 
            }) 

        return stats
