import glob
from enum import Enum

import minerl
import os
import json
import numpy as np
from tqdm import tqdm
from train import Config
from itertools import product
from alignment.sequence import Sequence
from alignment.vocabulary import Vocabulary
from alignment.sequencealigner import SimpleScoring, GlobalSequenceAligner

from train.proc.proc_base import Subprocess
from train.subtask_identifier.task_builder import TaskBuilder


class FileObject:
    def __init__(self):
        self.metadata_file = None
        self.rendered_file = None
        self.rendered = None
        self.inventory = None
        self.enc_inventory = None
        self.enc_steps = None
        self.sequence = None
        self.seq_len = 0


CharEncoding = {
    0: "G",  # coal
    1: "A",  # cobblestone
    2: "L",  # crafting_table
    3: "M",  # dirt
    4: "F",  # furnace
    5: "W",  # iron_axe
    6: "K",  # iron_ingot
    7: "Q",  # iron_ore
    8: "E",  # iron_pickaxe
    9: "S",  # log
    10: "P",  # planks
    11: "V",  # stick
    12: "I",  # stone
    13: "C",  # stone_axe
    14: "Y",  # stone_pickaxe
    15: "H",  # torch
    16: "R",  # wooden_axe
    17: "N",  # wooden_pickaxe
}


class FilterType(Enum):
    OutsideOfPlusMinusStdFromMean = 0
    GreaterThanMean = 1
    GreaterThanMedian = 2
    GreaterThanNPercentOfMax = 3
    LessThanNPercentOfMax = 4


def align_sequences(seq_a, seq_b):
    alignments = []

    a = Sequence(seq_a)
    b = Sequence(seq_b)

    # Create a vocabulary and encode the sequences.
    v = Vocabulary()
    a_encoded = v.encodeSequence(a)
    b_encoded = v.encodeSequence(b)

    # Create a scoring and align the sequences using global aligner.
    scoring = SimpleScoring(1, -1)
    aligner = GlobalSequenceAligner(scoring, -2)
    score, encoded_list = aligner.align(a_encoded, b_encoded, backtrace=True)
    # Iterate over optimal alignments and print them.
    for encoded in encoded_list:
        alignment = v.decodeSequenceAlignment(encoded)
        alignments.append(alignment)
    return alignments


def pairwise_compare_sequences(key_pairs, seq_dict):
    alignments = []
    for key_a, key_b in tqdm(key_pairs):
        if key_a == key_b:
            continue
        for alignment in align_sequences(seq_dict[key_a], seq_dict[key_b]):
            alignments.append(alignment)
    return alignments


def rank_alignment(top: int, alignments):
    top_n = {}
    for i, elem in enumerate(sorted(alignments, key=lambda v: v.quality(), reverse=True)):
        if i >= top:
            break
        top_n['{}first'.format(i)] = elem.first
        top_n['{}second'.format(i)] = elem.second
        print('{}: {} =>\n{}'.format(i+1, elem.quality(), elem))
    return top_n


def filter_shortest_sequences(file_dict, filter_type: FilterType, percent: float = None):
    seq_lens = []
    for _, v in file_dict.items():
        seq_lens.append(v.seq_len)

    median = np.median(seq_lens)
    mean = np.mean(seq_lens)
    std = np.std(seq_lens)
    max_len = np.max(seq_lens)
    min_len = np.min(seq_lens)
    print('median', median)
    print('mean', mean)
    print('std', std)
    print('min', min_len)
    print('max', max_len)

    deletion_list = []
    for k, v in file_dict.items():
        if filter_type == FilterType.OutsideOfPlusMinusStdFromMean:
            # delete everything that is outside this window
            if not (mean - std < v.seq_len < mean + std):
                deletion_list.append(k)
        elif filter_type == FilterType.GreaterThanMedian:
            # delete everything above median
            if v.seq_len > median:
                deletion_list.append(k)
        elif filter_type == FilterType.GreaterThanMean:
            # delete everything above mean
            if v.seq_len > mean:
                deletion_list.append(k)
        elif filter_type == FilterType.GreaterThanNPercentOfMax:
            # delete everything above n %
            if v.seq_len > percent * max_len:
                deletion_list.append(k)
        elif filter_type == FilterType.LessThanNPercentOfMax:
            # delete everything below n %
            if v.seq_len < percent * max_len:
                deletion_list.append(k)
        else:
            raise NotImplementedError('Unknown filter type: {}'.format(filter_type))
    for k in deletion_list:
        del file_dict[k]
    print('remaining', len(list(file_dict.keys())))


