from typing import Dict
from ml_collections import ConfigDict
import numpy as np


def _trim_dict(d: Dict[str, np.ndarray], n: int):
    return {k: v[:n, ...] for k, v in d.items()}


def _append_dict(d1: Dict[str, np.ndarray], d2: Dict[str, np.ndarray]):
    assert set(d1.keys()) == set(d2.keys())
    return {k: np.concatenate([d1[k], d2[k]], axis=0) for k in d1.keys()}


class MultiDataLoader(object):
    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()

        # Just load a single file in its entirety:
        config.file = ""

        # Load and mix two files:
        config.file_0 = ""
        config.rows_0 = 0
        config.file_1 = ""
        config.rows_1 = 0

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    @staticmethod
    def is_load_config_set(config: ConfigDict):
        return (config.file != "") or (config.file_0 != "")

    @staticmethod
    def load(config: ConfigDict, verbose=True):
        if config.file != "":
            if not all(c == "" for c in (config.file_0, config.file_1)):
                raise ValueError(
                    f"Cannot specify any multi-file argument `file_i` when single-file argument `file` is set"
                )

            return MultiDataLoader.loadfile(config.file, verbose)

        else:
            if any(c == "" for c in (config.file_0, config.file_1)):
                raise ValueError(
                    "Both `file_1` and `file_2` must be set. If you only want a single file, specify `file` instead."
                )

            p0 = MultiDataLoader.loadfile(config.file_0, verbose)
            p1 = MultiDataLoader.loadfile(config.file_1, verbose)

            rv = _append_dict(
                _trim_dict(p0, config.rows_0), _trim_dict(p1, config.rows_1)
            )
            print(
                f"Spliced {len(rv['observations'])} rows, mean reward per 1000 steps {np.sum(rv['rewards'])/1000}"
            )

            return rv

    @staticmethod
    def loadfile(filename: str, verbose: bool = True):
        loaded = np.load(filename)
        num_traj = loaded["trajectories"]

        rv = dict(
            observations=loaded["s"][:-1],
            actions=loaded["a"][:-1],
            next_observations=loaded["sp"][:-1],
            subseq_observations=loaded["s"][1:],
            rewards=loaded["r"][:-1],
            dones=loaded["done"][:-1].astype(np.float32),
        )

        if verbose:
            print(f"Loaded from {filename}")
            print(
                f"Loaded {len(rv['observations'])} rows / {num_traj} trajectories from file, mean reward per episode {np.sum(rv['rewards'])/num_traj}"
            )

        return rv
