import time
from collections import namedtuple

import jax
import jax.numpy as jnp
import numpy as np
from jax import grad, jit, vmap
from jax.example_libraries import optimizers
from jax.random import PRNGKey, permutation, split
from scipy.optimize import minimize
from tqdm import tqdm
import haiku as hk
import optax
import optax
from typing import Iterator, Mapping, Tuple
import wandb
import utils

Batch = Mapping[str, np.ndarray]

# todo: conditional imputation: SMC with posterior correction
# todo: permute over feature orderings
# todo: octane
# todo: bernoulli density estimation

import models.old_copula.ori_copula_AR_functions as cop_AR

# import models.copula_AR_functions as cop_AR

# Compute overhead v_{1:n}, return fit copula object for prediction
def fit_copula_density(y, n_perm=10, seed=20, n_perm_optim=None, n_optim=None):
    # Set seed for scipy
    np.random.seed(seed)

    # Generate random permutations
    key = PRNGKey(seed)
    key, *subkey = split(key, n_perm + 1)
    subkey = jnp.array(subkey)
    y_perm = vmap(permutation, (0, None))(subkey, y)

    # Initialize parameter and put on correct scale to lie in [0,1]/[0,\infty]
    d = jnp.shape(y)[1]
    rho_lengths_init = (
        jnp.ones(d) * wandb.config.model["init_length"]
    )  # first dimension is bandwidth, rest is d-1 length scales
    rho_lengths_init = rho_lengths_init.at[0].set(wandb.config.model["init_rho"])
    hyperparam_init = rho_lengths_init.at[0].set(jnp.log(1 / rho_lengths_init[0] - 1))
    hyperparam_init = hyperparam_init.at[1:].set(jnp.log(rho_lengths_init[1:]))

    # calculate rho_opt
    # either use all permutations or a selected number to fit bandwidth
    if n_perm_optim is None:
        y_perm_opt = y_perm
    else:
        y_perm_opt = y_perm[0:n_perm_optim]

    if n_optim is not None:
        y_perm_opt = y_perm[:, 0:n_optim]

    if wandb.config.base["compile"]:
        # Compiling
        print("Compiling...")
        start = time.time()
        temp = cop_AR.fun_jll_perm_sp(hyperparam_init, y_perm_opt)
        temp = cop_AR.grad_jll_perm_sp(hyperparam_init, y_perm_opt)
        temp = cop_AR.update_pn_loop_perm(
            rho_lengths_init[0], rho_lengths_init[1:], y_perm
        )[0].block_until_ready()
        end = time.time()
        print("Compilation time: {}s".format(round(end - start, 3)))
        wandb.log({"Compilation time": round(end - start, 3)})

    # Optimize with SLSQP (Quasi-Newton)
    print("Optimizing...")
    start = time.time()

    if wandb.config.model["maxiter"]:
        opt = minimize(
            fun=cop_AR.fun_jll_perm_sp,
            x0=hyperparam_init,
            args=(y_perm_opt),
            # x0=hyperparam_init[None],
            # args=(y_perm_opt[None]),
            jac=cop_AR.grad_jll_perm_sp,
            method="SLSQP",
            options={"maxiter": wandb.config.model["maxiter"]},
        )
        opt_fun = opt.fun
        # check optimization succeeded
        if opt.success == False:
            print("Optimization failed")

        # unscale hyperparameter
        hyperparam_opt = opt.x
    else:
        hyperparam_opt = hyperparam_init
        opt_fun = 999

    rho_lengths_opt = jnp.array(hyperparam_opt)
    rho_lengths_opt = rho_lengths_opt.at[0].set(1 / (1 + jnp.exp(hyperparam_opt[0])))
    rho_lengths_opt = rho_lengths_opt.at[1:].set(jnp.exp(hyperparam_opt[1:]))
    end = time.time()

    print("Optimization time: {}s".format(round(end - start, 3)))
    wandb.log({"Optimization time": round(end - start, 3)})
    print({"rho": rho_lengths_opt[0]})
    print(
        {f"lengthscale_{i}": rho_lengths_opt[i] for i in range(1, len(rho_lengths_opt))}
    )
    wandb.log({"rho": rho_lengths_opt[0]})
    wandb.log(
        {f"lengthscale_{i}": rho_lengths_opt[i] for i in range(1, len(rho_lengths_opt))}
    )

    print("Fitting...")
    start = time.time()
    vn_perm = cop_AR.update_pn_loop_perm(
        rho_lengths_opt[0], rho_lengths_opt[1:], y_perm
    )[0].block_until_ready()

    end = time.time()
    print("Fit time: {}s".format(round(end - start, 3)))

    copula_density_obj = namedtuple(
        "copula_density_obj", ["vn_perm", "rho_lengths_opt", "preq_loglik", "y_perm"]
    )
    return copula_density_obj(vn_perm, rho_lengths_opt, -opt_fun, y_perm)


# Predict on test data using copula object
def predict_copula_density(copula_density_obj, y_test):
    start = time.time()
    logcdf_conditionals, logpdf_joints = cop_AR.update_ptest_loop_perm_av(
        copula_density_obj.vn_perm,
        copula_density_obj.rho_lengths_opt[0],
        copula_density_obj.rho_lengths_opt[1:],
        copula_density_obj.y_perm,
        y_test,
    )
    logcdf_conditionals = logcdf_conditionals.block_until_ready()  # for accurate timing
    end = time.time()
    return logcdf_conditionals, logpdf_joints

