import numpy as np
from scipy.spatial import distance_matrix

from cvxopt import matrix
from cvxopt.solvers import qp
import jax
import jax.numpy as jnp

from opelab.core.baseline import Baseline
from opelab.core.baselines.kernels import sum_of_gaussian_kernels
from opelab.core.data import DataType, to_numpy
from opelab.core.policy import Policy

    
class ModelBasedQP(Baseline):
    
    def __init__(self, p_log_grad_fn, widths:float | np.ndarray | None=None, error_mean: float=0.2):
        if isinstance(widths, float):
            widths = [widths]
        self.p_log_grad_fn = p_log_grad_fn
        self.widths = widths
        self.error_mean = error_mean
        
        def steinalized_kernel(x1, x2, score1, score2, widths):
            k0 = sum_of_gaussian_kernels(x1, x2, widths)            
            k1, k2 = jax.grad(sum_of_gaussian_kernels, argnums=(0, 1))(x1, x2, widths)
            k1, k2 = k1.reshape((-1,)), k2.reshape((-1,))
            k12 = jax.jacfwd(jax.jacrev(sum_of_gaussian_kernels, argnums=0), argnums=1)(x1, x2, widths)
            k12 = jnp.trace(k12)        
            score1, score2 = score1.reshape((-1,)), score2.reshape((-1,))
            return k12 + jnp.sum(k1 * score2) + jnp.sum(k2 * score1) + k0 * jnp.sum(score1 * score2)
            
        def steinalized_kernel_batch(xs1, xs2, scores1, scores2, widths):
            return jax.vmap(lambda x1, s1: jax.vmap(
                lambda x2, s2: steinalized_kernel(
                    x1, x2, s1, s2, widths))(xs2, scores2))(xs1, scores1)
        
        self.kernel_fn = jax.jit(steinalized_kernel_batch)
    
    def solve_qp(self, kernel_matrix, policy_ratios):
        policy_ratios = policy_ratios.reshape((-1, 1))
        n = kernel_matrix.shape[0]
        mat = kernel_matrix / np.mean(np.abs(kernel_matrix))
        mat = mat * policy_ratios * policy_ratios.T
        mat = np.asarray(mat, dtype=np.double)
        P = matrix(mat)
        q = matrix([0.0] * n)
        G = matrix(0.0, (n + 4, n))
        for i in range(n):
            G[i, i] = -1.0
            G[-4, i] = -1.0
            G[-3, i] = 1.0
            G[-2, i] = -policy_ratios[i, 0]
            G[-1, i] = policy_ratios[i, 0]
        h = matrix(0.0, (n + 4, 1))
        h[-4, 0] = -(1.0 - self.error_mean) * n
        h[-3, 0] = (1.0 + self.error_mean) * n
        h[-2, 0] = -(1.0 - self.error_mean) * n
        h[-1, 0] = (1.0 + self.error_mean) * n
        sol = qp(P, q, G, h)['x']
        sol = np.array(sol).reshape((-1,))
        return sol
    
    def _train_density(self, data, target, behavior, gamma=1.0): 
        states, states_un, actions, next_states, next_states_un, _, policy_ratios = \
            to_numpy(data, target, behavior)
        n = states.shape[0]
        
        # compute the scores of the dynamics model
        p_scores = self.p_log_grad_fn(states_un, actions, next_states_un).reshape((n, -1))
        
        # get the optimal width
        if self.widths is None:
            print('\n' + 'finding optimal kernel width using median distance')
            i = np.random.choice(n, size=4096)
            widths = [np.median(distance_matrix(states[i], states[i]))]
            print(f'optimal kernel width for ModelBased {widths}')
        else:
            widths = self.widths
        
        # solve the QP to compute the optimal weights
        kernel_mat = self.kernel_fn(next_states, next_states, p_scores, p_scores, widths)
        w = self.solve_qp(kernel_mat, policy_ratios)
        return w
    
    def evaluate(self, data:DataType, target:Policy, behavior:Policy, gamma:float=1.0, reward_estimator=None) -> float: 
        *_, rewards, policy_ratios = to_numpy(data, target, behavior)
        w = self._train_density(data, target, behavior, gamma).reshape(rewards.shape)
        ratios = w * policy_ratios.reshape(w.shape)
        return np.sum(ratios * rewards) / np.sum(ratios)
    
