import time
import numpy

from dataclasses import dataclass
from typing import Dict, Any, Sequence, Union, Callable, Tuple, List, Iterator
from types import LambdaType, FunctionType
from collections import defaultdict

from torch import Tensor as torch_tensor

Tensor = torch_tensor
DataArray = numpy.ndarray
AgentID = str
PolicyID = str
EnvConfig = Dict[str, Any]
PolicyConfig = Dict[str, Any]


class MetricType:
    REWARD = "reward"
    LIVE_STEP = "live_step"


@dataclass
class PolicyConfig:
    policy: Union[str, Callable]
    observation_space: Union[Dict, Callable] = None
    action_space: Union[Dict, Callable] = None
    human_readable: str = "policy"
    mapping: Callable = lambda agent_id: agent_id
    custom_config: Dict[str, Any] = None
    model_config: Dict[str, Any] = None

    def __post_init__(self):
        self.custom_config = self.custom_config or {}
        self.model_config = self.model_config or {}

    def copy(self, key=None):
        return PolicyConfig(
            policy=self.policy,
            observation_space=self.observation_space(key)
            if isinstance(self.observation_space, LambdaType)
            else self.observation_space,
            action_space=self.action_space(key)
            if isinstance(self.action_space, LambdaType)
            else self.action_space,
            human_readable=self.human_readable,
            mapping=self.mapping,
            custom_config=self.custom_config,
            model_config=self.model_config,
        )

    def new_policy_instance(self, key) -> Any:
        policy_cls = (
            self.policy(key) if isinstance(self.policy, LambdaType) else self.policy
        )
        observation_space = (
            self.observation_space(key)
            if isinstance(self.observation_space, LambdaType)
            else self.observation_space
        )
        action_space = (
            self.action_space(key)
            if isinstance(self.action_space, LambdaType)
            else self.action_space
        )
        model_config = (
            self.model_config(key)
            if isinstance(self.model_config, LambdaType)
            else self.model_config
        )
        custom_config = (
            self.custom_config(key)
            if isinstance(self.custom_config, LambdaType)
            else self.custom_config
        )
        return policy_cls(observation_space, action_space, model_config, custom_config)


@dataclass
class TrainingConfig:
    trainer_cls: type
    hyper_params: Dict[str, Any]


@dataclass
class StopperConfig:
    stopper: Union[str, Callable]
    config: Dict[str, Any]

    def __post_init__(self):
        self.config = {} if self.config is None else self.config


@dataclass
class EnvDescription:
    creator: Callable
    name: str
    config: Dict[str, Any]

    def create_instance(self):
        """Generate an environment instance"""

        return self.creator(**self.config)


@dataclass
class RolloutConfig:
    caller: Callable
    """rollout function"""

    fragment_length: int
    """ specify the fragment length (maximum of data transitions) at each iteration """

    max_step: int = 10
    """ specify the number of rollout episodes at each iteration """

    num_simulation: int = 10
    """ specify simulation round every iteration/policy combination """

    mode: str = "episodic"
    """ deprecated: rollout mode, could be episodic or stpping. """

    vector_mode: bool = False
    """ enable vector env or not """

    max_episode: int = 10

    remote_env: bool = False
    """ enable remote env, defaults to False """

    def __post_init__(self):
        # value check
        assert self.max_episode, self.max_episode
        assert self.max_step, self.max_step
        assert self.num_simulation, self.num_simulation


@dataclass
class TrainingFeedback:
    agent_id: AgentID
    trainings: Sequence[Any]
    rollouts: Sequence[Any]

    def __post_init__(self):
        self._merged_training_result = dict()
        self._merged_rollout_result = dict()
        self._info = dict()

    def pretty_output(self):
        return self._info
        # return pretty_dict(self._info)

    def merge(self):
        length_training = len(self.trainings)
        length_rollout = len(self.rollouts)

        self._info["training"] = {"num_epoch": length_training}
        self._info["rollout"] = {"num_epoch": length_rollout}

        temp = defaultdict(lambda: 0.0)
        for rollout in self.rollouts[-10:]:
            for k, v in rollout.items():
                temp[k] += v / len(self.rollouts)
        self._info["rollout"].update(temp)

        # report only the converged results
        for k, v in self.trainings[-1].items():
            self._info["training"][k] = v

        self._merged_training_result.update(self._info["training"])
        self._merged_rollout_result.update(self._info["rollout"])

    def payoff(self):
        return self._merged_rollout_result[MetricType.REWARD]


@dataclass
class RolloutFeedback:
    agent_rewards: Dict[str, float]
    agent_steps: Dict[str, int]
    policy_mapping: Dict[AgentID, PolicyID]
    identify: str = None

    def __post_init__(self):
        self.identify = f"RolloutFeedback_{time.time()}"


@dataclass
class SimulationFeedback:
    num_simulation: int
    policy_mapping: Dict[AgentID, PolicyID]
    agent_rewards: Dict[str, float] = None
    identify: str = None

    def __post_init__(self):
        self.identify = f"SimulationFeedback_{time.time()}"
        self.agent_rewards = {agent: 0.0 for agent in self.policy_mapping.keys()}

    def step(self, agent_id: AgentID, observation, action, reward, info):
        self.agent_rewards[agent_id] += reward


class LearningMode:
    ON_POLICY = "on_policy"
    OFF_POLICY = "off_policy"


class TrainingParadigm:
    DECT = "decentralized_execution_and_centralized_training"
    CECT = "fully_centralized"
