import enum
import tabulate
import torch
import numpy as np

from dataclasses import dataclass, fields

from expground.types import DataArray, Dict, Callable, Sequence, Tensor
from expground.logger import Log


class EpisodeKeys(enum.Enum):
    OBSERVATION = "observation"
    REWARD = "reward"
    ACC_REWARD = "accumulative_reward"
    NEXT_OBSERVATION = "next_observation"
    ACTION = "action"
    ACTION_MASK = "action_mask"
    ACTION_LOGITS = "logits"
    NEXT_ACTION_LOGITS = "next_logits"
    NEXT_ACTION_MASK = "next_action_mask"
    DONE = "done"
    ACTION_DIST = "action_distribution"
    GLOBAL_STATE = "global_state"
    NEXT_GLOBAL_STATE = "next_global_state"
    STATE_VALUE = "state_value"
    INFO = "info"


@dataclass
class Episode:
    observation: DataArray
    action: DataArray
    reward: DataArray
    next_observation: DataArray
    action_mask: DataArray
    done: DataArray
    action_distribution: DataArray
    logits: DataArray
    extras: Dict[str, DataArray] = None

    def __post_init__(self):
        """Check length consistency."""

        # check shape
        lens_dict = {
            field.name: len(getattr(self, field.name))
            for field in fields(self)
            if field.name is not "extras" and getattr(self, field.name) is not None
        }
        assert (
            len(set(lens_dict.values())) == 1
        ), f"Inconsistency between fields: {lens_dict}"
        if self.extras is not None:
            lens_dict = {k: len(v) for k, v in self.extras.items()}
            lens_set = set(lens_dict.values())
            assert len(lens_set) == 1 and lens_set.pop() == len(
                self.observation
            ), f"Inconsitency in extras: {lens_dict} expected length is: {len(self.observation)}"

    def clean_data(self):
        res = {}
        for field in fields(self):
            v = getattr(self, field.name)
            if field.name is not "extras" and v is not None:
                res[field.name] = v
        if self.extras is not None:
            for k, v in self.extras.items():
                if v is not None:
                    res[k] = v
        for k, v in res.items():
            Log.debug("cleaned data for %s with shape %s", k, v.shape)
        return res

    def visualize(self):
        """Visualize current episode in tabular.

        ------------------    -------------- -----
        name                  shape          bytes
        'observation'         (n_batch, ...) xxx
        'action'              (n_batch, ...) xxx
        'reward'              (n_batch,)     xxx
        'next_observation'    (n_batch, ...) xxx
        'action_mask'         (n_batch, ...) xxx
        'done'                (n_batch,)     xxx
        'action_distribution' (n_batch, ...) xxx
        ------------------ -------------- -----
        """

        # factors = {
        #     field.name: getattr(self, field.name)
        #     for field in fields(self)
        # }

        # rows = [[self.route, *[f for f in factors]]]  # header

        # for field in fields(Factor):
        #     rows.append([field.name, [getattr(f, field.name) for f in factors.values()]])
        # print(tabulate.tabulate(rows))
        raise NotImplementedError


def default_dtype_mapping(dtype):
    if dtype in [np.int32, np.int64, int]:
        return torch.int32
    elif dtype in [float, np.float32]:
        return torch.float32
    elif dtype == np.float64:
        return torch.float64
    elif dtype in [bool, np.bool_]:
        return torch.float32
    else:
        raise NotImplementedError(f"dtype: {dtype} has no transmission rule.") from None


# wrap with type checking
def walk(caster, v):
    if isinstance(v, Episode):
        v = v.__dict__
    elif isinstance(v, Dict):
        for k, _v in v.items():
            v[k] = walk(caster, _v)
    else:
        v = caster(v)
    return v


def tensor_cast(
    custom_caster: Callable = None,
    callback: Callable = None,
    dtype_mapping: Dict = None,
    device="cpu",
):
    """Casting the inputs of a method into tensors if needed.

    Note:
        This function does not support recursive iteration.

    Args:
        custom_caster (Callable, optional): Customized caster. Defaults to None.
        callback (Callable, optional): Callback function, accepts returns of wrapped function as inputs. Defaults to None.
        dtype_mapping (Dict, optional): Specify the data type for inputs which you wanna. Defaults to None.

    Returns:
        Callable: A decorator.
    """
    dtype_mapping = dtype_mapping or default_dtype_mapping
    cast_to_tensor = custom_caster or (
        lambda x: torch.FloatTensor(x.copy()).to(
            device=device, dtype=dtype_mapping(x.dtype)
        )
        if not isinstance(x, torch.Tensor)
        else x
    )

    def decorator(func):
        def wrap(self, *args, **kwargs):
            new_args = []
            for i, arg in enumerate(args):
                new_args.append(walk(cast_to_tensor, arg))
            for k, v in kwargs.items():
                kwargs[k] = walk(cast_to_tensor, v)
            rets = func(self, *new_args, **kwargs)
            if callback is not None:
                callback(rets)
            return rets

        return wrap

    return decorator
