import os
import pickle
import argparse
import numpy as np
import gzip
from tqdm import tqdm
import logging

from train.behavioral_cloning.datasets.minerl_data_pipeline import make, download as download_original

from torch.utils.data import Dataset

from train.behavioral_cloning.datasets.transforms import default_data_transform
from train.behavioral_cloning.spaces.action_spaces import BINARY_ACTIONS, ENUM_ACTIONS, ENUM_ACTION_OPTIONS
from train.behavioral_cloning.datasets.utils import stack_binary_actions, stack_enum_actions
from train.behavioral_cloning.spaces.action_spaces import ActionSpace
from train.behavioral_cloning.spaces.input_spaces import InputSpace

logging.getLogger("train.behavioral_cloning.datasets.minerl_data_pipeline").setLevel(logging.INFO)


class MineRLBaseDataset(Dataset):
    """ MineRL Base-Dataset for Behavioural Cloning. """

    def __init__(self, root=None, sequence_len=10, future_steps=0, data_transform=default_data_transform,
                 preprocessors=(), seed=4711, queue_size=None, include_metadata=False, max_samples=None,
                 data_split=1, frame_skip=1, train=True, prepare=False, download=False,
                 experiment="MineRLTreechop-v0"):

        self.sequence_len = sequence_len
        self.future_steps = future_steps
        self.frame_skip = frame_skip
        self.data_transform = data_transform
        self.preprocessors = preprocessors

        # load sequences from disc
        if root is not None:
            self.sequences = self._dump_or_load_dataset_pkl(root, prepare, download, experiment, max_samples,
                                                            queue_size, seed, include_metadata)

            # split into train and validation set
            split_idx = int(len(self.sequences) * data_split)
            self.sequences = self.sequences[0:split_idx] if train else self.sequences[split_idx::]
        else:
            self.sequences = []

        # prepare sequences for training
        self._pre_process()

        # compute train indices
        self.indices = None
        self._compute_sequence_step_indices()

    def __len__(self):
        """ Number of sliceable sub-sequences """
        return len(self.indices)

    def __getitem__(self, item):
        """ return pre-processed training instance """
        i_seq, step = self.indices[item]
        return self._prepare_sequence(i_seq, step)

    def add_sequence(self, sequence, insert_index=None):
        """ add new sequence to dataset

        Format:
        -------
            sequence = (observation_dict, action_dict, reward_seq, next_observation_dict, done_seq)

            observation_dict: odict_keys(['equipped_items', 'inventory', 'pov'])
                observation_dict['pov']: ndarray(uint8) with shape (n_steps, 64, 64, 3)
                observation_dict['inventory']: OrderedDict('coal', 'cobblestone', ...)
                    observation_dict['inventory']['coal']: array with shape (n_steps, )
                    ...
                observation_dict['equipped_items']: OrderedDict('mainhand', ...)
                    observation_dict['equipped_items']['mainhand']: OrderedDict
                    ...
            action_dict: odict_keys(['attack', 'back', 'camera', 'craft', 'equip', 'forward', 'jump', ...])
                action_dict['attack']: array with shape (n_steps, )
                    ...
            reward_seq: float array with shape (n_steps, )
            next_observation_dict: None (this is useless for us)
            done_seq: binary array with shape (n_steps, )
        """

        for processor in self.preprocessors:
            sequence = processor(sequence)

        if insert_index:
            self.sequences.insert(insert_index, sequence)
        else:
            self.sequences.append(sequence)

        # re-compute indices of dataset
        self._compute_sequence_step_indices()

    def dump(self, index, experiment, root='.'):
        """ Dumps a sequence to the file system """
        dump_file = os.path.join(root, "%s.pkl.gz" % experiment)
        with gzip.GzipFile(dump_file, "wb", compresslevel=7) as fp:
            print("dumping dataset pickle ...", end="")
            pickle.dump(self.sequences[index], fp, -1)
            print("done! (%s)" % dump_file)

    def remove_sequence(self, index):
        """ remove sequence from dataset by index in list """
        self.sequences.pop(index)
        self._compute_sequence_step_indices()

    def get_preprocessed_sequence(self, seq_index):
        """ return full length pre-processed sequence """
        return self._prepare_sequence(seq_index, step=None)

    def get_raw_observation(self, item):
        i_seq, step = self.indices[item]
        observation_dict, action_dict, reward_seq, next_observation_dict, done_seq, meta = self.sequences[i_seq]

        step_action_dict = dict()
        for key in action_dict.keys():
            if action_dict[key].ndim > 0:
                step_action_dict[key] = action_dict[key][step]

        return None, step_action_dict, None, None, None

    def num_sequences(self):
        """ returns number of sequences """
        return len(self.sequences)

    def _prepare_sequence(self, i_seq, step=None):

        # get sequence
        observation_dict, action_dict, reward_seq, next_observation_dict, done_seq, steps_remaining, meta = \
            self.sequences[i_seq]

        # prepare inputs (observation_dict.keys() -> 'equipped_items', 'inventory', 'pov')
        pov = np.transpose(observation_dict["pov"], (0, 3, 1, 2))
        binary_actions = stack_binary_actions(action_dict, BINARY_ACTIONS)
        enum_actions = stack_enum_actions(action_dict, ENUM_ACTIONS)
        camera_actions = action_dict["camera"]

        # inventory
        inventory, equipped_items = None, None
        if "inventory" in observation_dict:
            inventory = np.c_[list(observation_dict["inventory"].values())]
        if "equipped_items" in observation_dict:
            equipped_items = observation_dict["equipped_items"]["mainhand"]["type"]

        # get required slice slice
        if step:
            start = step - (self.frame_skip * self.sequence_len) - (self.frame_skip * 1)
            stop = step + (self.frame_skip * self.future_steps)
            assert start >= 0
            assert stop <= pov.shape[0]

            pov = pov[slice(start, stop)]
            binary_actions = binary_actions[slice(start, stop)]
            camera_actions = camera_actions[slice(start, stop)]
            reward_seq = reward_seq[slice(start, stop)]
            steps_remaining = steps_remaining[slice(start, stop)]

            if inventory is not None:
                inventory = inventory[:, slice(start, stop)]

            if equipped_items is not None:
                equipped_items = equipped_items[slice(start, stop)]

            if enum_actions is not None:
                enum_actions = enum_actions[slice(start, stop)]

        # apply data transforms
        data = pov, binary_actions, camera_actions, enum_actions, inventory, equipped_items, reward_seq, steps_remaining
        pov, binary_actions, camera_actions, enum_actions, inventory, equipped_items, reward_seq, steps_remaining = self.data_transform(
            data)

        # apply frame skipping
        pov = pov[::self.frame_skip]
        binary_actions = binary_actions[::self.frame_skip]
        camera_actions = camera_actions[::self.frame_skip]
        reward_seq = reward_seq[::self.frame_skip]
        steps_remaining = steps_remaining[::self.frame_skip]

        if inventory is not None:
            inventory = inventory[:, ::self.frame_skip]
        if equipped_items is not None:
            equipped_items = equipped_items[::self.frame_skip]
        if enum_actions is not None:
            enum_actions = enum_actions[::self.frame_skip]

        # transpose inventory to have time dimension in first axis
        inventory = inventory.T

        # pov: S, C, H, W
        # discrete_action_matrix: S, 8
        # camera_actions: S, 2
        # enum_actions: S, 5
        # inventory: S, 18
        # equipped_items: S, 1
        # rewards: S, 1
        return pov, binary_actions, camera_actions, enum_actions, inventory, equipped_items, reward_seq, steps_remaining

    def _pre_process(self):
        for processor in self.preprocessors:
            for i, seq in enumerate(self.sequences):
                self.sequences[i] = processor(seq)

    def _dump_or_load_dataset_pkl(self, root, prepare, download, experiment, max_samples, queue_size, seed,
                                  include_metadata):
        # dump dataset pickle file
        dump_file = os.path.join(root, "%s.pkl.gz" % experiment)

        if prepare or not os.path.exists(dump_file):
            # download experiment
            if download:
                print("downloading dataset ...")
                download_original(root, experiment=experiment)

            # init minerl data
            self.data_pipeline = make(experiment, data_dir=root, num_workers=8)

            # collect train examples
            print("collecting sequences ...")
            self.sequences = []
            for i, entry in tqdm(enumerate(self.data_pipeline.sarsd_iter(num_epochs=1,
                                                                         max_sequence_len=-1,
                                                                         queue_size=queue_size, seed=seed,
                                                                         include_metadata=include_metadata))):

                # remove "next_observation_dict" as we don't need it in our dataloader
                if include_metadata:
                    observation_dict, action_dict, reward_seq, next_observation_dict, done_seq, meta = entry
                    # add value function
                    steps_remaining = np.arange(len(reward_seq) - 1, -1, -1)
                    entry = observation_dict, action_dict, reward_seq, None, done_seq, steps_remaining, meta
                else:
                    observation_dict, action_dict, reward_seq, next_observation_dict, done_seq = entry
                    steps_remaining = np.arange(len(reward_seq) - 1, -1, -1)
                    entry = observation_dict, action_dict, reward_seq, None, done_seq, steps_remaining, None

                self.sequences.append(entry)

                if max_samples and len(self.sequences) >= max_samples:
                    break

            if prepare:
                try:
                    with gzip.GzipFile(dump_file, "wb", compresslevel=7) as fp:
                        print("dumping dataset pickle ...", end="")
                        pickle.dump(self.sequences, fp, -1)
                        print("done! (%s)" % dump_file)
                except Exception as e:
                    print("Unable to pickle dataset for future speedups!")

            if experiment in ["MineRLObtainDiamond-v0", "MineRLObtainDiamondDense-v0"]:
                self._compute_action_statistics()
                # pass

        # load dataset pickle file
        else:
            with gzip.open(dump_file, "rb") as fp:
                self.sequences = pickle.load(fp)
                if experiment in ["MineRLObtainDiamond-v0", "MineRLObtainDiamondDense-v0"]:
                    self._compute_action_statistics()
                print("loaded dataset (%s)!" % dump_file)

        return self.sequences

    def _compute_action_statistics(self):
        stream_stats = {}
        step_cnt = 0
        # collect statistics for each subsequence
        for i, sequence in enumerate(self.sequences):
            action_dict = sequence[1]
            stream_name = os.path.basename(sequence[-1]["stream_name"])
            stream_name = stream_name[:stream_name.rindex("-")]
            step_cnt += len(action_dict[BINARY_ACTIONS[0]])

            if stream_name in stream_stats:
                seq_statistics = stream_stats[stream_name]
            else:
                seq_statistics = {a: 0 for a in BINARY_ACTIONS}
                seq_statistics.update({a: {b: 0 for b in ENUM_ACTION_OPTIONS[a]} for a in ENUM_ACTIONS})
                stream_stats[stream_name] = seq_statistics

            # update statistics
            for a in BINARY_ACTIONS:
                seq_statistics[a] += action_dict[a].sum()
            for e in ENUM_ACTIONS:
                for a in ENUM_ACTION_OPTIONS[e]:
                    seq_statistics[e][a] += (action_dict[e] == ENUM_ACTION_OPTIONS[e].index(a)).sum()
            stream_stats[stream_name] = seq_statistics

        # summarize subtask statistics
        act_cnt = np.sum([len(x) for x in ENUM_ACTION_OPTIONS.values()]) + len(BINARY_ACTIONS)
        act_stats = np.zeros((len(stream_stats), act_cnt), dtype=np.float32)
        for i, stream_name in enumerate(stream_stats.keys()):
            seq = stream_stats[stream_name]
            j = 0
            for a in BINARY_ACTIONS:
                act_stats[i, j] = seq[a]
                j += 1
            for e in ENUM_ACTIONS:
                for a in ENUM_ACTION_OPTIONS[e]:
                    act_stats[i, j] = seq[e][a]
                    j += 1

        # calculate action frequency as occurrence per demonstration
        demo_stats = np.sum(act_stats > 0, axis=0) / act_stats.shape[0]
        act_stats = act_stats.T

        # divvy up stats for individual action heads
        demonstration_statistics = [demo_stats[:len(BINARY_ACTIONS)]]
        action_statistics = [act_stats[:len(BINARY_ACTIONS)]]
        action_statistics[-1] = action_statistics[-1].sum(1) / step_cnt
        start = len(BINARY_ACTIONS)
        for e in ENUM_ACTIONS:
            demonstration_statistics.append(demo_stats[start:start + len(ENUM_ACTION_OPTIONS[e])])
            action_statistics.append(act_stats[start:start + len(ENUM_ACTION_OPTIONS[e])])
            action_statistics[-1] = action_statistics[-1].sum(1) / step_cnt
            start += len(ENUM_ACTION_OPTIONS[e])

        self.action_statistics = [demonstration_statistics, action_statistics]

    def _compute_sequence_step_indices(self):
        self.indices = []
        for i_seq, seq in enumerate(self.sequences):
            first_step = (self.frame_skip * self.sequence_len + self.frame_skip * 1)
            last_step = seq[0]["pov"].shape[0] - (self.frame_skip * self.future_steps)
            for step in range(first_step, last_step):
                self.indices.append((i_seq, step))


