import logging
import os
import pickle

import hydra
import jax
import jax.numpy as jnp
from omegaconf import DictConfig

from priorg.sim.distributions import Independent, Normal, Uniform
from priorg.sim.distributions.continuous import Normal
from priorg.sim.methods.guidance import prior_guide_theta_prior_only
from priorg.sim.methods.metrics import compute_mmd_unweighted, compute_rmse
from priorg.sim.tasks.task import get_task

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


@hydra.main(version_base=None, config_path="cfg", config_name="evaluate.yaml")
def evaluate(cfg: DictConfig):
    # Get the model path and seed from config
    task_name = cfg.task.name
    seed = cfg.seed

    # Construct the model filename with seed
    model_filename = f"model_{seed}.pkl"
    model_path = os.path.join(cfg.model_path, task_name, model_filename)

    # Load the model
    with open(model_path, "rb") as f:
        model = pickle.load(f)

    key = jax.random.PRNGKey(seed)

    task = get_task(task_name)

    observation_generator = task.get_observation_generator(
        condition_mask_fn=cfg.eval.condition_mask_fn
    )

    key, key_obs = jax.random.split(key)
    observation_stream = observation_generator(key_obs)

    p = Independent(
        Uniform(low=jnp.array([0.0, -2.0]), high=jnp.array([2.0, 2.0])),
        reinterpreted_batch_ndims=1,
    )

    q = Independent(
        Normal(loc=jnp.array([1.0, 0.0]), scale=jnp.array([0.1, 0.2])),
        reinterpreted_batch_ndims=1,
    )

    cov_prior_q = jnp.diag(jnp.array([0.1, 0.2]) ** 2)

    ground_truth = []

    all_samples_prior_guide = []
    all_samples_wo_guide = []
    mmd_wo_guide = []
    mmd_prior_guide = []

    for _ in range(cfg.eval.num_observations):
        condition_mask, x_o, theta_o = next(observation_stream)
        ground_truth.append(theta_o)

        key, key_q = jax.random.split(key)
        mean_prior_q = Normal(loc=theta_o, scale=jnp.array([0.1, 0.2])).sample(key_q)

        @jax.jit
        def sample_once_prior(key):
            x_T = jnp.zeros([27])
            return prior_guide_theta_prior_only(
                model=model,
                key=key,
                condition_mask=model.condition_mask,
                x_o=x_o,
                x_T=x_T,
                theta_prior_mean=mean_prior_q,
                theta_prior_cov=cov_prior_q,
                prior_dim=task.get_theta_dim(),
                num_steps=cfg.prior_guide.num_steps,
                rho=cfg.prior_guide.rho,
                langevin_steps=cfg.prior_guide.langevin_steps,
                langevin_ratio=cfg.prior_guide.langevin_ratio,
            )

        keys = jax.random.split(key, cfg.eval.num_samples)
        batched_sample_prior = jax.vmap(sample_once_prior)
        samples_prior_guide = batched_sample_prior(keys)
        samples_prior_guide = jnp.array(samples_prior_guide[:, :2])
        all_samples_prior_guide.append(samples_prior_guide)

        samples_wo_guide = model.sample(
            cfg.eval.num_samples, x_o=x_o, condition_mask=condition_mask, rng=key
        )
        samples_wo_guide = jnp.array(samples_wo_guide)
        all_samples_wo_guide.append(samples_wo_guide)

        mmd_wo_guide.append(
            compute_mmd_unweighted(
                samples_wo_guide, jnp.array(theta_o).reshape(1, -1), lengthscale=1
            )
        )
        mmd_prior_guide.append(
            compute_mmd_unweighted(
                samples_prior_guide, jnp.array(theta_o).reshape(1, -1), lengthscale=1
            )
        )

    mmd_wo_guide = jnp.array(mmd_wo_guide)
    mmd_prior_guide = jnp.array(mmd_prior_guide)

    ground_truth = jnp.array(ground_truth)
    all_samples_prior_guide = jnp.array(all_samples_prior_guide)
    all_samples_wo_guide = jnp.array(all_samples_wo_guide)

    logger.info(f"evaluating {task_name} with seed {seed}")
    logger.info(f"MMD without guide: {jnp.mean(mmd_wo_guide)}")
    logger.info(f"MMD with guide: {jnp.mean(mmd_prior_guide)}")

    rmse_wo_guide = compute_rmse(
        ground_truth=ground_truth, samples=all_samples_wo_guide
    )
    rmse_prior_guide = compute_rmse(
        ground_truth=ground_truth, samples=all_samples_prior_guide
    )

    logger.info(f"RMSE without guide: {rmse_wo_guide}")
    logger.info(f"RMSE with guide: {rmse_prior_guide}")

    save_path = os.path.join(cfg.save_path, task_name, f"seed_{seed}.npz")
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    jnp.savez(
        save_path,
        mmd_wo_guide=mmd_wo_guide,
        mmd_prior_guide=mmd_prior_guide,
        rmse_wo_guide=rmse_wo_guide,
        rmse_prior_guide=rmse_prior_guide,
    )


if __name__ == "__main__":
    evaluate()
