import numpy as np
from typing import Sequence
from tqdm import tqdm

import flax.linen as nn
from flax.training import train_state
import jax 
import jax.numpy as jnp
import jax.random as jrandom
from jax import vmap, grad, jit, value_and_grad
import optax 

from opelab.core.baseline import Baseline
from opelab.core.data import DataType, to_numpy
from opelab.core.mlp import MLP
from opelab.core.policy import Policy


class BestDice(Baseline):
    
    def __init__(self, 
                 alpha_r, 
                 lr_q:float = 1e-4, 
                 lr_delta:float=1e-4,
                 lr_zeta:float=1e-4,
                 lr_lambda:float=1e-4,
                 layers: Sequence[int] = [500, 500], 
                 epochs:int=50, 
                 batch_size: int=256, 
                 verbose: int = 0, 
                 seed: int = 0):
        
        self.lr_q = lr_q
        self.lr_delta = lr_delta 
        self.lr_zeta = lr_zeta
        self.lr_lambda = lr_lambda
        self.epochs = epochs 
        self.batch_size = batch_size
        self.verbose = verbose
        self.seed = seed
        self.q_optimizer = optax.adam(lr_q)
        self.delta_opt = optax.adam(lr_delta)
        self.zeta_opt = optax.adam(lr_zeta)
        self.lambda_opt = optax.adam(lr_lambda)
        self.alpha_r = alpha_r
        self.processed_data=False
        
        self.Q = MLP(layers+[1,], nn.relu, output_activation=lambda s: s)
        self.delta = MLP(layers+[1,], nn.relu, output_activation=lambda s: s)
        self.zeta = MLP(layers+[1,], nn.relu, output_activation=lambda s: nn.softplus(s))
        self.lamb = jnp.array(0.0)
        
        @jit
        def predict_w_fn(params, states, actions):
            actions = actions[:, None] if actions.ndim == 1 else actions
            xus = jnp.concatenate([states, actions], axis=-1)
            return vmap(self.Q.apply, in_axes=(None, 0))(params, xus)
        
        @jit 
        def predict_zeta_fn(params, states, actions):
            actions = actions[:, None] if actions.ndim == 1 else actions
            xus = jnp.concatenate([states, actions], axis=-1)
            return vmap(self.zeta.apply, in_axes=(None, 0))(params, xus)
            
        
        def train_fn(q_state, delta_state, zeta_state, lamb, initial_states, initial_actions, states, actions, rewards, next_states, next_actions, gamma):
            
            def loss_fn(q_params, delta_params, zeta_params, lamb):
                #first line of equation 10
                first_component = (1-gamma)*jnp.mean(predict_w_fn(q_params, initial_states, initial_actions)) + lamb
                
                #second line of equation 10
                zetas = predict_zeta_fn(zeta_params, states, actions)
                rhs = self.alpha_r*rewards + gamma * predict_w_fn(q_params, next_states, next_actions) - predict_w_fn(q_params, states, actions) - lamb - predict_w_fn(delta_params, states, actions)
                
                second_component = jnp.mean(zetas*rhs)
                
                #third line
                third_component = jnp.mean(jnp.square(predict_w_fn(delta_params, states, actions)))
                loss = first_component + second_component + third_component
                
                return loss
            
            #loss = loss_fn(q_state.params, delta_state.params, zeta_state.params, lamb)
            loss, grads = value_and_grad(loss_fn, argnums=(0,1,3))(q_state.params, delta_state.params, zeta_state.params, lamb)
            
            
            
            q_state = q_state.apply_gradients(grads=grads[0])
            delta_state = delta_state.apply_gradients(grads=grads[1])
            lamb = jnp.clip(lamb - self.lr_lambda * grads[2])
            
            zeta_grad = grad(lambda a,b,c,d: -1*loss_fn(a,b,c,d), argnums=(2))(q_state.params, delta_state.params, zeta_state.params, lamb)
            zeta_state = zeta_state.apply_gradients(grads=zeta_grad)
            
            return q_state, delta_state, zeta_state, lamb, loss
        
        self.predict_w_fn = predict_w_fn
        self.predict_zeta_fn = jit(predict_zeta_fn)
        self.train_fn = jit(train_fn)
        
        
    
    def train_networks(self, data, traj_data, target: Policy , behavior: Policy, gamma:float):
        key = jax.random.PRNGKey(self.seed)
        
        
        init_states = [tau['states'][0] for tau in traj_data]
        init_states = np.stack(init_states)
        init_actions = target.sample(init_states)
        
        states, states_un, actions, next_states, next_states_un, rewards, policy_ratio, terminals = data #to_numpy(data, target=target, behavior=behavior, return_terminals=True)
        
        xus = jnp.concatenate([states, actions], axis=-1)
        
        q_params = self.Q.init(jrandom.PRNGKey(self.seed), xus[0])
        delta_params = self.delta.init(jrandom.PRNGKey(self.seed), xus[0])
        zeta_params = self.zeta.init(jrandom.PRNGKey(self.seed), xus[0])
        lamb = self.lamb
        
        q_state = train_state.TrainState.create(
            apply_fn = self.predict_w_fn,
            params=q_params,
            tx = self.q_optimizer
        )
        delta_state = train_state.TrainState.create(
            apply_fn = self.predict_w_fn,
            params = delta_params,
            tx = self.delta_opt
            )
        zeta_state = train_state.TrainState.create(
            apply_fn = self.predict_w_fn,
            params = zeta_params,
            tx = self.zeta_opt
        )
        
        
        with tqdm(range(self.epochs)) as tp:
            for epoch in tp:
                key, b_key = jrandom.split(key)
                
                batch_ordering = jrandom.permutation(b_key, jnp.arange(len(states)))
                epoch_loss = [] 
                density_ratios_list = []
                density_ratios_list_max = []
                for j in  range(len(states)//self.batch_size):
                        batch = batch_ordering[j*self.batch_size:(j+1)*self.batch_size]
                        batch_states, batch_actions, batch_rewards, batch_next_states = states[batch], actions[batch], rewards[batch], next_states[batch]
                        batch_next_actions = target.sample(batch_next_states)
                        
                        key, i_key = jrandom.split(key)
                        batch_init = jrandom.randint(i_key, (self.batch_size, ), 0, len(init_states))
                        batch_init_states, batch_init_actions = init_states[batch_init], init_actions[batch_init]
                        
                        #q_state, delta_state, zeta_state, lamb, initial_states, initial_actions, states, actions, rewards, next_states, next_actions, gamma)
                        q_state, delta_state, zeta_state, lamb, loss_val = self.train_fn(q_state, delta_state, zeta_state, lamb,
                                                                                    batch_init_states, batch_init_actions,
                                                                                    batch_states, batch_actions, batch_rewards, batch_next_states, batch_next_actions, gamma)
                        
                        
                        density_ratios = self.predict_zeta_fn(zeta_state.params, batch_states, batch_actions)
                        
                        epoch_loss.append(loss_val)
                        density_ratios_list_max.append(jnp.max(density_ratios))
                        density_ratios_list.append(jnp.mean(density_ratios))
                        
                
                        
                tp.set_postfix(loss = np.mean(epoch_loss), dratios = np.max(density_ratios_list_max), dratios_mean = np.mean(density_ratios_list))
        self.q_params = q_state.params 
        self.delta_params = delta_state.params 
        self.zeta_params = zeta_state.params 
        self.lamb = lamb
        return self.q_params, self.delta_params, self.zeta_params, self.lamb 
    
    def evaluate(self, data, target, behavior, gamma = 1, reward_estimator=None):
        
        traj_data = data
        if not self.processed_data:
            self.data = to_numpy(data, target=target, behavior=behavior, return_terminals=True)
            data = self.data 
            self.processed_data=True
        else:
            data = self.data 

        states, states_un, actions, next_states, next_states_un, rewards, policy_ratio, terminals = data
        
        q_params, delta_params, zeta_params, lamb = self.train_networks(data, traj_data, target, behavior, gamma)
        density_ratios = self.predict_zeta_fn(zeta_params, states, actions)
        estimate = np.mean(jnp.mean(density_ratios*rewards)).item()/(1-gamma)
        print(jnp.mean(density_ratios), jnp.max(density_ratios))
        print(estimate)
        return estimate