import os, minerl
import numpy as np
import imageio
import time
import torch
from PIL import Image, ImageDraw, ImageFont
from torch.utils.data._utils.collate import default_collate
from tqdm import tqdm
from train.behavioral_cloning.datasets.agent_state import AgentState
from train.behavioral_cloning.spaces.action_spaces import BINARY_ACTIONS, ENUM_ACTIONS, ActionSpace
from train.behavioral_cloning.spaces.input_spaces import InputSpace
from train.envs.minerl_env import make_minerl
from train.pytorch_wrapper.eval_hook import EvalHook
from train.pytorch_wrapper.utils import BColors
from train.behavioral_cloning.datasets.experience import Experience
import matplotlib
from train.common.image import ImageCorrection

matplotlib.use("agg")
import matplotlib.pyplot as plt

minerl.__class__

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

bcols = BColors()


class EvalEnv(EvalHook):
    def __init__(self, env_name: str = None,
                 input_space: InputSpace = None, action_space: ActionSpace = None,
                 trials: int = 10, max_steps: int = 1200, seq_len: int = 32, transforms=None, record_dir: str = None,
                 argstr: str = None, seed: int = None, verbosity: int = 0, experience_recording=None,
                 checkpoint: str = None, experience_seed: int = None, frame_skip: int = 2, env_server: bool = False,
                 server_port: int = 9999, n_workers: int = 1, watch_item: str = None,  record_video=True,
                 task_based_reward=False, task_to_check=None):
        super().__init__(verbosity=verbosity)
        self.env_name = env_name
        self.input_space = input_space
        self.action_space = action_space
        self.trials = trials
        self.max_steps = max_steps
        self.seq_len = seq_len
        self.transforms = transforms
        self.record_dir = record_dir
        self.arg_str = argstr
        self.seed = seed
        self.experience_recording = experience_recording
        self.checkpoint = checkpoint
        self.experience_seed = experience_seed
        self.frame_skip = frame_skip
        self.env_server = env_server
        self.n_workers = n_workers
        self.server_port = server_port
        self.watch_item = watch_item
        self.task_based_reward = task_based_reward
        self.task_to_check = task_to_check
        if trials % n_workers != 0:
            raise Exception("Number of trials must be divisible by n_workers!")
        if experience_recording:
            self.experience_recording = Experience().load(experience_recording)

    def eval_every_k(self):
        return 10

    def run(self, model, env, **kwargs):

        # run multiple evaluation trials
        trial_rewards = []
        item_increase = []
        for trial in tqdm(range(1, int(self.trials / self.n_workers) + 1), desc='EvalEnv'):
            self.log(bcols.print_colored("\nEvaluation trial %02d / %02d" % (trial, self.trials / self.n_workers),
                                         color=bcols.OKGREEN), log_level=1)
            # reset environment (TODO: seed every trial or only initial state?)
            if self.seed and not self.experience_seed:
                # env.seed(np.random.randint(1, 99999))
                pass

            state = env.reset()

            # iterate until done or max steps reached
            cum_reward = [0 for _ in range(self.n_workers)]
            inventory_watcher = [[] for _ in range(self.n_workers)]
            last_info = None
            equip_watcher = [[] for _ in range(self.n_workers)]
            frame_recording = [[] for _ in range(self.n_workers)]
            action_recording = [[] for _ in range(self.n_workers)]
            value_recording = [[] for _ in range(self.n_workers)]
            action_log = np.zeros((self.n_workers, len(BINARY_ACTIONS), self.max_steps), dtype=np.float32)
            for step in range(0, self.max_steps):
                self.log(bcols.print_colored("\n--- step %05d ---\n" % (step + 1), bcols.WARNING), log_level=2)

                # compile current observation
                # pov, binary_actions, camera_actions, rewards = state,
                # input_dict = self.input_space.prepare(pov, binary_actions, camera_actions, rewards)

                # prepare data for model
                input_dict = default_collate(state)

                for k in input_dict.keys():
                    input_dict[k] = input_dict[k].to(DEVICE)

                # predict next action
                out_dict = model.forward(input_dict)
                action = self.action_space.logits_to_dict(env.action_space.no_op(), out_dict)

                for a in action:
                    if self.verbosity >= 2:
                        for i, ba in enumerate(BINARY_ACTIONS):
                            txt = "%s (%d)" % (ba.rjust(10, " "), a[ba])
                            if a[ba]:
                                txt = bcols.print_colored(txt, bcols.OKBLUE)
                            print(txt)
                        print("camera".rjust(10, " "), np.around(a["camera"], 2))

                # take env step
                state, reward, done, info = env.step(action)
                # sequential_state.update(state[0], action, reward[0], done[0], info[0])

                # book keeping
                for i in range(len(state)):
                    if "inventory" in info[i]["meta_controller"]["state"]:
                        inventory = info[i]["meta_controller"]["state"]["inventory"]
                        inventory_watcher[i].append(inventory)
                    if "equipped_items" in info[i]["meta_controller"]["state"]:
                        equip_watcher[i].append(
                            info[i]["meta_controller"]["state"]["equipped_items"]["mainhand"]["maxDamage"])

                    if self.task_based_reward and self.task_to_check == 'diamond':
                        for i, r in enumerate(reward):
                            if reward[i] >= 1000:
                                cum_reward[i] += reward[i]
                    elif self.task_based_reward and last_info is not None and reward[i] < 1000: # smaller than diamond
                        if "inventory" in info[i]["meta_controller"]["state"]:
                            # give only reward if the task to monitor has occurred
                            diff = info[i]["meta_controller"]["state"]["inventory"][self.task_to_check] - last_info[i]["meta_controller"]["state"]["inventory"][self.task_to_check]
                            if diff > 0:
                                cum_reward[i] += 1
                        else:
                            cum_reward[i] += reward[i]
                    else:
                        cum_reward[i] += reward[i]

                    pov = np.moveaxis(state[i]['pov'][-1], 0, -1)
                    pov = ((pov - pov.min()) * 255 / (pov.max() - pov.min())).astype(np.uint8)
                    frame_recording[i].append(pov)
                    action_recording[i].append(action[i])
                    if 'value' in out_dict:
                        value_recording[i].append(out_dict['value'].cpu().detach().numpy()[i])
                    action_log[i, :, step] = np.asarray([action[i][ba] for ba in BINARY_ACTIONS])

                last_info = info

            try:
                # write video
                if self.record_dir and len(frame_recording) > 0:
                    self.write_video(frame_recording, value_recording, action_recording, action_log, inventory_watcher,
                                     equip_watcher, cum_reward)
            except Exception as e:
                print(f"EvalHook: Caught exception during video writing: {e}")

    def write_video(self, frame_recording, value_recording, action_recording, action_log, inventory_recording,
                    equipment_recording, cum_reward):
        # save video
        for v in range(len(frame_recording)):
            frames = []
            for i, frame in enumerate(frame_recording[v]):
                canvas = Image.new(mode="RGB", size=(400, 240))
                # paste frame
                img = Image.fromarray(frame_recording[v][i].astype(np.uint8), "RGB")
                img = ImageCorrection.process(img)
                canvas.paste(img, box=(20, 20))
                # init drawing
                draw = ImageDraw.Draw(canvas)
                font = ImageFont.load_default()

                # value
                if len(value_recording[v]) > i:
                    draw.text((100, 140), "Value: {:.4f}".format(value_recording[v][i][0]), (255, 255, 255), font=font)

                # actions
                ba_offset, ea_offset = 80, 8
                for j, act in enumerate(BINARY_ACTIONS):
                    draw.text((20, ba_offset + (j + 1) * 12), "{} {}".format(action_recording[v][i][act], act),
                              (255, 255, 255), font=font)
                for j, act in enumerate(ENUM_ACTIONS):
                    draw.text((100, ea_offset + (j + 1) * 12), "{}: {}".format(act, action_recording[v][i][act]),
                              (255, 255, 255), font=font)
                draw.text((100, 92), "{}\n{}".format("Camera", action_recording[v][i]["camera"]), (255, 255, 255),
                          font=font)

                # inventory
                if len(inventory_recording) > 0 and len(inventory_recording[0]) > 0:
                    inv_offset = 0
                    for j, (name, cnt) in enumerate(inventory_recording[v][i].items()):
                        draw.text((280, inv_offset + (j + 1) * 12), "{}: {}".format(name, cnt), (255, 255, 255), font=font)

                # equipped item
                if len(equipment_recording) > 0 and len(equipment_recording[0]) > 0:
                    draw.text((100, 160), "Mainhand: {}".format(equipment_recording[v][i]), (255, 255, 255), font=font)

                frames.append(np.asarray(canvas))

            if len(frames) <= 0:
                print('EvalHook: Evaluation video was empty. Skipping this env recoding.')
                continue

            frames = np.stack(frames, 0)
            timestamp = time.strftime("%Y%m%d%H%M%S")
            video_file = "recording-{}-rew{}-{}-{}.mp4".format(self.env_name, cum_reward[v], timestamp, self.arg_str)
            video_path = os.path.join(self.record_dir, video_file)
            imageio.mimwrite(video_path, frames, fps=20)

            # save action history
            if False:
                plt.figure("action_hist", figsize=(60, 10))
                plt.clf()
                plt.imshow(action_log[v], interpolation="nearest", aspect="auto", vmin=0, vmax=1)
                ticks = list(range(len(BINARY_ACTIONS)))
                labels = ["%s [%d]" % (BINARY_ACTIONS[t], t) for t in ticks]
                plt.yticks(ticks, labels, fontsize=24)
                plt.ylim([7.5, -0.5])
                image_file = "action_log-{}-rew{}-{}-{}.png".format(self.env_name, cum_reward[v], timestamp,
                                                                    self.arg_str)
                image_path = os.path.join(self.record_dir, image_file)
                plt.savefig(image_path)
