import numpy as np
from scipy.spatial import distance_matrix
from typing import Sequence

import flax.linen as nn
from flax.training import train_state
import jax
import jax.numpy as jnp
import optax

from opelab.core.baseline import Baseline
from opelab.core.baselines.kernels import sum_of_gaussian_kernels
from opelab.core.data import DataType, to_numpy
from opelab.core.mlp import MLP
from opelab.core.policy import Policy

   
class ModelBasedKernel(Baseline):
    
    def __init__(self, p_log_grad_fn, penalty_nonzero: float=0.01,
                 learning_rate: float=0.001, iters: int=20000,
                 widths: float | np.ndarray | None=None, batch_size: int=256,
                 layers: Sequence[int]=[32, 32]):
        if isinstance(widths, float):
            widths = [widths]
        self.p_log_grad_fn = p_log_grad_fn
        self.penalty_nonzero = penalty_nonzero
        self.iters = iters
        self.optimizer = optax.adam(learning_rate)
        self.widths = widths
        self.batch_size = batch_size
        self.model = MLP(layers + [1, ], nn.tanh, output_activation=lambda s: jnp.log(1.0 + jnp.exp(s)))
        
        def steinalized_kernel(x1, x2, score1, score2, widths):
            k0 = sum_of_gaussian_kernels(x1, x2, widths)            
            k1, k2 = jax.grad(sum_of_gaussian_kernels, argnums=(0, 1))(x1, x2, widths)
            k1, k2 = k1.reshape((-1,)), k2.reshape((-1,))
            k12 = jax.jacfwd(jax.jacrev(sum_of_gaussian_kernels, argnums=0), argnums=1)(x1, x2, widths)
            k12 = jnp.trace(k12)        
            score1, score2 = score1.reshape((-1,)), score2.reshape((-1,))
            return k12 + jnp.sum(k1 * score2) + jnp.sum(k2 * score1) + k0 * jnp.sum(score1 * score2)
            
        def steinalized_kernel_batch(xs1, xs2, scores1, scores2, widths):
            return jax.vmap(lambda x1, s1: jax.vmap(
                lambda x2, s2: steinalized_kernel(
                    x1, x2, s1, s2, widths))(xs2, scores2))(xs1, scores1)
        
        def stein_loss(weights, xs, scores, policy_ratios, widths, mean_pen):
            stein_k = steinalized_kernel_batch(xs, xs, scores, scores, widths)
            factors = policy_ratios.reshape((-1, 1)) * weights.reshape((-1, 1))
            loss = jnp.mean(stein_k * factors * factors.T) / jnp.mean(jnp.abs(stein_k))
            penalty = mean_pen * (jnp.mean(factors) - 1) ** 2
            return loss + penalty
        
        def predict_w_fn(params, states):
            return jax.vmap(self.model.apply, in_axes=(None, 0))(params, states)
            
        def train_fn(mlp_state, states, next_states, scores, policy_ratios, widths):
            
            def loss_fn(params):
                weights = predict_w_fn(params, states)
                return stein_loss(weights, next_states, scores, policy_ratios, widths, penalty_nonzero)
            
            loss, grads = jax.value_and_grad(loss_fn)(mlp_state.params)
            mlp_state = mlp_state.apply_gradients(grads=grads)
            return mlp_state, loss
        
        self.predict_w_fn = jax.jit(predict_w_fn)
        self.train_fn = jax.jit(train_fn)
        self.stein_loss = jax.jit(stein_loss)
        self.kernel_fn = jax.jit(steinalized_kernel_batch)
    
    def _train_density(self, data, target, behavior, gamma=1.0): 
        
        # prepare the data for batching
        states, states_un, actions, next_states, next_states_un, rewards, policy_ratios = \
            to_numpy(data, target, behavior)
        n = states.shape[0]
        
        # compute the scores of the dynamics model
        p_scores = self.p_log_grad_fn(states_un, actions, next_states_un).reshape((n, -1))
        
        # get the optimal width
        if self.widths is None:
            print('\n' + 'finding optimal kernel width using median distance')
            i = np.random.choice(n, size=4096)
            widths = [np.median(distance_matrix(states[i], states[i]))]
            print(f'optimal kernel width for ModelBased {widths}')
        else:
            widths = self.widths
        
        # initialize model for w(s) and optimizer
        params = self.model.init(jax.random.PRNGKey(0), states[0])
        mlp_state = train_state.TrainState.create(
            apply_fn=self.predict_w_fn,
            params=params,
            tx=self.optimizer
        )
        
        # train the weighting over the data
        print('\n' + 'training weighting for ModelBased')
        mean_loss, mean_stein, count_loss = 0.0, 0.0, 0
        for t in range(self.iters):
            i = np.random.choice(n, size=min(n, self.batch_size), replace=True)
            mlp_state, loss_val = self.train_fn(
                mlp_state, states[i], next_states[i], p_scores[i], policy_ratios[i], widths) 
            w = self.predict_w_fn(mlp_state.params, states[i]) 
            stein_val = self.stein_loss(w, next_states[i], p_scores[i], policy_ratios[i], widths, 0.)
            mean_loss = (mean_loss * count_loss + loss_val) / (count_loss + 1)
            mean_stein = (mean_stein * count_loss + stein_val) / (count_loss + 1)
            count_loss += 1
        
            # evaluation
            if t % (self.iters // 25) == 0:
                w = self.predict_w_fn(mlp_state.params, states).reshape(rewards.shape)  
                ratios = w * policy_ratios.reshape(w.shape)
                estimate = np.sum(ratios * rewards) / np.sum(ratios)
                print(f'iter \t {t} \t loss \t {mean_loss:6f} \t estimate \t {estimate:.4f} '
                      f'\t mean \t {np.mean(ratios):.4f} \t stein {mean_stein:.6f}')
                mean_loss, mean_stein, count_loss = 0.0, 0.0, 0
        
        self.params = mlp_state.params
        return self.params
    
    # def estimate_accuracy_w(self, it, params, policy, n_rollouts=500, horizon=500):
    #     import matplotlib.pyplot as plt
    #     import gymnasium as gym
    #     env = gym.make('Pendulum-v1')
    #     states = []
    #     states_act = []
    #     for i in range(n_rollouts):
    #         state, _ = env.reset()
    #         state_save = np.array([np.arctan2(state[1], state[0]), state[2]])
    #         states.append(state_save)
    #         states_act.append(state)
    #         for _ in range(horizon):
    #             action = policy.sample(state)
    #             state, _, terminated, *_ = env.step(action)
    #             state_save = np.array([np.arctan2(state[1], state[0]), state[2]])
    #             states.append(state_save)
    #             states_act.append(state)
    #             if terminated:
    #                 break
    #     states = np.stack(states, axis=0)
    #     states_act = np.stack(states_act, axis=0)
    #     s_min = np.min(states_act, axis=0, keepdims=True)
    #     s_max = np.max(states_act, axis=0, keepdims=True)
    #     states_act = (states_act - s_min) / (s_max - s_min)
    #
    #     H, eth, ethdot = np.histogram2d(states[:, 0], states[:, 1], bins=20)
    #     H = np.clip(H, 0, np.quantile(H, 0.95))
    #     H = H / np.sum(H)
    #     plt.imshow(H, origin='upper')
    #     plt.xlabel('theta-dot')
    #     plt.ylabel('theta')
    #     plt.savefig(f'real_{it}.pdf')      
    #
    #     plt.clf()
    #     H2 = np.zeros_like(H)
    #     for i, (th1, th2) in enumerate(zip(eth[:-1], eth[1:])):
    #         for j, (thd1, thd2) in enumerate(zip(ethdot[:-1], ethdot[1:])):
    #             thm = (th1 + th2) / 2
    #             thd = (thd1 + thd2) / 2
    #             state = np.asarray([np.cos(thm), np.sin(thm), thd])
    #             state = (state - s_min) / (s_max - s_min)
    #             H2[i, j] = self.predict_w_fn(params, state).item()
    #     H2 = np.clip(H2, 0, np.quantile(H2, 0.95))
    #     plt.imshow(H2, origin='upper')
    #     plt.xlabel('theta-dot')
    #     plt.ylabel('theta')
    #     plt.savefig(f'gen_{it}.pdf')     
    #     plt.clf()
        
    def evaluate(self, data:DataType, target:Policy, behavior:Policy, gamma:float=1.0, reward_estimator=None) -> float: 
        states, *_, rewards, policy_ratios = to_numpy(data, target, behavior)
        params = self._train_density(data, target, behavior, gamma)
        w = self.predict_w_fn(params, states).reshape(rewards.shape)  
        ratios = w * policy_ratios.reshape(w.shape)
        return np.sum(ratios * rewards) / np.sum(ratios)
    

class ModelBasedGAN(Baseline):
    
    def __init__(self, p_log_grad_fn, n_states: int, critic_penalty: float=0.5,
                 learning_rate_w: float=0.001, learning_rate_critic: float=0.001,
                 iters: int=100000, critic_iters: int=3, batch_size: int=256,
                 layers_w: Sequence[int]=[32, 32], layers_critic: Sequence[int]=[32, 32]):
        self.p_log_grad_fn = p_log_grad_fn
        self.iters = iters
        self.critic_iters = critic_iters
        self.optimizer_w = optax.adam(learning_rate_w)
        self.optimizer_critic = optax.adam(learning_rate_critic)
        self.batch_size = batch_size
        
        self.model_w = MLP(layers_w + [1, ], nn.tanh, output_activation=lambda s: jnp.log(1.0 + jnp.exp(s)))
        self.model_critic = MLP(layers_critic + [n_states, ], lambda s: s * jax.nn.sigmoid(s), output_activation=lambda s: s * jax.nn.sigmoid(s))
        
        # prediction of weights
        def predict_w_fn(params, states):
            return jax.vmap(self.model_w.apply, in_axes=(None, 0))(params, states)
                
        # LSDE loss function
        def lsde_loss_fn(params, state, p_score, penalty):
            state = state.reshape((-1,))
            p_score = p_score.reshape((-1,))
            f_state = self.model_critic.apply(params, state).reshape((-1,))
            f_jac_state = jax.jacfwd(self.model_critic.apply, argnums=1)(params, state)
            lsde = jnp.sum(p_score * f_state) + jnp.trace(f_jac_state)
            critic_loss = -penalty * jnp.sum(f_state ** 2)
            total_loss = lsde + critic_loss
            return total_loss
        
        # loss to train critic
        def critic_loss_fn(params_w, params_f, states, next_states, policy_ratios, p_scores):
            weights = predict_w_fn(params_w, states).reshape((-1,))
            losses = jax.vmap(lsde_loss_fn, in_axes=(None, 0, 0, None))(
                params_f, next_states, p_scores, critic_penalty).reshape((-1,))
            loss_val = -jnp.mean(weights * policy_ratios.reshape((-1,)) * losses)
            return loss_val
        
        # loss to train weights
        def weight_loss_fn(params_w, params_f, states, next_states, policy_ratios, p_scores):
            weights = predict_w_fn(params_w, states).reshape((-1,))
            losses = jax.vmap(lsde_loss_fn, in_axes=(None, 0, 0, None))(
                params_f, next_states, p_scores, 0.0).reshape((-1,))
            loss_val = jnp.mean(weights * policy_ratios.reshape((-1,)) * losses)
            return loss_val
        
        # training function for critic
        def train_critic_fn(opt_state, params_w, params_f, states, next_states, policy_ratios, p_scores):
            loss, grads = jax.value_and_grad(critic_loss_fn, argnums=1)(
                params_w, params_f, states, next_states, policy_ratios, p_scores)
            updates, opt_state = self.optimizer_critic.update(grads, opt_state)
            params_f = optax.apply_updates(params_f, updates)
            return params_f, opt_state, loss
        
        # training function for weights
        def train_weight_fn(opt_state, params_w, params_f, states, next_states, policy_ratios, p_scores):
            loss, grads = jax.value_and_grad(weight_loss_fn, argnums=0)(
                params_w, params_f, states, next_states, policy_ratios, p_scores)
            updates, opt_state = self.optimizer_w.update(grads, opt_state)
            params_w = optax.apply_updates(params_w, updates)
            return params_w, opt_state, loss
        
        self.predict_w_fn = jax.jit(predict_w_fn)
        self.train_w_fn = jax.jit(train_weight_fn)
        self.train_f_fn = jax.jit(train_critic_fn)
    
    def _train_density(self, data, target, behavior, gamma=1.0): 
        
        # prepare the data for batching
        states, states_un, actions, next_states, next_states_un, rewards, policy_ratios = \
            to_numpy(data, target, behavior, True)
        n = states.shape[0]
        
        # initial score of dynamics model
        p_scores = self.p_log_grad_fn(states_un, actions, next_states_un).reshape((n, -1))
        
        # initialize model for w(s) and optimizer
        key = jax.random.PRNGKey(0)
        key, subkey1, subkey2 = jax.random.split(key, num=3)
        params_w = self.model_w.init(subkey1, states[0])
        params_f = self.model_critic.init(subkey2, states[0])
        opt_state_w = self.optimizer_w.init(params_w)
        opt_state_f = self.optimizer_critic.init(params_f)
        
        # train the weighting over the data
        print('\n' + 'training weighting for ModelBased')
        mean_loss, count_loss = 0.0, 0
        for t in range(self.iters):
            for _ in range(self.critic_iters):
                i = np.random.choice(n, size=min(n, self.batch_size), replace=True)
                params_f, opt_state_f, _ = self.train_f_fn(
                    opt_state_f, params_w, params_f, states[i], next_states[i],
                    policy_ratios[i], p_scores[i])
            i = np.random.choice(n, size=min(n, self.batch_size), replace=True)
            params_w, opt_state_w, loss_val = self.train_w_fn(
                opt_state_w, params_w, params_f, states[i], next_states[i],
                policy_ratios[i], p_scores[i])
            mean_loss = (mean_loss * count_loss + loss_val) / (count_loss + 1)
            count_loss += 1
            
            # evaluation
            if t % (self.iters // 25) == 0:
                w = self.predict_w_fn(params_w, states).reshape(rewards.shape)  
                ratios = w * policy_ratios.reshape(w.shape)
                estimate = np.sum(ratios * rewards) / np.sum(ratios)
                print(f'iter \t {t} \t loss \t {mean_loss:.6f} \t estimate \t {estimate:.4f} '
                      f'\t mean \t {np.mean(ratios):.4f}')
                mean_loss, count_loss = 0.0, 0
            
        return params_w
    
    def evaluate(self, data:DataType, target:Policy, behavior:Policy, gamma:float=1.0, reward_estimator=None) -> float: 
        states, *_, rewards, policy_ratios = to_numpy(data, target, behavior, True)
        params = self._train_density(data, target, behavior, gamma)
        w = self.predict_w_fn(params, states).reshape(rewards.shape)  
        ratios = w * policy_ratios.reshape(w.shape)
        return np.sum(ratios * rewards) / np.sum(ratios)
    