def filter_successful_file_dict(file_dict):
    deletion_list = []
    for k, v in file_dict.items():
        with open(v.metadata_file) as metadata_file:
            metadata = json.load(metadata_file)
        # choose only successful sequences
        if not metadata['success']:
            deletion_list.append(k)
    for key in deletion_list:
        del file_dict[key]


def load_rendered_episodes(file_dict):
    for _, v in file_dict.items():
        # load rendered files observations
        v.rendered = np.load(v.rendered_file)
        v.inventory = v.rendered['observation_inventory']


def encode_inventory_change(file_dict):
    for _, v in file_dict.items():
        prev = np.zeros(np.shape(v.inventory)[1], dtype=np.int)
        chars = []
        encodings = []
        steps = []
        for i in range(np.shape(v.inventory)[0]):
            # track the change in the inventory
            diff = v.inventory[i] - prev
            # only keep changes
            if np.sum(np.abs(diff)) > 0:
                # select the maximum argument since it represents item gains
                chars.append(CharEncoding[np.argmax(diff)])
                # keep changes
                encodings.append(diff)
                # track step index
                steps.append(i)
            prev = v.inventory[i]
        # build sequence
        v.sequence = chars
        v.seq_len = len(chars)
        v.enc_inventory = np.array(encodings)
        v.enc_steps = np.array(steps)


def compute_stats(file_dict):
    # TODO: extract action and inventory statistics
    pass


def extract_data(file_dict, consensus):
    for k, v in file_dict.items():
        for align in align_sequences(v.sequence, consensus):
            print('extract', align)


class SubtaskBuilder(Subprocess):
    def __init__(self, config):
        super(SubtaskBuilder, self).__init__('SubtaskBuilder', config)
        self.config = config
        self.task_builder = TaskBuilder(self.config)
        self.file_dict = None
        self.list_of_tasks = None
        self.consensus_encoding = None
        self.consensus_names = None

    def load_data(self):
        if not os.path.exists(self.config.subtask.datadir):
            os.makedirs(self.config.subtask.datadir)
        target_file = os.path.join(self.config.subtask.datadir, self.config.env.env)
        # download data only if it does not exist already
        if not os.path.exists(target_file):
            minerl.data.download(directory=self.config.subtask.datadir, experiment=self.config.env.env)

    def build_file_dict(self):
        # identify all available files
        metadata_pattern = os.path.join(self.config.subtask.datadir, '**/*metadata.json')
        metadata_files = glob.glob(metadata_pattern, recursive=True)

        rend_pattern = os.path.join(self.config.subtask.datadir, '**/*rendered.npz')
        rend_files = glob.glob(rend_pattern, recursive=True)

        file_dict = {}
        # create dictionary for each available file object
        for fn in metadata_files:
            key = os.path.dirname(fn)
            fo = FileObject()
            fo.metadata_file = fn
            file_dict[key] = fo
        for fn in rend_files:
            key = os.path.dirname(fn)
            fo = file_dict[key]
            fo.rendered_file = fn
            file_dict[key] = fo
        return file_dict

    @property
    def consensus(self):
        return self.list_of_tasks

    def run(self):
        self.load_data()
        self.file_dict = self.build_file_dict()

        # load successful sequences
        file_dict_ = self.file_dict.copy()
        filter_successful_file_dict(file_dict_)
        load_rendered_episodes(file_dict_)
        encode_inventory_change(file_dict_)

        # compute consensus
        filter_shortest_sequences(file_dict_, FilterType.LessThanNPercentOfMax, percent=0.05)
        filter_shortest_sequences(file_dict_, FilterType.GreaterThanMedian, percent=0.15)

        key_pairs = product(file_dict_.keys(), file_dict_.keys())
        seq_dict = {k: v.sequence for k, v in file_dict_.items()}
        alignments = pairwise_compare_sequences(key_pairs=key_pairs, seq_dict=seq_dict)
        top_n = rank_alignment(top=5, alignments=alignments)
        key_pairs = product(top_n.keys(), top_n.keys())
        alignments = pairwise_compare_sequences(key_pairs=key_pairs, seq_dict=top_n)
        top_n = rank_alignment(top=1, alignments=alignments)
        self.consensus_encoding = list(top_n.values())[0]

        # build tasks
        self.list_of_tasks, self.consensus_names = self.task_builder.extract_dummy_tasks(self.consensus_encoding)

        # signal complete flag for parallel executions
        self._complete()

        # load all sequences
        file_dict_ = self.file_dict.copy()
        load_rendered_episodes(file_dict_)
        encode_inventory_change(file_dict_)

        # compute stats
        compute_stats(file_dict_)

        # extract data
        extract_data(file_dict_, self.consensus_encoding)

        print('done')


if __name__ == '__main__':
    config = Config('configs/experiment/config.meta.json')
    sb = SubtaskBuilder(config)
    sb.run()
