from argparse import ArgumentParser

import numpy as np

from offline.envs.custom.utils import save_dataset
from offline.envs.registration import (
    make_env_and_load_data_dict,
    get_ref_scores,
)
from offline.utils.suppress_warnings import (
    suppress_absl_warnings,
    suppress_gymnasium_warnings,
)


def build_argument_parser():
    parser = ArgumentParser()
    parser.add_argument("dataset_names", nargs="+")
    parser.add_argument("-f", "--force", action="store_true")
    parser.add_argument("-l", "--label", required=True)
    parser.add_argument("-s", "--silent", action="store_true")
    return parser


def main(
    dataset_names: list[str],
    force: bool,
    label: str,
    silent: bool,
) -> None:
    suppress_absl_warnings()
    suppress_gymnasium_warnings()

    results = []
    max_scores = []
    min_scores = []
    env_ids = []
    common_keys: set[str] = set()
    for index, dataset_name in enumerate(dataset_names):
        env, data_dict, _ = make_env_and_load_data_dict(dataset_name)
        try:
            max_score, min_score = get_ref_scores(dataset_name)
        except KeyError:
            max_score, min_score = None, None
        assert "infos/sources" not in data_dict
        data_dict["infos/sources"] = np.full_like(
            data_dict["rewards"], index, dtype=int
        )
        if not data_dict["terminals"][-1]:
            data_dict["timeouts"][-1] = True
        if env.spec is None:
            raise ValueError("Cannot save datasets from envs with no EnvSpecs.")
        env_ids.append(env.spec.id)
        max_scores.append(max_score)
        min_scores.append(min_score)
        results.append(data_dict)
        if not common_keys:
            common_keys = set(data_dict.keys())
        else:
            common_keys &= set(data_dict.keys())

    if len(frozenset(env_ids)) > 1:
        raise ValueError(
            "Trying to merge datasets from different environments."
        )
    assert len(frozenset(max_scores)) == 1
    assert len(frozenset(min_scores)) == 1

    if not silent:
        total_results = {
            key: np.concatenate([res[key] for res in results], axis=0)
            for key in common_keys
            if not key.startswith("metadata/")
        }
        save_dataset(
            env_id=env_ids[0],
            force=force,
            label=label,
            max_score=max_scores[0],
            min_score=min_scores[0],
            results=total_results,
        )


if __name__ == "__main__":
    main(**vars(build_argument_parser().parse_args()))
