import numpy as np
from scipy.spatial import distance_matrix
from typing import Sequence

import flax.linen as nn
from flax.training import train_state
import jax
import jax.numpy as jnp
import optax

from opelab.core.baseline import Baseline
from opelab.core.baselines.kernels import sum_of_gaussian_kernels
from opelab.core.data import DataType, to_numpy
from opelab.core.mlp import MLP
from opelab.core.policy import Policy


class IPSKernelJAX(Baseline):

    def __init__(self, layers: Sequence[int],
                 learning_rate: float=0.001, widths: float | np.ndarray | None=None,
                 batch_size: int=256, iters: int=20000, reg: float=0.0) -> None:
        if isinstance(widths, float):
            widths = [widths]
        self.widths = widths
        self.model = MLP(layers + [1, ], nn.tanh, output_activation=lambda s: jnp.log(1.0 + jnp.exp(s)))
        self.iters = iters
        self.optimizer = optax.adam(learning_rate)
        self.batch_size = batch_size

        def predict_w_fn(params, states):
            return jax.vmap(self.model.apply, in_axes=(None, 0))(params, states)

        def delta_w_fn(w, next_w, policy_ratio):
            w = w / jnp.mean(w)
            next_w = next_w / jnp.mean(next_w)
            policy_ratio = policy_ratio.reshape(w.shape)
            return w * policy_ratio - next_w
                
        def train_fn(mlp_state, batch1, batch2, widths):

            def loss_fn(params):
                states1, next_states1, policy_ratio1 = batch1
                states2, next_states2, policy_ratio2 = batch2
                kernel_mat = sum_of_gaussian_kernels(next_states1, next_states2, widths)
                w1 = mlp_state.apply_fn(params, states1)
                w2 = mlp_state.apply_fn(params, states2)
                next_w1 = mlp_state.apply_fn(params, next_states1)
                next_w2 = mlp_state.apply_fn(params, next_states2)
                delta_w1 = delta_w_fn(w1, next_w1, policy_ratio1).reshape((-1, 1))
                delta_w2 = delta_w_fn(w2, next_w2, policy_ratio2).reshape((1, -1))
                loss_val = jnp.sum(delta_w1 * delta_w2 * kernel_mat) / states1.shape[0]
                leaves = jax.tree_util.tree_leaves(params)
                for param in leaves:
                    loss_val += reg * jnp.sum(jnp.square(param)) / float(len(leaves))
                return loss_val

            loss, grads = jax.value_and_grad(loss_fn)(mlp_state.params)
            mlp_state = mlp_state.apply_gradients(grads=grads)
            return mlp_state, loss

        self.predict_w_fn = jax.jit(predict_w_fn)
        self.train_fn = jax.jit(train_fn)

    def _train_density(self, data, target, behavior, gamma=1.0): 
        states, _, _, next_states, _, rewards, policy_ratios = to_numpy(data, target, behavior)
        n = states.shape[0]
        
        # get the optimal weight
        if self.widths is None:
            i = np.random.choice(n, size=4096)
            widths = [np.median(distance_matrix(states[i], states[i]))]
            print('\n' + f'optimal kernel width for IPS {widths}')
        else:
            widths = self.widths

        # initialize model for w(s) and optimizer
        params = self.model.init(jax.random.key(0), states[0])
        mlp_state = train_state.TrainState.create(
            apply_fn=self.model.apply,
            params=params,
            tx=self.optimizer
        )

        # train w(s)
        print('training w(s) for IPS')
        mean_loss, count_loss = 0.0, 0
        for t in range(self.iters):
            i = np.random.choice(n, size=min(n, self.batch_size), replace=True)
            batch = (states[i], next_states[i], policy_ratios[i])
            mlp_state, loss_val = self.train_fn(mlp_state, batch, batch, widths)
            mean_loss = (mean_loss * count_loss + loss_val) / (count_loss + 1)
            count_loss += 1    
                    
            if t % (self.iters // 25) == 0:
                w = self.predict_w_fn(mlp_state.params, states).reshape(rewards.shape)
                ratios = w * policy_ratios.reshape(w.shape)
                estimate = np.sum(ratios * rewards) / np.sum(ratios)
                print(f'iter {t} loss {mean_loss:.6f} estimate {estimate:.4f}')
                mean_loss, count_loss = 0.0, 0
                
        return mlp_state.params

    def evaluate(self, data:DataType, target:Policy, behavior:Policy, gamma:float=1.0, reward_estimator=None) -> float: 
        params = self._train_density(data, target, behavior, gamma)
        states, *_, rewards, policy_ratios = to_numpy(data, target, behavior)
        w = self.predict_w_fn(params, states).reshape(rewards.shape)
        ratios = w * policy_ratios.reshape(w.shape)
        return np.sum(ratios * rewards) / np.sum(ratios)


class IPSGANJAX(Baseline):

    def __init__(self, layers_w: Sequence[int], layers_f: Sequence[int],
                 learning_rate_w: float=0.001, learning_rate_f: float=0.001,
                 batch_size: int=256, iters: int=20000, f_iters: int=5) -> None:
        self.model = MLP(layers_w + [1, ], nn.tanh, output_activation=lambda s: jnp.log(1.0 + jnp.exp(s)))
        self.f = MLP(layers_f + [1, ], nn.tanh)
        self.iters = iters
        self.f_iters = f_iters
        self.optimizer_w = optax.adam(learning_rate_w)
        self.optimizer_f = optax.adam(learning_rate_f)
        self.batch_size = batch_size

        def predict_w_fn(params, states):
            return jax.vmap(self.model.apply, in_axes=(None, 0))(params, states)

        def delta_w_fn(w, next_w, policy_ratio):
            w = w / jnp.mean(w)
            next_w = next_w / jnp.mean(next_w)
            policy_ratio = policy_ratio.reshape(w.shape)
            return w * policy_ratio - next_w

        def train_w_fn(w_state, f_state, batch):

            def loss_fn(w_params, f_params):
                states, next_states, policy_ratio = batch
                w = w_state.apply_fn(w_params, states)
                next_w = w_state.apply_fn(w_params, next_states)
                delta_w = delta_w_fn(w, next_w, policy_ratio)
                next_f = f_state.apply_fn(f_params, next_states)
                norm_f = jnp.sqrt(jnp.mean(next_f ** 2)) + 1e-15
                next_f = next_f / norm_f
                return jnp.mean(delta_w * next_f) ** 2

            loss, grads = jax.value_and_grad(loss_fn, argnums=0)(w_state.params, f_state.params)
            w_state = w_state.apply_gradients(grads=grads)
            return w_state, loss

        def train_f_fn(w_state, f_state, batch):

            def loss_fn(w_params, f_params):
                states, next_states, policy_ratio = batch
                w = w_state.apply_fn(w_params, states)
                next_w = w_state.apply_fn(w_params, next_states)
                delta_w = delta_w_fn(w, next_w, policy_ratio)
                next_f = f_state.apply_fn(f_params, next_states)
                norm_f = jnp.sqrt(jnp.mean(next_f ** 2)) + 1e-15
                next_f = next_f / norm_f
                return -jnp.mean(delta_w * next_f) ** 2

            loss, grads = jax.value_and_grad(loss_fn, argnums=1)(w_state.params, f_state.params)
            f_state = f_state.apply_gradients(grads=grads)
            return f_state, loss

        self.predict_w_fn = jax.jit(predict_w_fn)
        self.train_w_fn = jax.jit(train_w_fn)
        self.train_f_fn = jax.jit(train_f_fn)

    def _train_density(self, data, target, behavior, gamma=1.0): 
        states, _, _, next_states, _, rewards, policy_ratios = to_numpy(data, target, behavior)
        n = states.shape[0]
        
        # initialize model for w(s) and optimizer
        w_params = self.model.init(jax.random.key(0), states[0])
        w_state = train_state.TrainState.create(
            apply_fn=self.model.apply,
            params=w_params,
            tx=self.optimizer_w
        )

        # initialize model for f(s) and optimizer
        f_params = self.f.init(jax.random.key(0), states[0])
        f_state = train_state.TrainState.create(
            apply_fn=self.f.apply,
            params=f_params,
            tx=self.optimizer_f
        )

        # train w(s) and f(s) jointly
        print('training w(s) and f(s) for IPS')
        mean_loss, count_loss = 0.0, 0
        for t in range(self.iters):
            i = np.random.choice(n, size=min(n, self.batch_size), replace=True)
            batch = (states[i], next_states[i], policy_ratios[i])
            for _ in range(self.f_iters):
                f_state, _ = self.train_f_fn(w_state, f_state, batch)
            w_state, loss_val = self.train_w_fn(w_state, f_state, batch)
            mean_loss = (mean_loss * count_loss + loss_val) / (count_loss + 1)
            count_loss += 1       
                 
            if t % (self.iters // 20) == 0:
                w = self.predict_w_fn(w_state.params, states).reshape(rewards.shape)
                ratios = w * policy_ratios.reshape(w.shape)
                estimate = np.sum(ratios * rewards) / np.sum(ratios)
                print(f'iter {t} loss {mean_loss:.6f} estimate {estimate:.4f}')
                mean_loss, count_loss = 0.0, 0
                
        return w_state.params

    def evaluate(self, data:DataType, target:Policy, behavior:Policy, gamma:float=1.0, reward_estimator=None) -> float: 
        params = self._train_density(data, target, behavior, gamma)
        states, *_, rewards, policy_ratios = to_numpy(data, target, behavior)
        w = self.predict_w_fn(params, states).reshape(rewards.shape)
        ratios = w * policy_ratios.reshape(w.shape)
        return np.sum(ratios * rewards) / np.sum(ratios)

