import pickle
import os
from collections import OrderedDict
import matplotlib.pyplot as plt
import gzip
import numpy as np
import imageio
from PIL import Image, ImageDraw, ImageFont
import time
import datetime
import json
from train.common.image import ImageCorrection
from tqdm import tqdm


class SeedParser:
    @staticmethod
    def parse_seed_list(path):
        seeds = []
        for filename in os.listdir(path):
            if filename.endswith(".p"):
                f_names = filename.split('-')
                seed_name = f_names[2].split('_')
                seeds.append(seed_name[2])
        return seeds

    @staticmethod
    def save(filename, seeds):
        with open(filename, 'wb') as file:
            pickle.dump(seeds, file)

    @staticmethod
    def load(filename):
        with open(filename, 'rb') as file:
            seeds = pickle.load(file)
        return seeds


class Experience(object):
    """
    Model for collecting and replaying trajectories (experience).
    """

    def __init__(self, meta_info: dict = None):
        """

        :param meta_info: can be used to store the seed or env used.
        """
        self.meta_info = meta_info

        self.states = []
        self.actions = []
        self.transitions_info = []
        self.checkpoints = OrderedDict()
        self.corrupted = False

    def append(self, state, action, reward, done, info):
        """ append transition """
        self.append_state(state)
        self.append_action_and_info(action, reward, done, info)

    def append_state(self, state):
        """ append state
        (might not be required all the time, be careful this can use a lot of memory and disc space) """
        self.states.append(state)

    def append_action_and_info(self, action, reward=None, done=None, info=None):
        """ append action and info """
        self.actions.append(action)
        self.transitions_info.append({"reward": reward, "done": done, "info": info})

    def set_checkpoint(self, checkpoint):
        """ add experience checkpoint """
        self.checkpoints[checkpoint] = len(self.actions)

    def replay_action_generator(self, checkpoint=None):
        """ replay trajectory actions until certain checkpoint is reached """

        if checkpoint in self.checkpoints:
            last_step = self.checkpoints[checkpoint]
        else:
            last_step = len(self.actions)

        for i, action in enumerate(self.actions[:last_step]):
            yield action

    def replay_on_env(self, env, checkpoint=None, noop_action=None):
        """ replay trajectory actions on provided env until certain checkpoint is reached.
        returns env after taking all trajectory steps up to the specified checkpoint.
        """
        # replay actions on env and collect rewards for consistency check
        replay_rewards = []

        counter = -1
        for action in self.replay_action_generator(checkpoint):
            counter += 1
            obs, reward, done, info = env.step(action)
            replay_rewards.append(reward)

        # calculate sum of inventory items at this snapshot
        recorded_inventory = [state['inventory'] for state in self.states][counter]
        summe_live = 0
        summe_recorded = 0
        for key in obs['inventory']:
            if key != 'dirt':
                summe_live += obs['inventory'][key]
                summe_recorded += recorded_inventory[key]

        wait_counter = 0
        inventory_ok = False
        #while summe_live != summe_recorded and wait_counter <= 10:
        while not inventory_ok and wait_counter <= 10:
            inventory_ok = self.inventories_ok(obs=obs, recorded_inventory=recorded_inventory)

            # wait for up to 10 frames to see if anything will show up in inventory
            print('Inventories not matching, waiting...', wait_counter)

            counter += 1
            action = noop_action
            obs, reward, done, info = env.step(action)

            wait_counter += 1

        # final test
        if not self.inventories_ok(obs, recorded_inventory):
            print('inventories of live and recorded agent (without dirt) not matching.')
            print('inventory live agent: ', obs['inventory'])
            print('inventory recorded agent: ', recorded_inventory)
            raise Exception('inventory exception!')

        # print inventories in case of success
        print('Successfully replayed inventory.')
        print('Inventory live agent: ', obs['inventory'])
        print('Inventory recorded agent: ', recorded_inventory)
        return env, obs


    def inventories_ok(self, obs, recorded_inventory):
        """
        This static function compares live and recorded inventories. It returns true if each entry of the live
        inventory is greater equal that the corresponding entry of the recorded inventory.
        """
        for key in obs['inventory']:
            if key != 'dirt':
                if obs['inventory'][key] >= recorded_inventory[key]:
                    pass
                else:
                    return False

        return True


    def plot_all_observations(self):
        """
        Plots all recorded states. This is mainly for debugging purposes.
        """
        counter = 0
        os.makedirs('figures/testexperience', exist_ok=True)
        for obs in self.states:
            print('saving state', counter)
            plt.imsave('figures/testexperience/{:05d}.png'.format(counter), obs['pov'])
            counter += 1

    def save(self, file_path, compressed=False):
        """ save collected experience to pickle file """
        os.makedirs(os.path.dirname(file_path), exist_ok=True)
        if not compressed:
            with open(file_path, "wb") as fp:
                pickle.dump(self, fp, -1)
        elif compressed:
            with gzip.GzipFile(file_path, "wb", compresslevel=7) as fp:
                pickle.dump(self, fp, -1)

    @staticmethod
    def load(file_path, compressed=False):
        """ factory method to load experience """
        if not compressed:
            with open(file_path, "rb") as fp:
                return pickle.load(fp)
        elif compressed:
            with gzip.open(file_path, "rb") as fp:
                return pickle.load(fp)

    @staticmethod
    def recursive_fill_skip_or_black(i, frames):
        if i <= 0:
            return np.zeros((64, 64, 3), dtype=np.uint8)
        else:
            return frames[i-1] if frames[i-1] is not None else Experience.recursive_fill_skip_or_black(i-1, frames)

    @staticmethod
    def video(root_path, experience_file, frames, rewards, stages, povs=None, trans_rewards=None, video_only=False, skip_frames=False):
        os.makedirs(root_path, exist_ok=True)
        with open('configs/experiment/consensus_code.json') as f:
            consensus_code = json.load(f)
        if len(frames) == 0:
            print("Experience: Cannot write empty frames video!")
            return
        video = []
        proc_frames = []
        for i, (frame, reward, stage) in enumerate(tqdm(zip(frames, rewards, stages))):
            if video_only:
                canvas_dim = 256
                canvas = Image.new(mode="RGB", size=(canvas_dim, canvas_dim))
                if frame is None and skip_frames:
                    continue
                elif frame is None:
                    frame = Experience.recursive_fill_skip_or_black(i, frames)
                img = Image.fromarray(frame.astype(np.uint8), "RGB")
                img = img.resize((canvas_dim, canvas_dim))
                img = ImageCorrection.process(img)
                proc_frames.append(img)
                canvas.paste(img, box=(0, 0))
            else:
                canvas_dim = 240 if povs is not None else 210
                canvas = Image.new(mode="RGB", size=(canvas_dim, 100))
                if frame is None and skip_frames:
                    continue
                elif frame is None:
                    frame = Experience.recursive_fill_skip_or_black(i, frames)
                img = Image.fromarray(frame.astype(np.uint8), "RGB")
                canvas.paste(ImageCorrection.process(img), box=(0, 0))
                if povs is not None:
                    if povs[i] is None:
                        povs[i] = Experience.recursive_fill_skip_or_black(i, povs)
                    img = Image.fromarray(povs[i].astype(np.uint8), "RGB")
                    canvas.paste(ImageCorrection.process(img), box=(64, 0))
                draw = ImageDraw.Draw(canvas)
                font = ImageFont.load_default()
                rew_dim = 150 if povs is not None else 75
                draw.text((rew_dim, 10), "[Reward]", (255, 255, 255), font=font)
                draw.text((rew_dim, 22), str(reward), (255, 255, 255), font=font)
                draw.text((rew_dim, 34), "[Current Task]", (255, 255, 255), font=font)
                cur_stage = stage[-1][-1]
                draw.text((rew_dim, 46), f'{cur_stage} ({consensus_code[cur_stage]})', (255, 255, 255), font=font)
                if trans_rewards is not None:
                    draw.text((rew_dim, 58), "[ORI Reward]", (255, 255, 255), font=font)
                    draw.text((rew_dim, 60), str(trans_rewards[i]), (255, 255, 255), font=font)
                draw.text((2, 70), '[Completed Tasks]', (255, 255, 255), font=font)
                compl_stage = '-' if len(stage[:-1]) == 0 else f'{stage[:-1]}'
                draw.text((2, 82), compl_stage, (255, 255, 255), font=font)
            video.append(np.asarray(canvas))
        video = np.stack(video, axis=0)
        sys_time = datetime.datetime.fromtimestamp(time.time()).strftime('%Y%m%d-%H%M%S')
        stage = '[]' if len(stages) == 0 else stages[-1]
        name = os.path.join(root_path, f'{experience_file}_stage-[{stage}]_rew-{rewards[-1]}_rec-{sys_time}')
        imageio.mimwrite('{}.mp4'.format(name), video, fps=30)

    @staticmethod
    def validate_and_fix(experience_recording):
        corrected_checkpoints = {}
        checkpoint_names = {}

        # ============= CORRECTION FOR DOUBLE PYRAMID CHECKPOINTS =================
        offset = 0
        checkpoints = experience_recording.checkpoints
        # verify all existing checkpoints based on the consensus and correct if they occur multiple times
        for k, v in checkpoints.items():
            # strip to consensus name without date
            consensus = k.split('.')[-1]
            # check if this consensus stage already occurred
            if consensus in corrected_checkpoints:
                # get last checkpoint values
                from_del_idx = corrected_checkpoints[consensus]
                to_del_idx = v
                # remove redundant entries
                del experience_recording.actions[from_del_idx:to_del_idx]
                del experience_recording.states[from_del_idx:to_del_idx]
                del experience_recording.transitions_info[from_del_idx:to_del_idx]
                # correct all future checkpoints by the detected offset
                offset += to_del_idx - from_del_idx
                # recreate the corrected_checkpoints discarding follow up stages
                corrected_checkpoints = {k: v for k, v in corrected_checkpoints.items() if consensus not in k}
            # build original name mapping
            checkpoint_names[consensus] = k
            # add stage to corrected_consensus with corrected offset
            corrected_checkpoints[consensus] = v - offset
        # recreate proper named mapping of checkpoints and reassign correction
        experience_recording.checkpoints = OrderedDict()
        for k, v in corrected_checkpoints.items():
            experience_recording.checkpoints[checkpoint_names[k]] = v

        # ============= CORRECTION FOR MULTIPLE DONE FLAGS IN STAGES ==============
        checkpoints = corrected_checkpoints
        corrected_checkpoints = {k.split('.')[-1]: v for k, v in experience_recording.checkpoints.items()}
        idx = 0
        # find last done
        done_idx = 0
        for i in range(len(experience_recording.transitions_info)):
            if experience_recording.transitions_info[i]['done']:
                done_idx = i
        # set correct index if done was found in between stages
        for k, v in checkpoints.items():
            for i in range(v - idx):
                if experience_recording.transitions_info[idx + i]['done'] and idx + i < done_idx:
                    last_consensus_stage = k.split('.')[-1][:-1]
                    # it is expected to have at least two stages in the consensus such that you can set the index
                    # the previous stage to begin otherwise the current file defined as corrupted
                    assert last_consensus_stage in corrected_checkpoints
                    corrected_checkpoints[last_consensus_stage] = idx + i + 1
            idx = v
        experience_recording.checkpoints = OrderedDict()
        for k, v in corrected_checkpoints.items():
            experience_recording.checkpoints[checkpoint_names[k]] = v

        return experience_recording
