import jax
import jax.numpy as jnp
from approxml.utils import grad_log_normal, gen_simulation_samples
from functools import partial
import numpy as np
from sklearn.linear_model import LogisticRegression
from approxml.simulators import mvt_norm_simulator


def cross_val_sm(prop_sigma_values,
                theta_t,
                simulator_fn,
                n_prop,
                n_sim_dst,
                obs,
                key,
                n_param_dim,
                n_data_dim,
                num_folds=5,
                lamb=0.5):
    
    mean_val_losses = []

    for prop_sigma in prop_sigma_values:
        fold_size = n_prop // num_folds
        prop_cov = prop_sigma * jnp.eye(n_param_dim)

        key, subkey = jax.random.split(key)

        gen_sim_fn = partial(
            gen_simulation_samples,
            simulator_fn=simulator_fn,
            prop_sim_fn=partial(mvt_norm_simulator, cov=prop_cov),
            n_prop=n_prop,
            n_sim_dst=n_sim_dst
        )

        thetas_q, sims_q, _ = gen_sim_fn(subkey, theta_t)

        val_losses = []
        for i in range(num_folds):
            val_start = i * fold_size
            thetas_q_val = thetas_q[val_start: val_start + fold_size]
            thetas_q_train = jnp.concatenate((thetas_q[:val_start], thetas_q[val_start + fold_size:]), axis=0)
            sims_q_val = sims_q[val_start: val_start + fold_size]
            sims_q_train = jnp.concatenate((sims_q[:val_start], sims_q[val_start + fold_size:]), axis=0)
            
            sims_q_val_aug = jnp.concatenate([sims_q_val, jnp.ones_like(sims_q_val[..., :1])], axis=-1)
            
            grad_fn = partial(fit_linear_sm,
            gen_sim_fn=None,
            grad_log_prop_fn=partial(grad_log_normal, cov=prop_cov),
            n_sim_dst=n_sim_dst,
            n_prop=n_prop - fold_size,
            thetas_q=thetas_q_train,
            sims_q=sims_q_train,
            lamb=lamb)

            key, subkey = jax.random.split(key)
            W, _, _, _ = grad_fn(subkey, theta_t)

            est_grad = jnp.einsum('mk,jik->jim', W.T, sims_q_val_aug)
            
            y_0 = np.zeros(obs.shape[0], dtype=int)
            y_t = np.ones(fold_size*n_sim_dst, dtype=int)

            X_all = np.vstack([obs, sims_q_val.reshape(fold_size*n_sim_dst,n_data_dim)])
            y_all = np.concatenate([y_0, y_t])

            clf_bal = LogisticRegression(class_weight='balanced')
            clf_bal.fit(X_all, y_all)

            probs_bal = clf_bal.predict_proba(np.array(sims_q_val.reshape(fold_size*n_sim_dst,n_data_dim)))
            r_hat = (probs_bal[:, 0] / probs_bal[:, 1]) 
            r_hat = jnp.array(r_hat).reshape(fold_size, n_sim_dst)

            val_loss = score_matching_loss(est_grad, thetas_q_val, theta_t, fold_size, n_sim_dst, prop_cov)
            val_losses.append(val_loss)

        mean_val_losses.append(jnp.array(val_losses).mean())

    min_index = jnp.argmin(jnp.array(mean_val_losses))
    min_prop_sigma = prop_sigma_values[min_index]

    min_index, min_prop_sigma 

    return min_prop_sigma, jnp.array(mean_val_losses)

def score_matching_loss(
  pred_q,
  thetas_q,
  theta_t,
  n_prop,
  n_sim_dst,
  prop_theta_cov,
  WF = None
):
    theta_dim = theta_t.shape[0]
    grad_log_q_1 = jax.vmap(grad_log_normal, 
                            in_axes=(0,None,None))(thetas_q,
                                                   theta_t,
                                                   prop_theta_cov)
                            
    grad_log_q_2 = jnp.repeat(grad_log_q_1, n_sim_dst, axis=0).reshape(n_prop, n_sim_dst, theta_dim)
    
    op = pred_q ** 2 + 2 * pred_q * grad_log_q_2

    if WF is None:
      op = jnp.mean(jnp.mean(op.sum(axis=-1), axis=1),axis=0)
    else:
      op = jnp.mean(jnp.mean(WF * op.sum(axis=-1), axis=1),axis=0)
    return op

def fit_linear_sm(
  key, 
  theta_t, 
  gen_sim_fn, 
  grad_log_prop_fn,
  n_prop, 
  n_sim_dst, 
  lamb=1e-3,
  thetas_q=None,
  sims_q=None):
    theta_dim = theta_t.shape[0]
    
    if thetas_q is None and sims_q is None and gen_sim_fn is not None:
        thetas_q, sims_q, _ = gen_sim_fn(key, theta_t) 

    grad_log_q_1 = jax.vmap(grad_log_prop_fn, in_axes=(0,None))(thetas_q, theta_t)
    grad_log_q_2 = jnp.repeat(grad_log_q_1, n_sim_dst, axis=0).reshape(n_prop, n_sim_dst, theta_dim)
    sims_q_aug = jnp.concatenate([sims_q, jnp.ones_like(sims_q[..., :1])], axis=-1)
    G_j = jax.vmap(lambda x: x.T @ x, in_axes=0)(sims_q_aug).sum(0)
    reg_term = lamb * jnp.eye(G_j.shape[0], M=G_j.shape[1])
    W = - jnp.linalg.inv(G_j + reg_term) @ jax.vmap(jax.vmap(jnp.outer, in_axes=(0, 0)), in_axes=(0, 0))(sims_q_aug, grad_log_q_2).sum(0).sum(0) 
    return W, sims_q, sims_q_aug, thetas_q