import numpy as np
import os
import pprint
import random
from typing import Any, Mapping, Optional

from policy.sample_batch import SampleBatch, MultiAgentBatch
from utils.framework import try_import_torch

_printer = pprint.PrettyPrinter(indent=2, width=60)


def summarize(obj: Any) -> Any:
    """Return a pretty-formatted string for an object.

    This has special handling for pretty-formatting of commonly used data types
    in RLlib, such as SampleBatch, numpy arrays, etc.

    Args:
        obj: The object to format.

    Returns:
        The summarized object.
    """

    return _printer.pformat(_summarize(obj))


def _summarize(obj):
    if isinstance(obj, Mapping):
        return {k: _summarize(v) for k, v in obj.items()}
    elif hasattr(obj, "_asdict"):
        return {
            "type": obj.__class__.__name__,
            "data": _summarize(obj._asdict()),
        }
    elif isinstance(obj, list):
        return [_summarize(x) for x in obj]
    elif isinstance(obj, tuple):
        return tuple(_summarize(x) for x in obj)
    elif isinstance(obj, np.ndarray):
        if obj.size == 0:
            return _StringValue("np.ndarray({}, dtype={})".format(obj.shape, obj.dtype))
        elif obj.dtype == np.object or obj.dtype.type is np.str_:
            return _StringValue(
                "np.ndarray({}, dtype={}, head={})".format(
                    obj.shape, obj.dtype, _summarize(obj[0])
                )
            )
        else:
            return _StringValue(
                "np.ndarray({}, dtype={}, min={}, max={}, mean={})".format(
                    obj.shape,
                    obj.dtype,
                    round(float(np.min(obj)), 3),
                    round(float(np.max(obj)), 3),
                    round(float(np.mean(obj)), 3),
                )
            )
    elif isinstance(obj, MultiAgentBatch):
        return {
            "type": "MultiAgentBatch",
            "policy_batches": _summarize(obj.policy_batches),
            "count": obj.count,
        }
    elif isinstance(obj, SampleBatch):
        return {
            "type": "SampleBatch",
            "data": {k: _summarize(v) for k, v in obj.items()},
        }
    else:
        return obj


class _StringValue:
    def __init__(self, value):
        self.value = value

    def __repr__(self):
        return self.value


def update_global_seed_if_necessary(seed: Optional[int] = None) -> None:
    """Seed global modules such as random, numpy, torch.

    This is useful for debugging and testing.

    Args:
        seed: An optional int seed. If None, will not do anything.
    """
    if seed is None:
        return

    # Python random module.
    random.seed(seed)
    # Numpy.
    np.random.seed(seed)

    # Torch.
    torch, _ = try_import_torch()
    torch.manual_seed(seed)
    # See https://github.com/pytorch/pytorch/issues/47672.
    cuda_version = torch.version.cuda
    if cuda_version is not None and float(torch.version.cuda) >= 10.2:
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = "4096:8"
    else:
        from distutils.version import LooseVersion

        if LooseVersion(torch.__version__) >= LooseVersion("1.8.0"):
            # Not all Operations support this.
            torch.use_deterministic_algorithms(True)
        else:
            torch.set_deterministic(True)
    # This is only for Convolution no problem.
    torch.backends.cudnn.deterministic = True
