import datetime
import logging
import pathlib
import time
import os
import imageio
import numpy as np
from PIL import Image, ImageDraw, ImageFont

from train.common.image import ImageCorrection
from train.stats import Stats


class VideoRecorder:
    def __init__(self, config, rec_save_dir=None, recording=False, video_size=(128, 128), **kwargs):
        if rec_save_dir is None:
            recording = False
            print('Recording disabled due to None save_dir')
        elif rec_save_dir is not None and recording:
            pathlib.Path(rec_save_dir).mkdir(parents=True, exist_ok=True)
            print('Recording enabled')
        else:
            print('Recording disabled')
        self.config = config
        self.rec_save_dir = rec_save_dir
        self.recording = recording
        self.video_size = video_size
        self.frames = None
        self.infos = None
        self.tasks = None
        self.rewards = None
        self.actions = None
        self.values = None
        self.inventories = None
        self.statistics = None
        self.reset()

    def reset(self):
        logging.info('Recorder: Resetting video buffer!')
        self.frames = []
        self.infos = []
        self.tasks = []
        self.rewards = []
        self.actions = []
        self.values = []
        self.inventories = []
        self.statistics = []

    def append_info(self, info):
        if self.recording:
            self.infos.append(str(info))

    def append_task(self, task):
        if self.recording:
            self.tasks.append(str(task))

    def append_reward(self, reward):
        if self.recording:
            self.rewards.append(str(reward))

    def append_value(self, value):
        if self.recording:
            self.values.append(value)

    def append_action(self, action):
        if self.recording:
            self.actions.append(action)

    def append_inventory(self, inventory):
        if self.recording:
            self.inventories.append(inventory)

    def append_statistics(self, statistics):
        if self.recording:
            self.statistics.append(statistics)

    def append_frame(self, frame):
        if self.recording:
            # requires format size x size x color channel
            self.frames.append(frame)

    def append_frames(self, frames):
        if self.recording:
            for i in range(np.shape(frames)[0]):
                self.append_frame(frames[i, ...])

    def write(self, name_tag: str = None, sys_time: str = None):
        if self.recording and len(self.frames) > 0:
            logging.info('Recorder: Writing video recording...')
            if sys_time is None:
                sys_time = datetime.datetime.fromtimestamp(time.time()).strftime('%Y%m%d-%H%M%S')
            if name_tag is not None:
                sys_time += "-{}".format(name_tag)
            name = os.path.join(self.rec_save_dir, 'rec-{}'.format(sys_time))
            video_frames = self.get_values_video(self.frames)
            # if len(self.actions) > 0:
            #     video_frames = self.get_binaries_video(video_frames)
            imageio.mimwrite('{}.mp4'.format(name), video_frames, fps=30)

    def get_values_video(self, frames):
        video = []
        for i, frame in enumerate(frames):
            canvas = Image.new(mode="RGB", size=(780, 320))
            # paste frame
            img = Image.fromarray(frame.astype(np.uint8), "RGB")
            img = ImageCorrection.process(img)
            canvas.paste(img, box=(20, 20))
            # init drawing
            draw = ImageDraw.Draw(canvas)
            font = ImageFont.load_default()

            # value and reward
            re_offset, last_idx = 0, 0
            last_idx += 1
            draw.text((110, re_offset + last_idx * 12), "[Current Run]", (255, 255, 255), font=font)
            value = self.values[i] if i < len(self.values) else "n/a"
            last_idx += 1
            draw.text((110, re_offset + last_idx * 12), "Value: {:.4f}".format(value), (255, 255, 255), font=font)
            reward = self.rewards[i] if i < len(self.rewards) else "n/a"
            last_idx += 1
            draw.text((110, re_offset + last_idx * 12), "Reward: {}".format(reward), (255, 255, 255), font=font)

            # actions
            ba_offset, last_idx = 96, 0
            action = self.actions[i] if i < len(self.actions) else {}
            last_idx += 1
            draw.text((20, ba_offset + last_idx * 12), "[Actions]", (255, 255, 255), font=font)
            for j, (name, cnt) in enumerate(action.items()):
                last_idx += 1
                draw.text((20, ba_offset + last_idx * 12),
                          "{}: {}".format(name, str(cnt)), (255, 255, 255), font=font)

            # inventory
            inv_offset, last_idx = 0, 0
            inventory = self.inventories[i] if i < len(self.inventories) else {}
            last_idx += 1
            draw.text((215, inv_offset + last_idx * 12), "[Mainhand]", (255, 255, 255), font=font)
            for j, (name, cnt) in enumerate(inventory['equipped_items']['mainhand'].items()):
                last_idx += 1
                draw.text((215, inv_offset + last_idx * 12),
                          "{}: {}".format(name, str(cnt)), (255, 255, 255), font=font)

            last_idx += 2
            draw.text((215, inv_offset + last_idx * 12), "[Inventory]", (255, 255, 255), font=font)
            for j, (name, cnt) in enumerate(inventory['inventory'].items()):
                last_idx += 1
                draw.text((215, inv_offset + last_idx * 12),
                          "{}: {}".format(name, str(cnt)), (255, 255, 255), font=font)

            # evaluation statistics
            g_stats = Stats.get_instance().copy()
            pro_offset, last_idx = 0, 0
            stats = self.statistics[i] if i < len(self.statistics) else {}
            last_idx += 1
            draw.text((350, pro_offset + last_idx * 12), "[Current Progress]", (255, 255, 255), font=font)
            last_idx += 1
            t = stats['progress']['completed']
            val = stats['milestones'][t] if t in stats['milestones'] else 1.00
            draw.text((350, pro_offset + last_idx * 12),
                      "completed_task: P({}) = {:.3f}".format(stats['progress']['completed'], val),
                      (255, 255, 255), font=font)
            last_idx += 1
            draw.text((350, pro_offset + last_idx * 12),
                      "pending_task: {}".format(stats['progress']['pending']), (255, 255, 255), font=font)
            last_idx += 1
            t = stats['progress']['current']
            t_name = self.config.subtask.consensus_code[t]
            draw.text((350, pro_offset + last_idx * 12),
                      "current_task: {} ({}) | seed: {}".format(t, t_name, g_stats['current']['seed']),
                      (255, 255, 255), font=font)

            last_idx += 2
            draw.text((350, pro_offset + last_idx * 12), "[Global Resource Freq]", (255, 255, 255), font=font)
            for j, (name, freq) in enumerate(g_stats['resources'].items()):
                last_idx += 1
                draw.text((350, pro_offset + last_idx * 12),
                          "{}: {:.2f}".format(name, freq / (g_stats['runs']['count'] + 1)), (255, 255, 255), font=font)

            # write global reward stats
            stats_offset, last_idx = 60, 0
            last_idx += 1
            draw.text((520, stats_offset + last_idx * 12), "[Global Stats]", (255, 255, 255), font=font)
            last_idx += 1
            draw.text((520, stats_offset + last_idx * 12),
                      "furthest_task:", (255, 255, 255), font=font)
            last_idx += 1
            t = [(k, v) for k, v in g_stats['milestones'].items() if v != 0]
            val = "P({}) = {:.3f}".format(t[-1][0], t[-1][1]/(g_stats['runs']['count'])) if len(t) != 0 else "P() = {:.3f}".format(1.0)
            draw.text((520, stats_offset + last_idx * 12),
                      val, (255, 255, 255), font=font)
            last_idx += 1
            draw.text((520, stats_offset + last_idx * 12),
                      "completed_runs: {}".format(g_stats['runs']['count']), (255, 255, 255), font=font)
            last_idx += 1
            draw.text((520, stats_offset + last_idx * 12),
                      "total_reward: {}".format(g_stats['runs']['total_reward']), (255, 255, 255), font=font)
            last_idx += 1
            val = g_stats['runs']['mean_reward']
            draw.text((520, stats_offset + last_idx * 12),
                      "mean_reward: {:.2f}".format(val if val is not None else 0.0), (255, 255, 255), font=font)
            last_idx += 1
            val = g_stats['runs']['std_reward']
            draw.text((520, stats_offset + last_idx * 12),
                      "std_reward: {:.2f}".format(val if val is not None else 0.0), (255, 255, 255), font=font)
            last_idx += 1
            val = g_stats['runs']['min_reward']
            draw.text((520, stats_offset + last_idx * 12),
                      "min_reward: {:.2f}".format(val if val is not None else 0.0), (255, 255, 255), font=font)
            last_idx += 1
            val = g_stats['runs']['max_reward']
            draw.text((520, stats_offset + last_idx * 12),
                      "max_reward: {:.2f}".format(val if val is not None else 0.0), (255, 255, 255), font=font)

            video.append(np.asarray(canvas))

        return np.stack(video, axis=0)
