"""
This file is used to test implicit optimisation with N+L particles theta_l and theta^i
N y_t are updated with langevin dynamics and for each i,
To run langevin on theta_l we average over score function gradients of theta_i
"""

import argparse
import datetime
import time

from torch.utils.tensorboard import SummaryWriter
from dataclasses import dataclass
from functools import partial
import os

from jax_tqdm import scan_tqdm
import jax
import jax.experimental
from blackjax.util import generate_gaussian_noise

from jax import numpy as np
from jaxtyping import Array, PRNGKeyArray, PyTree
import optax

from exptax.models.base import BaseExperiment
from exptax.models.model_sources import Sources
from exptax.run_utils import init_part_approx
from exptax.optimizers.implicit import ImplicitState
from exptax.estimators import bounds_eig_fix_shape
from exptax.run_utils import create_meas_array, update_hist
from exptax.optimizers.base import Optim
from exptax.optimizers.implicit import (
    init_averaged,
    logger_implicit,
    step_diffusion_gibbs,
)
from exptax.base import ParticlesApprox


def print_metrics(xi: Array, n_meas: int, spce_val: float, snmc_val: float, wass_value):
    print(f"{n_meas=}, {xi=}, {spce_val=}, {snmc_val=}, {wass_value=}")


# Logging
def log_scalar(scalar_dict:dict, n_meas:int, writer:SummaryWriter):
    for name, value in scalar_dict.items():
        writer.add_scalar(name, float(value), n_meas)

