import collections

import numpy as np

from .basics import convert

WARMUP_PARAMS = {
    "_m_unimix": 1.0,
    "_m_horizon": 128,
    "_m_ent_scale": 1.0,
    "_e_command_repeat_counts": 512,
    "_e_enable_replan": False,
    "_e_num_vehicles": 50,
}

# Obstacle
OBS_FINETUNE_PARAMS = {
    "_m_unimix": 0.0,
    "_m_horizon": 128,
    "_e_command_repeat_counts": 512,
    "_m_ent_scale": 2.0,
    "_e_enable_replan": True,
    "_e_num_vehicles": 0,
}

# Four-lane v300
FINETUNE_PARAMS = {
    "_m_unimix": 0.0,
    "_m_horizon": 16,
    "_e_command_repeat_counts": 64,
    "_m_ent_scale": 1.0,
    "_e_enable_replan": True,
    "_e_num_vehicles": 300,
}

# Evaluation
EVAL_PARAMS = {
    "_m_unimix": 0.0,
    "_m_horizon": 1,
    "_e_command_repeat_counts": 64,
    "_m_ent_scale": 1.0,
    "_e_enable_replan": True,
    "_e_num_vehicles": 300,
}


class Driver:
    WARMUP = False

    _CONVERSION = {
        np.floating: np.float32,
        np.signedinteger: np.int32,
        np.uint8: np.uint8,
        bool: bool,
    }

    def __init__(self, env, mode="train", **kwargs):
        assert len(env) > 0
        self._env = env
        self._kwargs = kwargs
        self._on_steps = []
        self._on_episodes = []
        self._mode = mode

        self._reward_history = []

        if mode == "eval":
            for k, v in EVAL_PARAMS.items():
                setattr(self, k, v)
        elif Driver.WARMUP:
            for k, v in WARMUP_PARAMS.items():
                setattr(self, k, v)
        else:
            for k, v in FINETUNE_PARAMS.items():
                setattr(self, k, v)

        self.reset()

    def reset(self):
        self._acts = {k: convert(np.zeros((len(self._env),) + v.shape, v.dtype)) for k, v in self._env.act_space.items()}
        self._acts["reset"] = np.ones(len(self._env), bool)
        self._eps = [collections.defaultdict(list) for _ in range(len(self._env))]
        self._eps_info = [collections.defaultdict(list) for _ in range(len(self._env))]
        self._state = None

    def on_step(self, callback):
        self._on_steps.append(callback)

    def on_episode(self, callback):
        self._on_episodes.append(callback)

    def __call__(self, policy, steps=0, episodes=0):
        step, episode = 0, 0
        while step < steps or episode < episodes:
            step, episode = self._step(policy, step, episode)

    def _step(self, policy, step, episode):
        assert all(len(x) == len(self._env) for x in self._acts.values())
        acts = {k: v for k, v in self._acts.items() if not k.startswith("log_")}
        obs, info = self._env.step(acts)
        obs = {k: convert(v) for k, v in obs.items()}
        info = {k: convert(v) for k, v in info.items()}
        assert all(len(x) == len(self._env) for x in obs.values()), obs

        acts, self._state = policy(obs, self._state, **self._kwargs)
        acts = {k: convert(v) for k, v in acts.items()}
        if obs["is_last"].any():
            mask = 1 - obs["is_last"]
            acts = {k: v * self._expand(mask, len(v.shape)) for k, v in acts.items()}
        acts["reset"] = obs["is_last"].copy()
        self._acts = acts

        trns = {**obs, **acts}
        if obs["is_first"].any():
            for i, first in enumerate(obs["is_first"]):
                if first:
                    self._eps[i].clear()
                    self._eps_info[i].clear()

        for i in range(len(self._env)):
            trn = {k: v[i] for k, v in trns.items()}
            inf = {k: v[i] for k, v in info.items()}
            [self._eps[i][k].append(v) for k, v in trn.items()]
            [self._eps_info[i][k].append(v) for k, v in inf.items()]
            [fn(trn, inf, i, **self._kwargs) for fn in self._on_steps]

        step += 1

        if obs["is_last"].any():
            for i, done in enumerate(obs["is_last"]):
                if done:
                    ep = {k: convert(v) for k, v in self._eps[i].items()}
                    ep_info = {k: convert(v) for k, v in self._eps_info[i].items()}
                    [fn(ep.copy(), ep_info.copy(), i, **self._kwargs) for fn in self._on_episodes]
                    ep_reward = ep["reward"].astype(np.float64).sum()
                    self._reward_history.append(ep_reward)
                    print("Episode reward: ", ep_reward, "Average reward: ", np.mean(self._reward_history[-100:]))
                    self._update_parameters()
                    episode += 1

                    if self._mode == "eval":
                        self._state[1]["step"] = np.zeros_like(self._state[1]["step"], np.int8)
                        self._state[2]["step"] = np.zeros_like(self._state[1]["step"], np.int8)

        return step, episode

    def _expand(self, value, dims):
        while len(value.shape) < dims:
            value = value[..., None]
        return value

    def _update_parameters(self):
        if len(self._reward_history) < 100:
            return

        avg_reward = np.mean(self._reward_history[-100:])
        print(f"[Driver] Average reward: {avg_reward}")

        if Driver.WARMUP:
            adaptive_events = [
                (450.0, {"_m_ent_scale": 1.0}),
                (250.0, {"_e_num_vehicles": 300, "_e_enable_replan": True}),
                (120.0, {"_m_ent_scale": 1.5}),
                (100.0, {"_m_unimix": 0.0, "_m_horizon": 16, "_e_command_repeat_counts": 64, "_m_ent_scale": 3.0}),
            ]
        else:
            adaptive_events = [
                # (500.0, {'_m_ent_scale': 1.0, '_m_horizon': 16, '_m_e_command_repeat_counts': 64}),
            ]

        for threshold, params in adaptive_events:
            if avg_reward > threshold:
                for key, value in params.items():
                    setattr(self, key, value)
                print(f"[Driver] Adaptively adjusted parameters to {params} as reward exceeded {threshold}!")
                break
