import numpy as np
from typing import Sequence

import jax
import jax.numpy as jnp
import flax.linen as nn
import optax

from opelab.core.baseline import Baseline
from opelab.core.data import DataType, to_numpy
from opelab.core.mlp import MLP


class BiasedDice(Baseline):
    
    def __init__(self, alpha_r=1.0, alpha_xi=1.0, alpha_q=1e-6,
                 lr:float = 1e-3, 
                 layers: Sequence[int] = [128, 128], 
                 epochs:int=30, iter_per_epoch=1000,
                 batch_size: int=128, 
                 seed: int = 0):
        self.epochs = epochs 
        self.iter_per_epoch = iter_per_epoch
        self.batch_size = batch_size
        self.seed = seed
        self.optimizer = optax.adam(lr)       
        self.Q = MLP(layers+[1,], nn.relu, output_activation=lambda s: s)

        def q_fn(params, states, actions):
            xus = jnp.concatenate([states, actions], axis=-1)
            return self.Q.apply(params, xus)
        
        def loss_fn(params, s0, a0, s, a, r, s2, a2, gamma):
            q0 = q_fn(params, s0, a0).reshape((-1, 1))
            q = q_fn(params, s, a).reshape((-1, 1))
            q2 = q_fn(params, s2, a2).reshape((-1, 1))
            backup = alpha_r * r + gamma * q2
            loss1 = (1 - gamma) * jnp.mean(q0)
            loss2 = alpha_q * jnp.mean(jnp.square(q))
            loss3 = alpha_xi * jnp.mean(jnp.square((backup - q) / alpha_xi))
            return loss1 + loss2 + loss3
        
        def update_fn(opt_state, params, s0, a0, s, a, r, s2, a2, gamma):
            loss, grad = jax.value_and_grad(loss_fn)(params, s0, a0, s, a, r, s2, a2, gamma)
            updates, opt_state = self.optimizer.update(grad, opt_state, params=params)
            params = optax.apply_updates(params, updates)
            return opt_state, params, loss
        
        self.q_fn = jax.jit(q_fn)
        self.update_fn = jax.jit(update_fn)
    
    def evaluate(self, data, target, behavior, gamma = 1, reward_estimator=None):
        
        # compute (s0, a0, s, a, s2, a2) data set
        init_states = np.stack([tau['states'][0].reshape((-1,)) for tau in data], axis=0)       
        _, states_un, actions, _, next_states_un, rewards, _ = to_numpy(data, target, behavior)
        rewards = rewards.reshape((-1, 1))

        # initialize the q-function
        key = jax.random.PRNGKey(0)
        params = self.Q.init(key, jnp.ones((self.batch_size, states_un.shape[1] + actions.shape[1])))
        opt_state = self.optimizer.init(params)

        for epoch in range(self.epochs):
            mean_loss = 0.0
            for _ in range(self.iter_per_epoch):
                idx0 = np.random.randint(0, init_states.shape[0], size=self.batch_size)
                idx = np.random.randint(0, states_un.shape[0], size=self.batch_size)
                s0 = init_states[idx0]
                a0 = target.sample(init_states[idx0]).reshape((-1, actions.shape[1]))
                s = states_un[idx]
                a = actions[idx]
                r = rewards[idx]
                s2 = next_states_un[idx]
                a2 = target.sample(next_states_un[idx]).reshape((-1, actions.shape[1]))
                opt_state, params, loss = self.update_fn(opt_state, params, s0, a0, s, a, r, s2, a2, gamma)
                mean_loss += loss / float(self.iter_per_epoch)
            value_est = np.mean(self.q_fn(
                params, init_states, target.sample(init_states).reshape((-1, actions.shape[1]))))
            print(f'epoch {epoch}, mean loss {mean_loss:.4f}, value {value_est:.4f}')
        return value_est
            
        