class MineRLDataset(MineRLBaseDataset):
    """ MineRL Dataset for Behavioural Cloning. """

    def __init__(self, input_space: InputSpace, action_space: ActionSpace, **kwargs):
        super().__init__(**kwargs)
        self.input_space = input_space
        self.action_space = action_space

    def __getitem__(self, item):
        """ return pre-processed training instance """
        pov, binary_actions, camera_actions, enum_actions, inventory, equipped_items, rewards, steps_remaining = super().__getitem__(
            item)

        # prepare inputs for model processing
        inputs = self.input_space.prepare(pov, binary_actions, camera_actions, enum_actions, inventory, equipped_items,
                                          rewards, steps_remaining)

        # prepare optimization targets
        targets = self.action_space.prepare_input(binary_actions, camera_actions, enum_actions, inventory,
                                                  equipped_items, rewards, steps_remaining)

        return {"inputs": inputs, "targets": targets}

    def get_preprocessed_sequence(self, seq_index):
        """ return full length pre-processed sequence """
        pov, binary_actions, camera_actions, enum_actions, inventory, equipped_items, rewards, steps_remaining = self._prepare_sequence(
            seq_index, step=None)

        # prepare inputs for model processing
        inputs = self.input_space.prepare(pov, binary_actions, camera_actions, enum_actions, inventory, equipped_items,
                                          rewards)

        return {"inputs": inputs, "targets": None}


