import numpy as np
import copy

from expground.types import AgentID, DataArray, Dict, Any, Tuple
from expground.logger import Log
from expground.common.schedules import LinearSchedule
from expground.algorithms.base_trainer import Trainer
from expground.algorithms.loss_func import LossFunc
from expground.algorithms.base_policy import Policy

from .policy import PPO
from .config import DEFAULT_CONFIG


def concat_dicts(dicts, shuffle: bool = False):
    res = dicts[0]
    for e in dicts[1:]:
        for k, v in res.items():
            # import pdb; pdb.set_trace()
            if len(v.shape) < 2:
                res[k] = np.hstack((v, e[k]))
            else:
                res[k] = np.vstack((v, e[k]))
            # print("k and shape", k, res[k].shape)

    if shuffle:
        n_batch = len(list(res.values())[0])
        idx = np.random.permutation(n_batch)
        for k, v in res.items():
            res[k] = v[idx]

    return res


class PPOTrainer(Trainer):
    def __init__(
        self,
        loss_func: LossFunc,
        training_config: Dict[str, Any],
        policy_instance: Policy,
    ):

        _tc = copy.deepcopy(DEFAULT_CONFIG["training_config"])
        if training_config is None:
            Log.warning(
                "PPOTrainer doesn't detect legal training config, will load default as:\n{}".format(
                    DEFAULT_CONFIG["training_config"]
                )
            )
            training_config = _tc
        training_config.update(_tc)

        super(PPOTrainer, self).__init__(
            loss_func, training_config=training_config, policy_instance=policy_instance
        )

    def _before_loss(self, policy: PPO, batch):
        if self._training_config["share"]:
            batch = concat_dicts(list(batch.values()), shuffle=True)
        else:
            assert (
                len(batch) == 1
            ), "PPO supports only single agent mode now, batched agents should less than 2."
            batch = list(batch.values())[0]

        return batch, {}

    def _after_loss(self, policy: PPO, step_counter: int):
        policy.update_target(tau=self._training_config["tau"])

    def set_pretrain(self, pmode):
        pass

    def get_eps(self):
        pass