@dataclass
class ImplicitDiffusion:
    """
    Optimizer with implicit diffusion on p(theta| xi, y)
    """

    def __new__(
        cls,
        model: BaseExperiment,
        opt_steps: int,
        optx_opt: optax.GradientTransformation,
        log_dir: str = "logs",
    ) -> Optim:
        def init(
            rng_key: PRNGKeyArray,
            particles: ParticlesApprox,
            design: Array,
            n_prior: int = 200,
        ) -> ImplicitState:
            return init_averaged(
                model,
                optx_opt,
                rng_key,
                particles,
                design,
                n_prior,
            )

        def run(
            rng_key: ParticlesApprox,
            state: ImplicitState,
            hist: PyTree,
            n_meas: int,
        ):
            @scan_tqdm(opt_steps)
            def step(state, tup):
                _, key = tup
                # return contrastive_implicit_step(key, hist, state, n_meas, model, optx_opt)
                return step_diffusion_gibbs(key, hist, state, n_meas, model, optx_opt)

            keys = jax.random.split(rng_key, opt_steps)
            end_state, hist = jax.lax.scan(step, state, (np.arange(0, opt_steps), keys))
            return end_state, hist

        def logger(
            particles: ParticlesApprox,
            particles_prior: ParticlesApprox,
            design: PyTree,
            hist: PyTree,
            n_meas: int,
        ):
            particles_prior = jax.tree.map(lambda x: x[None, :], particles_prior)
            particles = jax.tree.map(lambda x: x[None, :], particles)
            return logger_implicit(
                model, log_dir, particles, particles_prior, design, hist, n_meas
            )

        return Optim(init, run, logger)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="SMC experiment design")
    parser.add_argument("--n_prior", default=500, type=int)
    parser.add_argument("--n_contrastive", default=500, type=int)
    parser.add_argument("--num_sources", default=2, type=int)
    parser.add_argument("--name", default="", type=str)
    parser.add_argument("--iter_per_meas", default=1000, type=int)
    parser.add_argument("--num_meas", default=30, type=int)
    parser.add_argument("--prefix", default="", type=str)
    parser.add_argument("--plot_post", action=argparse.BooleanOptionalAction)
    parser.add_argument("--evals", action=argparse.BooleanOptionalAction)
    parser.add_argument("--rng_key", default=1, type=int)
    parser.add_argument(
        "--logging",
        default="warning",
        help="Provide logging level. Example --loglevel debug, default=warning",
    )
    args = parser.parse_args()

    rng_key = jax.random.PRNGKey(args.rng_key)
    log_dir = "runs/implicit/gibbs/"

    dir_name = "runs/" + args.prefix + "sources/" + args.name + "/"
    tensorboard_name = (
        dir_name
        + datetime.datetime.now().strftime("%H_%M_%S_%d_%m")
        + f"_{args.rng_key}_inner_{args.n_prior}_outer_{args.n_contrastive}"
    )
    writer = SummaryWriter(tensorboard_name)
    print("Logging to: ", tensorboard_name)

    model = Sources(
        max_signal=1e-4,
        base_signal=0.1,
        num_sources=args.num_sources,
        source_var=5.,
        rng_key=rng_key,
        noise_var=0.5,
        d=2,
    )

    opt_steps = args.iter_per_meas
    n_prior = args.n_prior
    n_contrastive= args.n_contrastive
    particles = init_part_approx(rng_key, model, n_contrastive, 1)
    particles = jax.tree.map(lambda x: jax.lax.collapse(x, 0, 2), particles)

    meas = args.num_meas
    hist = create_meas_array(rng_key, model, meas)
    xi = model.xi(rng_key)

    exponential_decay_scheduler = optax.exponential_decay(
    init_value=3e-2,
    transition_steps=opt_steps,
    decay_rate=0.97,
    transition_begin=int(opt_steps * 0.25),
    staircase=False,
    )
    adam = optax.chain(optax.zero_nans(), optax.adam(learning_rate=exponential_decay_scheduler), optax.scale(-1))
    sam = optax.contrib.sam(
                            optax.chain(optax.adam(2e-2), optax.scale(-1)),
                             optax.chain(optax.adam(2e-2), optax.scale(-1)))
    optx = sam
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    writer_logger = partial(log_scalar, writer=writer)
    start_time = time.time()

    implicit = ImplicitDiffusion(model, opt_steps, optx, log_dir=log_dir)
    state = implicit.init(rng_key, particles, xi, n_prior)
    rng_keys = jax.random.split(rng_key, meas)

    def step(carry, tuple):
        idx, key = tuple
        key_run, key_meas, key_noise = jax.random.split(key, 3)
        state, hist = carry
        state, hist_opt = implicit.run(key_run, state, hist, idx)

        particles, particles_prior, xi, ys, opt_state, _ = state

        y_0 = model.measure(key_meas, xi)
        hist = update_hist(hist, xi, y_0, idx)

        vals_bench = None
        if args.evals:
            spce_val, snmc_val = bounds_eig_fix_shape(
                model, model.ground_truth, hist, key, idx + 1
            )

            wass_value = model.wasserstein_eval(jax.tree.map(lambda x: x[None, :], particles))
            vals_bench =  {"SPCE": spce_val, "SNMC": snmc_val, **wass_value}

            jax.experimental.io_callback(writer_logger, None, vals_bench, idx)
            jax.experimental.io_callback(print_metrics, None, xi, idx, spce_val, snmc_val, wass_value)

        if args.plot_post:
            jax.experimental.io_callback(
                implicit.logger, None, particles, particles_prior, xi, hist_opt, idx
            )
        particles_prior = jax.tree.map(
            lambda x: jax.random.choice(key, x, shape=(n_prior,), replace=False, p=particles.weights), particles
        )
        particles = jax.tree.map(
            lambda x: jax.random.choice(key, x, shape=(n_contrastive,), replace=True, p=particles_prior.weights), particles_prior
        )

        #xi = generate_gaussian_noise(key_noise, xi, 0.0, 0.6)
        collapsed = jax.lax.collapse(particles_prior.thetas["theta"], 0, 2)
        xi = jax.tree.map(lambda x: jax.random.choice(key, collapsed, shape=(1,)), xi)
        opt_state = optx.init(xi)
        state = ImplicitState(particles, particles_prior, xi, ys, opt_state)
        return (state, hist), vals_bench


    last_state, _ = jax.lax.scan(step, (state, hist), (np.arange(0, meas), rng_keys))
    jax.block_until_ready(last_state)
    end = time.time()
    print("Total time: ", end - start_time)
