import os, minerl
import numpy as np
import imageio
import time
import torch
from PIL import Image, ImageDraw, ImageFont
from torch.utils.data._utils.collate import default_collate
from tqdm import tqdm
from train.behavioral_cloning.datasets.agent_state import AgentState
from train.behavioral_cloning.spaces.action_spaces import BINARY_ACTIONS, ENUM_ACTIONS, ActionSpace
from train.behavioral_cloning.spaces.input_spaces import InputSpace
from train.envs.minerl_env import make_minerl
from train.pytorch_wrapper.eval_hook import EvalHook
from train.pytorch_wrapper.utils import BColors
from train.behavioral_cloning.datasets.experience import Experience
import logging
logging.getLogger("PIL").setLevel(logging.INFO)
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt

minerl.__class__

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

bcols = BColors()


class EvalEnv(EvalHook):
    def __init__(self, env, n_workers, trials: int = 10, max_steps: int = 1200, verbosity: int = 0,
                 watch_item: str = None, recording_dir: str = None):
        super().__init__(verbosity=verbosity)
        self.env = env
        self.n_workers = n_workers
        self.trials = trials
        self.max_steps = max_steps
        self.watch_item = watch_item
        self.recording_dir = recording_dir
        self.action_space = None

    def set_action_space(self, action_space):
        self.action_space = action_space

    def eval_every_k(self):
        return 20

    def __call__(self, model, **kwargs):
        if self.env is None:
            return {"total_reward": 0.0,
                    "mean_reward": 0.0,
                    "std_reward": 0.0,
                    "min_reward": 0.0,
                    "max_reward": 0.0,
                    "total_item": 0.0,
                    "mean_item": 0.0,
                    "std_item": 0.0,
                    "min_item": 0.0,
                    "max_item": 0.0}
        self.model = model
        return self.run()

    def run(self):
        # run multiple evaluation trials
        trial_rewards = []
        trial_watch_item = []
        for trial in tqdm(range(1, int(self.trials / self.n_workers) + 1), desc='EvalEnv'):
            self.log(bcols.print_colored("\nEvaluation trial %02d / %02d" % (trial, self.trials / self.n_workers),
                                         color=bcols.OKGREEN), log_level=1)
            state = self.env.reset()

            # iterate until done or max steps reached
            cum_reward = [0 for _ in range(self.n_workers)]
            inventory_watcher = [[] for _ in range(self.n_workers)]
            equip_watcher = [[] for _ in range(self.n_workers)]
            frame_recording = [[] for _ in range(self.n_workers)]
            action_recording = [[] for _ in range(self.n_workers)]
            value_recording = [[] for _ in range(self.n_workers)]
            action_log = np.zeros((self.n_workers, len(BINARY_ACTIONS), self.max_steps), dtype=np.float32)
            for step in range(0, self.max_steps):
                self.log(bcols.print_colored("\n--- step %05d ---\n" % (step + 1), bcols.WARNING), log_level=2)

                # prepare data for model
                input_dict = default_collate(state)

                for k in input_dict.keys():
                    input_dict[k] = input_dict[k].to(DEVICE)

                # predict next action
                out_dict = self.model.forward(input_dict)
                action = self.action_space.logits_to_dict(self.env.action_space.no_op(), out_dict)

                for a in action:
                    if self.verbosity >= 2:
                        for i, ba in enumerate(BINARY_ACTIONS):
                            txt = "%s (%d)" % (ba.rjust(10, " "), a[ba])
                            if a[ba]:
                                txt = bcols.print_colored(txt, bcols.OKBLUE)
                            print(txt)
                        print("camera".rjust(10, " "), np.around(a["camera"], 2))

                # take env step
                state, reward, done, info = self.env.step(action)
                # sequential_state.update(state[0], action, reward[0], done[0], info[0])

                # book keeping
                for i in range(len(state)):
                    if "inventory" in info[i]["meta_controller"]["state"]:
                        inventory = info[i]["meta_controller"]["state"]["inventory"]
                        inventory_watcher[i].append(inventory)
                    if "equipped_items" in info[i]["meta_controller"]["state"]:
                        equip_watcher[i].append(
                            info[i]["meta_controller"]["state"]["equipped_items"]["mainhand"]["maxDamage"])
                    cum_reward[i] += reward[i]
                    pov = np.moveaxis(state[i]['pov'][-1], 0, -1)
                    pov = ((pov - pov.min()) * 255 / (pov.max() - pov.min())).astype(np.uint8)
                    frame_recording[i].append(pov)
                    action_recording[i].append(action[i])
                    if 'value' in out_dict:
                        value_recording[i].append(out_dict['value'].cpu().numpy()[i])
                    action_log[i, :, step] = np.asarray([action[i][ba] for ba in BINARY_ACTIONS])

            # print result stats
            trial_rewards.extend(cum_reward)

            # item watcher
            if self.watch_item:
                for i, inv in enumerate(inventory_watcher):
                    item_count = 0
                    for j in range(len(inv)):
                        cur = inv[j]
                        old = inv[j - 1]
                        if self.watch_item in cur and cur[self.watch_item] > old[self.watch_item]:
                            item_count += cur[self.watch_item] - old[self.watch_item]
                    trial_watch_item.append(item_count)
                    self.log(bcols.print_colored("\nGathered '%s': %d" % (self.watch_item, item_count), bcols.OKBLUE),
                             log_level=1)

            for r in cum_reward:
                self.log(bcols.print_colored("\nTotal Episode Reward: %.1f" % r, bcols.FAIL), log_level=1)

            # write video
            if self.recording_dir and len(frame_recording) > 0:
                self.write_video(frame_recording, value_recording, action_recording, action_log, inventory_watcher,
                                 equip_watcher, cum_reward)

                # print result stats
        self.log(
            bcols.print_colored("\nTotal Reward of %d trials: %.1f" % (self.trials, sum(trial_rewards)),
                                bcols.FAIL), log_level=1)

        # return dict
        result_dict = {"total_reward": sum(trial_rewards),
                       "mean_reward": np.mean(trial_rewards),
                       "std_reward": np.std(trial_rewards),
                       "min_reward": min(trial_rewards),
                       "max_reward": max(trial_rewards)}

        if len(trial_watch_item) > 0:
            result_dict.update({"total_item": sum(trial_watch_item),
                                "mean_item": np.mean(trial_watch_item),
                                "std_item": np.std(trial_watch_item),
                                "min_item": min(trial_watch_item),
                                "max_item": max(trial_watch_item)})

        return result_dict

    def write_video(self, frame_recording, value_recording, action_recording, action_log, inventory_recording,
                    equipment_recording, cum_reward):
        # save video
        for v in range(len(frame_recording)):
            frames = []
            for i, frame in enumerate(frame_recording[v]):
                canvas = Image.new(mode="RGB", size=(400, 240))
                # paste frame
                canvas.paste(Image.fromarray(frame_recording[v][i].astype(np.uint8), "RGB"), box=(20, 20))
                # init drawing
                draw = ImageDraw.Draw(canvas)
                font = ImageFont.load_default()

                # value
                if len(value_recording[v]) > i:
                    draw.text((100, 140), "Value: {:.4f}".format(value_recording[v][i][0]), (255, 255, 255), font=font)

                # actions
                ba_offset, ea_offset = 80, 8
                for j, act in enumerate(BINARY_ACTIONS):
                    draw.text((20, ba_offset + (j + 1) * 12), "{} {}".format(action_recording[v][i][act], act),
                              (255, 255, 255), font=font)
                for j, act in enumerate(ENUM_ACTIONS):
                    draw.text((100, ea_offset + (j + 1) * 12), "{}: {}".format(act, action_recording[v][i][act]),
                              (255, 255, 255), font=font)
                draw.text((100, 92), "{}\n{}".format("Camera", action_recording[v][i]["camera"]), (255, 255, 255),
                          font=font)

                # inventory
                inv_offset = 0
                for j, (name, cnt) in enumerate(inventory_recording[v][i].items()):
                    draw.text((280, inv_offset + (j + 1) * 12), "{}: {}".format(name, cnt), (255, 255, 255), font=font)

                # equipped item
                draw.text((100, 160), "Mainhand: {}".format(equipment_recording[v][i]), (255, 255, 255), font=font)

                frames.append(np.asarray(canvas))

            frames = np.stack(frames, 0)
            timestamp = time.strftime("%Y%m%d%H%M%S")
            video_file = "recording-rew{}-{}.mp4".format(cum_reward[v], timestamp)
            video_path = os.path.join(self.recording_dir, video_file)
            try:
                imageio.mimwrite(video_path, frames, fps=20)
            except Exception as e:
                print("Unable to write video: {}".format(e))
