import copy
from collections import Counter
from typing import Dict

import jax
import jax.numpy as jnp
import numpy as np
import optax
from ConfigSpace import Configuration, ConfigurationSpace, Float, Integer

from slimdqn.networks.base_dqn import BaseDQN


class HPGenerator:
    def __init__(
        self,
        key,
        observation_dim,
        n_actions,
        features,
        cnn,
        learning_rate,
        adam_eps,
        hp_space,
        exploitation_type,
        gamma_validation,
    ) -> None:
        self.observation_dim = observation_dim
        self.n_actions = n_actions
        self.features = features
        self.cnn = cnn
        self.learning_rate = learning_rate
        self.adam_eps = adam_eps
        self.hp_space = hp_space
        self.exploitation_type = exploitation_type
        self.gamma_validation = gamma_validation
        space = {"gamma": Float("gamma", bounds=(hp_space["gamma_range"][0], hp_space["gamma_range"][1]))}

        self.config_space = ConfigurationSpace(seed=int(key[1]), space=space)

    def from_hp_detail(self, key, hp_detail: Dict, params=None):
        q = BaseDQN(self.features, self.cnn, self.n_actions)
        optimizer = optax.adam(self.learning_rate, eps=self.adam_eps)

        def update_and_loss_fn(params, rewards, next_values, batch_samples, optimizer_state):
            targets = rewards + hp_detail["gamma"] * next_values
            targets_validation = rewards + self.gamma_validation * next_values

            loss, grad_loss = q.value_and_grad(params, targets, targets_validation, batch_samples)
            updates, optimizer_state = optimizer.update(grad_loss, optimizer_state, params)
            params = optax.apply_updates(params, updates)

            return loss, params, optimizer_state

        hp_fn = {
            "apply_fn": jax.jit(q.q_network.apply),
            "update_and_loss_fn": jax.jit(update_and_loss_fn),
            "best_action_fn": jax.jit(q.best_action),
        }
        if params is None:
            params = q.q_network.init(key, jnp.zeros(self.observation_dim))
        optimizer_state = optimizer.init(params)

        return hp_fn, params, optimizer_state

    def sample(self, key):
        hp_detail = dict(self.config_space.sample_configuration())

        hp_fn, params, optimizer_state = self.from_hp_detail(key, hp_detail)

        return hp_fn, params, optimizer_state, hp_detail

    def exploit(self, key, metrics):
        n_networks = len(metrics)

        if self.exploitation_type == "elitism":
            # Make sure the best HP is kept
            selected_indices = [np.nanargmax(metrics)]
            for _ in range(n_networks - 1):
                key, selection_key = jax.random.split(key)
                random_indices = jax.random.choice(selection_key, jnp.arange(n_networks), (3,), replace=False)
                selected_indices.append(random_indices[np.nanargmax(metrics[random_indices])].item())

            selected_indices_counter = Counter(selected_indices)

            indices_replacing_hps = []
            indices_new_hps = []
            for idx in range(n_networks):
                # if the idx has not been selected it will be replaced
                if idx not in selected_indices:
                    indices_new_hps.append(idx)
                # if the idx has been selected more that once (- 1), it should be added to the list of replacing idx
                else:
                    indices_replacing_hps.extend([idx] * (selected_indices_counter[idx] - 1))
        elif self.exploitation_type == "truncation":
            cut_new_hps = np.around(n_networks * 0.2).astype(int)
            cut_replacing_hps = n_networks - cut_new_hps
            partition_indices_ = np.argpartition(metrics, (cut_new_hps, cut_replacing_hps))
            # Replace the nans first
            partition_indices = np.roll(partition_indices_, np.isnan(metrics).sum())

            indices_new_hps = partition_indices[:cut_new_hps]
            indices_replacing_hps = partition_indices[cut_replacing_hps:]

        return key, indices_new_hps, indices_replacing_hps

    def explore(self, key, indices_new_hps, indices_replacing_hps, hp_fns, params, optimizer_states, hp_details):
        for idx in range(len(indices_new_hps)):
            new_hp_detail = hp_details[indices_replacing_hps[idx]].copy()
            old_params = {"params": copy.deepcopy(params[indices_replacing_hps[idx]]["params"])}

            key, scale_key = jax.random.split(key)
            scale = jax.random.uniform(scale_key, (), minval=-0.0005, maxval=0.0005)
            new_hp_detail["gamma"] = np.clip(
                hp_details[indices_replacing_hps[idx]]["gamma"] * (1 + scale),
                self.config_space["gamma"].lower,
                self.config_space["gamma"].upper,
            )

            key, hp_key = jax.random.split(key)
            hp_fns[indices_new_hps[idx]], params[indices_new_hps[idx]], optimizer_states[indices_new_hps[idx]] = (
                self.from_hp_detail(hp_key, new_hp_detail, old_params)
            )
            hp_details[indices_new_hps[idx]] = new_hp_detail

        return indices_new_hps, hp_fns, params, optimizer_states, hp_details

    def exploit_and_explore(self, key, metrics, hp_fns, params, optimizer_states, hp_details):
        explore_key, indices_new_hps, indices_replacing_hps = self.exploit(key, metrics)
        return self.explore(
            explore_key, indices_new_hps, indices_replacing_hps, hp_fns, params, optimizer_states, hp_details
        )
