import datetime
import logging
import os
import time

import torch
import minerl
import numpy as np
from tqdm import tqdm
from torch.utils.data._utils.collate import default_collate

from train.common.recorder import VideoRecorder
from train.common.utils import import_module
from train.monitor.consts import UNINIT_EXPECTED_MEAN_REWARD
from train.proc.proc_base import Subprocess, repeated_run
from train.stats import Stats

from train.task import TaskType

minerl.__class__


def write_imitation_debug(info, action):
    logging.debug('Proc: {}'.format(info[0]['meta_controller']['state']['inventory']))
    logging.debug('Proc: craft: {}, nearbyCraft: {}, nearbySmelt: {}, place: {}, equip: {}'
                  .format(action['craft'],
                          action['nearbyCraft'],
                          action['nearbySmelt'],
                          action['place'],
                          action['equip']))
    logging.debug('Proc: mainhand: {}'
                  .format(info[0]['meta_controller']['state']['equipped_items']['mainhand']))


class EvaluateProc(Subprocess):
    def __init__(self, config, env, task_history, device):
        super(EvaluateProc, self).__init__('Evaluate', config)
        self.task_history = task_history
        self.env = env
        self.device = device
        self.option = self.config.client_env_option
        self.record_episodes = True if self.option == "all" or self.option == "record" else False
        self.uses_replay = True if self.option == "all" or self.option == "replay" else False

        dataset = import_module(self.config.rl.dataset)
        self.action_space = dataset.ACTION_SPACE

        self.total_reward = UNINIT_EXPECTED_MEAN_REWARD
        self.mean_reward = UNINIT_EXPECTED_MEAN_REWARD
        self.std_reward = UNINIT_EXPECTED_MEAN_REWARD
        self.min_reward = UNINIT_EXPECTED_MEAN_REWARD
        self.max_reward = UNINIT_EXPECTED_MEAN_REWARD
        self.experience_list = []

        self.task_mapping = {v: k for k, v in self.config.subtask.consensus_code.items()}

        self.encoding = {v: k for k, v in self.config.subtask.consensus_code.items()}
        self.recorder = VideoRecorder(self.config, rec_save_dir=self.config.monitor.record_path,
                                      recording=self.config.debug.enabled)

        self.trials = self.config.monitor.max_evaluation_episodes
        self.max_imitation_fail_retries = self.config.monitor.max_imitation_fail_retries

        self.max_eval_steps = np.inf
        if self.config.debug.enabled and self.config.monitor.max_eval_steps > 0:
            self.max_eval_steps = self.config.monitor.max_eval_steps

    def write_debug(self, step):
        if step % self.config.monitor.eval_step_debug_log == 0:
            logging.info('Proc: Passed step: {}'.format(step))

    def record_video_stats(self, info, cum_reward, action, step, task, value, stats, resources):
        self.recorder.append_frame(info[0]['meta_controller']['state']['pov'])
        self.recorder.append_reward(cum_reward)
        self.recorder.append_action(action)
        self.recorder.append_info('step: {}'.format(step))
        self.recorder.append_task('task: {}'.format(task.id))
        self.recorder.append_value(value)
        self.recorder.append_inventory({
            'inventory': info[0]['meta_controller']['state']['inventory'],
            'equipped_items': info[0]['meta_controller']['state']['equipped_items']
        })

        Stats.get_instance().add_resources(resources)
        g_stats = Stats.get_instance().copy()
        # prepare resources and progress stats
        stats['milestones'] = {}
        for k, v in g_stats['milestones'].items():
            stats['milestones'][k] = v / (g_stats['runs']['count'] + 1)

        self.recorder.append_statistics(stats)

    @repeated_run
    def run(self):
        # run multiple evaluation trials
        trial_rewards = []
        logging.info('Proc: Starting evaluation...')
        for i in tqdm(range(1, self.trials + 1), desc='TestEnv'):
            # reset agent environment
            state = self.env.reset()
            logging.info('Proc: Eval run {}'.format(i))
            # iterate until done or max steps reached
            cum_reward = 0
            milestone = 0
            step = 0
            repeat = 0
            done = (False, )

            sys_time = None
            found_diamond = False
            # send noops to wait for action completion
            noop = self.env.action_space.no_op()

            # get experiences from env
            if self.record_episodes:
                consensus_dir = ''
                experience = self.env.env_method('get_experience')[0]

            # skip first step to create valid meta controller state
            state, reward, done, info = self.env.step((noop,))
            # check if diamond was found
            if reward[0] >= 1024: found_diamond = True
            old_meta = None
            new_meta = info[0]['meta_controller']

            # if replay until is used, then skip directly to the newest milestone
            if self.uses_replay:
                milestone = len(self.config.replay.replay_until)

            # rollout task execution
            while not done[0] and milestone < len(self.task_history) and step < self.max_eval_steps:
                # get current task
                task = self.task_history[milestone]

                # update statistics
                progress_stats = {'progress': {}}
                progress_stats['progress']['completed'] = "".join(
                    [self.task_mapping[t.target] for t in self.task_history[:milestone]])
                progress_stats['progress']['pending'] = "".join(
                    [self.task_mapping[t.target] for t in self.task_history[milestone:]])
                progress_stats['progress']['current'] = self.task_mapping[task.target]

                # handle task type
                if task.task_type == TaskType.Learning:
                    # update old meta state
                    old_meta = new_meta
                    # prepare data for model
                    input_dict = default_collate(state)
                    for k in input_dict.keys():
                        input_dict[k] = input_dict[k].to(self.device)
                    with torch.no_grad():
                        # get model
                        model = task.get_model().to(self.device)
                        # set to evaluation mode
                        model.eval()
                        # predict next action
                        out_dict = model.forward(input_dict)
                        action = self.action_space.logits_to_dict(self.env.action_space.no_op(), out_dict)[0]
                    # take env step
                    state, reward, done, info = self.env.step((action,))
                    # check if diamond was found
                    if reward[0] >= 1024: found_diamond = True
                    # book keeping
                    cum_reward += reward[0]
                    # recording stuff
                    resources_diff = {}
                    for k in old_meta['state']['inventory'].keys():
                        resources_diff[k] = max(info[0]['meta_controller']['state']['inventory'][k] - old_meta['state']['inventory'][k], 0)
                    self.record_video_stats(info, cum_reward, action, step, task, out_dict['value'].item(), progress_stats, resources_diff)
                elif task.task_type == TaskType.Imitation:
                    # get the sequence of imitation actions
                    action_sequence = task.get_imitation_actions(new_meta['state'])
                    # update old meta state
                    old_meta = new_meta
                    if len(action_sequence) <= 0:
                        milestone += 1
                        logging.warning('Proc: No actions to perform. Skipping')

                    for action in action_sequence:
                        # take env step
                        state, reward, done, info = self.env.step((action,))
                        # check if diamond was found
                        if reward[0] >= 1024: found_diamond = True
                        # book keeping
                        cum_reward += reward[0]
                        # recording stuff
                        resources_diff = {}
                        for k in old_meta['state']['inventory'].keys():
                            resources_diff[k] = max(info[0]['meta_controller']['state']['inventory'][k] - old_meta['state']['inventory'][k], 0)
                        self.record_video_stats(info, cum_reward, action, step, task, 0, progress_stats, resources_diff)
                        for _ in range(self.config.subtask.num_imitation_noop):
                            tmp_old_meta = info[0]['meta_controller']
                            state, reward, done, info = self.env.step((noop,))
                            # check if diamond was found
                            if reward[0] >= 1024: found_diamond = True
                            # book keeping
                            cum_reward += reward[0]
                            # recording stuff
                            resources_diff = {}
                            for k in tmp_old_meta['state']['inventory'].keys():
                                resources_diff[k] = max(info[0]['meta_controller']['state']['inventory'][k] - tmp_old_meta['state']['inventory'][k], 0)
                            self.record_video_stats(info, cum_reward, action, step, task, 0, progress_stats, resources_diff)
                        # debug imitation entries
                        write_imitation_debug(info, action)
                        # update meta information
                        new_meta = info[0]['meta_controller']
                        # check if success through intermediate steps was reached
                        if task.success(old_meta, new_meta):
                            task.imitation_fail_cnt = 0
                            logging.info('Proc: Early imitation success. Skip remaining actions.')
                            break
                else:
                    raise NotImplementedError('No such task type defined: {}'.format(task.task_type))

                # update meta information
                new_meta = info[0]['meta_controller']

                # check if task has completed
                success, inc_steps = task.success(old_meta, new_meta, [t for t in self.task_history[milestone:] if t.target == task.target])
                if success or found_diamond:
                    if task.task_type == TaskType.Imitation:
                        task.imitation_fail_cnt = 0
                    # go to next milestone
                    milestone += 1
                    Stats.get_instance().incr_milestone("".join(
                        [self.task_mapping[t.target] for t in self.task_history[:milestone]]))
                    logging.info('Proc: Eval milestone {} complete! Progress to next task...'.format(task.id))

                    # set checkpoint for experiences trajectory recording
                    if self.record_episodes:
                        sys_time = datetime.datetime.fromtimestamp(time.time()).strftime('%Y%m%d-%H%M%S')
                        consensus_dir = ''.join([self.encoding[t.target] for t in self.task_history[:milestone]])
                        experience.set_checkpoint(checkpoint='{}.{}'.format(sys_time, consensus_dir))

                    # notify transition and handle task or memory specific cleanup
                    if milestone < len(self.task_history):
                        task.notify_transition(self.task_history[milestone])
                    repeat = 0
                else:
                    repeat += 1
                    # rewind rollout back to last task
                    if task.task_type == TaskType.Learning:
                        if repeat > task.median_task_steps * self.config.monitor.max_steps_factor:
                            logging.warning('Proc: Maximum agent step trials reached. Skipping learning task...')
                            milestone += 1
                            repeat = 0
                    elif task.task_type == TaskType.Imitation:
                        logging.warning('Proc: Failed to imitate sequence. Rollout rewinds back to initial task.')
                        rewind_idx = milestone
                        # move back or stay at the first task
                        if rewind_idx > 0:
                            prev_milestone = milestone - 1
                            # go back until previous learning statement appears
                            while prev_milestone > 0 \
                                    and self.task_history[prev_milestone].task_type == TaskType.Imitation:
                                prev_milestone -= 1
                            rewind_idx = prev_milestone
                            # move back in time until beginning of the sub-sequence
                            while rewind_idx > 0 and \
                                    self.task_history[rewind_idx].target == self.task_history[prev_milestone].target:
                                rewind_idx -= 1
                            # move forward to the first entry of to the next learning
                            # statement at the starting of the sub-sequence
                            if self.task_history[rewind_idx].id != self.task_history[prev_milestone].id:
                                rewind_idx += 1
                            milestone = rewind_idx
                            task.imitation_fail_cnt += 1
                            task.imitation_ready = False
                        logging.warning('Proc: Rewinding task back to index: {}'.format(milestone))

                self.write_debug(step)
                step += 1

            # store experience if successful
            if self.record_episodes and ((milestone > 0 and consensus_dir != '') or found_diamond):
                # do two noops to wait for resources to be collected
                self.env.step((noop,))
                self.env.step((noop,))
                dir_name = os.path.join(self.config.monitor.experiences_path, consensus_dir)
                if not os.path.exists(dir_name):
                    os.makedirs(dir_name)
                file_name = os.path.join(dir_name, '{}_experience_recorder-env_seed_{}-timestamp_{}.p'
                                         .format(self.config.experiment_name, experience.meta_info, sys_time))
                experience.save(file_name)

            # update global stats
            Stats.get_instance().incr_runs_count()

            # write recording
            self.recorder.write('{}_rew-{}_milestones-[{}]'
                                .format(self.config.experiment_name,
                                        cum_reward,
                                        "".join([self.task_mapping[t.target] for t in self.task_history[:milestone]])),
                                sys_time=sys_time)
            self.recorder.reset()

            # print result stats
            trial_rewards.append(cum_reward)

        # collect metrics
        logging.info('Proc: Finished evaluation. Writing stats...')
        self.total_reward = sum(trial_rewards)
        self.mean_reward = np.mean(trial_rewards)
        self.std_reward = np.std(trial_rewards)
        self.min_reward = min(trial_rewards)
        self.max_reward = max(trial_rewards)
        Stats.get_instance().summarize_reward({
            'total_reward': self.total_reward,
            'mean_reward': self.mean_reward,
            'std_reward': self.std_reward,
            'min_reward': self.min_reward,
            'max_reward': self.max_reward
        })
        logging.info('Stats:\nTotal reward: {}\nMean reward: {}\nStd. reward: {}\nMin. reward: {}\nMax. reward: {}'
                     .format(self.total_reward,
                             self.mean_reward,
                             self.std_reward,
                             self.min_reward,
                             self.max_reward))

        self._complete()