from train.behavioral_cloning.spaces.action_spaces import MultiBinarySoftmaxCameraEnumActions
from train.behavioral_cloning.spaces.input_spaces import SingleFrameWithBinaryActionAndContinuousCameraSequence
from train.behavioral_cloning.datasets.transforms import divide_pov_by_255, resize_64_to_48, to_float32, \
    build_data_processor, random_left_right_flip

if __name__ == "__main__":
    """ main """

    # init arg parser
    parser = argparse.ArgumentParser(description='Prepare MineRL Dataset.')
    parser.add_argument('--root', help='root path of dataset. (this is where the data will be downloaded)',
                        required=True, type=str)
    parser.add_argument("--env", type=str, default="MineRLObtainDiamond-v0",
                        help="env used (MineRLTreechop-v0, MineRLObtainDiamond-v0, ...)")
    parser.add_argument('--download', help='allow to download data.', action='store_true')
    parser.add_argument('--prepare', help='prepare dataset pickle.', action='store_true')
    parser.add_argument("--max_samples", type=int, default=None, help="maximum number of sequences")
    args = parser.parse_args()

    INPUT_SPACE = SingleFrameWithBinaryActionAndContinuousCameraSequence(32)
    ACTION_SPACE = MultiBinarySoftmaxCameraEnumActions(
        bins=np.array([-22.5, -17.5, -12.5, -7.5, -2.5, 2.5, 7.5, 12.5, 17.5, 22.5],
                      dtype=np.float32))
    DATA_TRANSFORM = build_data_processor([resize_64_to_48, to_float32, divide_pov_by_255])
    TRAIN_TRANSFORM = build_data_processor([resize_64_to_48, to_float32, divide_pov_by_255, random_left_right_flip])

    set_list = []
    if os.path.exists(os.path.join(args.root, args.env)):
        set_list.append(args.root)
    if not os.path.exists(os.path.join(args.root, args.env)):
        lst = os.listdir(args.root)
        for e in lst:
            if os.path.isdir(os.path.join(args.root, e)) and os.path.exists(
                    os.path.join(os.path.join(args.root, e), args.env)):
                set_list.append(os.path.join(args.root, e))

    for dataset in set_list:
        print(dataset)
        # initialize dataset
        minerl_data = MineRLDataset(root=dataset, max_samples=args.max_samples, train=True, download=args.download,
                                    prepare=args.prepare, experiment=args.env, data_split=1,
                                    sequence_len=32,
                                    include_metadata=True,
                                    input_space=INPUT_SPACE,
                                    action_space=ACTION_SPACE,
                                    data_transform=TRAIN_TRANSFORM,
                                    frame_skip=2
                                    )

        # store statistics
        statfile = os.path.join(dataset, "action-statistics.npz")
        demo_stats, act_stats = minerl_data.action_statistics

        colnames = BINARY_ACTIONS.copy()
        for e in ENUM_ACTIONS:
            for a in ENUM_ACTION_OPTIONS[e]:
                colnames.append(e + "_" + a)
        np.savez_compressed(statfile, demonstration_statistics=demo_stats, step_statistics=act_stats, columns=colnames)

    # for x in minerl_data:
    #    pass
