import numpy as np
from jax import random
import jax.numpy as jnp
from base_agent import BaseAgent
import model as m
from optimizers import tuning_mattern


class DQNAgent(BaseAgent):
    def __init__(self, layer_spec=None, **kwargs):
        super().__init__(**kwargs)
        self.predict = lambda observations: m.predict(self.params, observations)
        self.batched_predict = lambda observations: m.batch_func(m.predict)(
            self.params, observations
        )
        
        self.opt_state=None
        self.inter_results = {}
        self.method = kwargs["method"]
        self.opt_name = kwargs["opt_name"]
        self.effective_dim = -1
        if self.method == "optex":
            self.effective_dim  = kwargs["edim"]
            self.inter_results.update({"length_scale": 1.0})
        
        if layer_spec is not None:
            self.register_param(
                "params", m.init_network_params(layer_spec, self.key)
            )
            self.layer_spec = layer_spec
        self.params_shape = self.get_shapes(self.params)
        
    def act(self, observation, explore=True):
        self.key, subkey = random.split(self.key)

        self.epsilon = (self.epsilon_decay ** self.steps_trained) * (
            self.epsilon_init - self.epsilon_min
        ) + self.epsilon_min
        if explore and random.uniform(self.key) < self.epsilon:
            action = random.randint(subkey, (), 0, self.layer_spec[-1])
        else:
            Q = m.predict(self.params, observation)
            action = jnp.argmax(Q)

        return int(action)
    
    def get_shapes(self, nested_params):
        shapes = []
        for tuple_params in nested_params:
            tuple_shapes = []
            for array_param in tuple_params:
                tuple_shapes.append(array_param.shape)
            shapes.append(tuple(tuple_shapes))
        return shapes
    
    def update(self, batch_size):
        def get_Q_for_actions(params, observations, actions):
            """Calculate Q values for action that was taken"""
            pred_Q_values = m.batch_func(m.predict)(params, observations)
            pred_Q_values = index_Q_at_action(pred_Q_values, actions)
            return pred_Q_values
        
        # (
        #     obs,
        #     actions,
        #     r,
        #     next_obs,
        #     dones,
        # ) = self.buffer.sample_batch(batch_size)  #数据采样
        
        sample =self.buffer.sample_batch(batch_size)
        obs = [row[0] for row in sample]
        actions = [row[1].astype(jnp.int64) for row in sample]
        r = [row[2] for row in sample]
        next_obs = [row[3] for row in sample]
        dones= [row[4] for row in sample]

        
        max_next_Q_values = list(map(self.get_max_Q_values , next_obs))#Q
        target_Q_values = list(map(self.get_target_Q_values , r, dones, max_next_Q_values)) #target q

        #  Caclulate loss and perform gradient descent
        loss, self.params, self.opt_state = m.update(
           self.method, self.opt_name ,get_Q_for_actions, self.params, obs, actions, target_Q_values, self.opt_state, self.lr, self.params_shape, num_parall=self.num_parall, edim=self.effective_dim, inter_results=self.inter_results
        )
        
        if self.method == "optex" and self.steps_trained%5==0:
            xs, ys = self.inter_results["x_history"], self.inter_results["g_history"]
            xs, ys = np.concatenate(xs, axis=0), np.concatenate(ys, axis=0)
            indices = np.random.choice(len(xs), int(0.8 * len(xs))).tolist()
            target_indices = [i for i in range(len(xs)) if i not in indices]

            length_scale = tuning_mattern(
                xs[indices],
                ys[indices],
                xs[target_indices],
                ys[target_indices],
                choice=[0.01, 0.1, 1, 10, 100],
                effective_dim=5000
            )
            self.inter_results.update({
                "length_scale": length_scale,
            })
        
        self.steps_trained += 1
        return loss

    def get_max_Q_values(self, next_obs):
        """Calculate max Q values for next state"""
        next_Q_values = self.batched_predict(next_obs)
        max_next_Q_values = jnp.max(next_Q_values, axis=-1)
        return max_next_Q_values

    def get_target_Q_values(self, rewards, dones, max_next_Q_values):
        """Calculate target Q values based on discounted max next_Q_values"""
        target_Q_values = (
            rewards + (1 - dones) * self.discount_factor * max_next_Q_values
        )
        return target_Q_values


class DQNFixedTarget(DQNAgent):
    def __init__(self, layer_spec=None, update_every=100, **kwargs):
        super().__init__(layer_spec=layer_spec, **kwargs)
        self.update_every = update_every
        # Need to update key so target_params != params
        self.key = random.split(self.key)[0]
        if layer_spec is not None:
            self.register_param(
                "target_params", m.init_network_params(layer_spec, self.key)
            )

        # Target functions
        self.batched_predict_target = lambda observations: m.batch_func(m.predict)(
            self.target_params, observations
        )

    def get_max_Q_values(self, next_obs):
        if self.steps_trained % self.update_every == 0:
            # Jax arrays are immutable
            self.target_params = self.params
        next_Q_values = self.batched_predict_target(next_obs)
        max_next_Q_values = jnp.max(next_Q_values, axis=-1)
        return max_next_Q_values


class DDQN(DQNFixedTarget):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def get_max_Q_values(self, next_obs):
        if self.steps_trained % self.update_every == 0:
            # Jax arrays are immutable
            self.target_params = self.params

        next_Q_values_target = self.batched_predict_target(next_obs)
        next_Q_values_online = self.batched_predict(next_obs)
        actions = jnp.argmax(next_Q_values_online, axis=-1)
        return index_Q_at_action(next_Q_values_target, actions)


def index_Q_at_action(Q_values, actions):
    # Q_values [bsz, n_actions]
    # Actions [bsz,]
    idx = jnp.expand_dims(actions, -1)
    # pred_Q_values [bsz,]
    pred_Q_values = jnp.take_along_axis(Q_values, idx, -1).squeeze()
    return pred_Q_values
