# Import

import os
import sys
import jax
import flax
import optax
import functools
import numpy as np
from tqdm import tqdm
import jax.numpy as jnp
import matplotlib.pyplot as plt
base_dir = os.path.split(os.path.dirname(__file__))[0]
src_dir = os.path.join(base_dir, 'src'); sys.path.append(src_dir)
from lsb import Lsb, init_r_by_samples, init_orthogonal_S, sb_opt
from util import SwissRollSampler, StandardNormalSampler, pca_plot
from util import get_cfg, jax_random_key_with_settings
from fn import vi, vi0

# Parameter

cfg = get_cfg('cfg', f'{base_dir}/cfg/default_swiss.py')
key = jax_random_key_with_settings(cfg.seed)


# Data

x_sampler = StandardNormalSampler(cfg.dim)
y_sampler = SwissRollSampler(cfg.dim)


# Model

sb = Lsb(cfg.dim, cfg.n_potential, cfg.eps, cfg.diagonal, cfg.S_init, jnp.float32)
key, subkey1, subkey2 = jax.random.split(key, 3)
vs = flax.core.FrozenDict(sb.init(subkey1, x_sampler.sample(cfg.batch_size), key=subkey2))
vs = init_orthogonal_S(init_r_by_samples(vs, y_sampler.sample(cfg.n_potential)))
opt = sb_opt(optax.adam(cfg.lr))

if "md" not in cfg.alg_name:
    st = opt.init(vs["params"])
else:
    vs_fast, vs_slow, vs_fast_capture = vs.copy(), vs.copy(), vs.copy()
    st = opt.init(flax.core.FrozenDict({"ps": vs["params"], "ps_fast": vs_fast["params"]}))


# Function

model = jax.jit(flax.linen.apply(lambda m, x, k: m(x, k), sb))
log_v = jax.jit(flax.linen.apply(lambda m, x: m.get_log_potential(x), sb))
log_c = jax.jit(flax.linen.apply(lambda m, x: m.get_log_C(x), sb))
drift = jax.jit(flax.linen.apply(lambda m, x, t: m.get_drift(x, t), sb))
sample_comp = functools.partial(jax.jit, static_argnums=2)(flax.linen.apply(lambda m, e, n, k: m.sample_comp(e, n, k), sb))
sample_comp_each = functools.partial(jax.jit, static_argnums=3)(flax.linen.apply(lambda m, e, x, n, k: m.sample_comp_each(e, x, n, k), sb))

solve_sde = functools.partial(jax.jit, static_argnums=2)(
    flax.linen.apply(lambda m, x, n, k: m.sample_euler_maruyama(x, n, k), sb))

def lsb_loss(vs, x, y):
    return jnp.mean(log_c(vs, x)) - jnp.mean(log_v(vs, y))

def lsbm_loss(vs, x, y, key):
    k1, k2 = jax.random.split(key)
    t = 0.99 * jax.random.uniform(k1, (x.shape[0], 1))
    x_t = y * t + x * (1 - t) + \
        jax.lax.sqrt(cfg.eps * t * (1 - t)) * \
        jax.random.normal(k2, x.shape)
    drift_target = (y - x_t) / (1 - t)
    diff = drift_target - drift(vs, x_t, jnp.reshape(t, -1))
    return 0.5 * jnp.mean(jax.vmap(jnp.inner)(diff, diff))

def md0_loss(eta, vs, vs_fast, vs_slow, key):
    key1, key2 = jax.random.split(key, 2)
    ys1 = jax.lax.stop_gradient(sample_comp(vs, cfg.eps, cfg.md.sample_size, key1))
    ys2 = jax.lax.stop_gradient(sample_comp(vs, cfg.eps, cfg.md.sample_size, key2))
    return \
        (1 - eta) * (vi0(cfg.diagonal, cfg.eps, vs, (vs_slow), ys1)) + \
        (  eta  ) * (vi0(cfg.diagonal, cfg.eps, vs, (vs_fast), ys2))

def md_loss(eta, vs, vs_fast, vs_slow, key):
    n = cfg.batch_size
    k1, k2, k3, k4 = jax.random.split(key, 4)
    x = sample_array(x_train, n, k1)
    ys1 = jax.lax.stop_gradient(sample_comp_each(vs, cfg.eps, x, cfg.md.sample_size, key2))
    ys2 = jax.lax.stop_gradient(sample_comp_each(vs, cfg.eps, x, cfg.md.sample_size, key3))
    return \
        (1 - eta) * jnp.mean(vi(cfg.diagonal, cfg.eps, vs, (vs_slow), x, ys1)) + \
        (  eta  ) * jnp.mean(vi(cfg.diagonal, cfg.eps, vs, (vs_fast), x, ys2))


# Train

exp_name = f"{cfg.alg_name}_swiss_seed_{cfg.seed}"
fig_path = f'{base_dir}/fig/{exp_name}'
if not os.path.exists(fig_path): os.makedirs(fig_path)

@functools.partial(jax.jit, static_argnums=4)
def train_lsb(st, vs, x, y, update_fn):
    def loss_fn(ps):
        return lsb_loss({**vs, 'params': ps}, x, y)
    _, ps = vs.pop('params')
    loss, grads = jax.value_and_grad(loss_fn)(ps)
    us, st_new = update_fn(grads, st, ps)
    ps_new = optax.apply_updates(ps, us)
    vs_new = flax.core.FrozenDict({**vs, "params": ps_new})
    return st_new, vs_new, loss

@functools.partial(jax.jit, static_argnums=5)
def train_lsbm(st, vs, x, y, key, update_fn):
    key, subkey = jax.random.split(key)
    def loss_fn(ps):
        return lsbm_loss({**vs, 'params': ps}, x, y, subkey)
    _, ps = vs.pop('params')
    loss, grads = jax.value_and_grad(loss_fn)(ps)
    us, st_new = update_fn(grads, st, ps)
    ps_new = optax.apply_updates(ps, us)
    vs_new = flax.core.FrozenDict({**vs, "params": ps_new})
    return st_new, vs_new, loss, key


@functools.partial(jax.jit, static_argnums=9)
def train_md0(step, st, vs, vs_fast, vs_slow, vs_fast_capture, x, y, key, update_fn):
    key, subkey = jax.random.split(key)
    eta = 1/(1/cfg.md.eta_start+
        (((step-1)//cfg.md.itvl)/(cfg.max_step//cfg.md.itvl-1))*
        (1/cfg.md.eta_end-1/cfg.md.eta_start))

    def loss_fn(ps):
        return md0_loss(eta,
            {**vs, 'params': ps}, vs_fast_capture, vs_slow, subkey)

    def loss_fast_fn(ps_fast):
        return lsb_loss({**vs, 'params': ps_fast}, x, y)

    _, ps = vs.pop('params')
    _, ps_fast = vs_fast.pop('params')

    loss, grads = jax.value_and_grad(loss_fn)(ps)
    loss_fast, grads_fast = jax.value_and_grad(loss_fast_fn)(ps_fast)

    ps = flax.core.FrozenDict({"ps": ps, "ps_fast": ps_fast})
    grads = flax.core.FrozenDict({"ps": grads, "ps_fast": grads_fast})

    us, st_new = update_fn(grads, st, ps)
    ps_new = optax.apply_updates(ps, us)
    vs_new = flax.core.FrozenDict({**vs, "params": ps_new["ps"]})
    vs_fast_new = flax.core.FrozenDict({**vs_fast, "params": ps_new["ps_fast"]})
    vs_slow_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_new, lambda: vs_slow)
    vs_fast_capture_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_fast_new, lambda: vs_fast_capture)
    return st_new, vs_new, vs_fast_new, vs_slow_new, vs_fast_capture_new, loss + loss_fast, key

@functools.partial(jax.jit, static_argnums=9)
def train_md0_lsbm(step, st, vs, vs_fast, vs_slow, vs_fast_capture, x, y, key, update_fn):
    key, subkey = jax.random.split(key)
    eta = 1/(1/cfg.md.eta_start+
        (((step-1)//cfg.md.itvl)/(cfg.max_step//cfg.md.itvl-1))*
        (1/cfg.md.eta_end-1/cfg.md.eta_start))

    def loss_fn(ps):
        return md0_loss(eta,
            {**vs, 'params': ps}, vs_fast_capture, vs_slow, subkey)

    def loss_fast_fn(ps_fast):
        return lsbm_loss({**vs, 'params': ps_fast}, x, y, subkey2)

    _, ps = vs.pop('params')
    _, ps_fast = vs_fast.pop('params')

    loss, grads = jax.value_and_grad(loss_fn)(ps)
    loss_fast, grads_fast = jax.value_and_grad(loss_fast_fn)(ps_fast)

    ps = flax.core.FrozenDict({"ps": ps, "ps_fast": ps_fast})
    grads = flax.core.FrozenDict({"ps": grads, "ps_fast": grads_fast})

    us, st_new = update_fn(grads, st, ps)
    ps_new = optax.apply_updates(ps, us)
    vs_new = flax.core.FrozenDict({**vs, "params": ps_new["ps"]})
    vs_fast_new = flax.core.FrozenDict({**vs_fast, "params": ps_new["ps_fast"]})
    vs_slow_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_new, lambda: vs_slow)
    vs_fast_capture_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_fast_new, lambda: vs_fast_capture)
    return st_new, vs_new, vs_fast_new, vs_slow_new, vs_fast_capture_new, loss + loss_fast, key

@functools.partial(jax.jit, static_argnums=9)
def train_md(step, st, vs, vs_fast, vs_slow, vs_fast_capture, x, y, key, update_fn):
    key, subkey = jax.random.split(key)
    eta = 1/(1/cfg.md.eta_start+
        (((step-1)//cfg.md.itvl)/(cfg.max_step//cfg.md.itvl-1))*
        (1/cfg.md.eta_end-1/cfg.md.eta_start))

    def loss_fn(ps, key):
        return md_loss(eta,
            {**vs, 'params': ps}, vs_fast_capture, vs_slow, x, key)

    def loss_fast_fn(ps_fast):
        return lsb_loss({**vs, 'params': ps_fast}, x, y)

    _, ps = vs.pop('params')
    _, ps_fast = vs_fast.pop('params')

    n = cfg.md.accumulation
    ks = jax.random.split(subkey, n)
    def body_fun(i, val):
        loss, grads = jax.value_and_grad(loss_fn)(ps, ks[i])
        return loss/n+val[0], jax.tree_util.tree_map(lambda x, y: x/n + y, grads, val[1])

    loss, grads = jax.lax.fori_loop(0, n, body_fun, (0.0, jax.tree_util.tree_map(jnp.zeros_like, ps)))
    loss_fast, grads_fast = jax.value_and_grad(loss_fast_fn)(ps_fast)

    ps = flax.core.FrozenDict({"ps": ps, "ps_fast": ps_fast})
    grads = flax.core.FrozenDict({"ps": grads, "ps_fast": grads_fast})

    us, st_new = update_fn(grads, st, ps)
    ps_new = optax.apply_updates(ps, us)
    vs_new = flax.core.FrozenDict({**vs, "params": ps_new["ps"]})
    vs_fast_new = flax.core.FrozenDict({**vs_fast, "params": ps_new["ps_fast"]})
    vs_slow_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_new, lambda: vs_slow)
    vs_fast_capture_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_fast_new, lambda: vs_fast_capture)
    return st_new, vs_new, vs_fast_new, vs_slow_new, vs_fast_capture_new, loss + loss_fast, key

@functools.partial(jax.jit, static_argnums=9)
def train_md_lsbm(step, st, vs, vs_fast, vs_slow, vs_fast_capture, x, y, key, update_fn):
    key, subkey = jax.random.split(key)
    eta = 1/(1/cfg.md.eta_start+
        (((step-1)//cfg.md.itvl)/(cfg.max_step//cfg.md.itvl-1))*
        (1/cfg.md.eta_end-1/cfg.md.eta_start))

    def loss_fn(ps, key):
        return md_loss(eta,
            {**vs, 'params': ps}, vs_fast_capture, vs_slow, x, key)

    def loss_fast_fn(ps_fast):
        return lsbm_loss({**vs, 'params': ps_fast}, x, y, subkey2)

    _, ps = vs.pop('params')
    _, ps_fast = vs_fast.pop('params')

    n = cfg.md.accumulation
    ks = jax.random.split(subkey, n)
    def body_fun(i, val):
        loss, grads = jax.value_and_grad(loss_fn)(ps, ks[i])
        return loss/n+val[0], jax.tree_util.tree_map(lambda x, y: x/n + y, grads, val[1])

    loss, grads = jax.lax.fori_loop(0, n, body_fun, (0.0, jax.tree_util.tree_map(jnp.zeros_like, ps)))
    loss_fast, grads_fast = jax.value_and_grad(loss_fast_fn)(ps_fast)

    ps = flax.core.FrozenDict({"ps": ps, "ps_fast": ps_fast})
    grads = flax.core.FrozenDict({"ps": grads, "ps_fast": grads_fast})

    us, st_new = update_fn(grads, st, ps)
    ps_new = optax.apply_updates(ps, us)
    vs_new = flax.core.FrozenDict({**vs, "params": ps_new["ps"]})
    vs_fast_new = flax.core.FrozenDict({**vs_fast, "params": ps_new["ps_fast"]})
    vs_slow_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_new, lambda: vs_slow)
    vs_fast_capture_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_fast_new, lambda: vs_fast_capture)
    return st_new, vs_new, vs_fast_new, vs_slow_new, vs_fast_capture_new, loss + loss_fast, key

def evaluate(step, key):
    xs = x_sampler.sample(cfg.eval_size)
    ys = y_sampler.sample(cfg.eval_size)
    key, subkey = jax.random.split(key)
    pca_plot(xs, ys, model(vs, xs, subkey),
        n_plot=cfg.eval_size,
        save_name=f'{fig_path}/{str(step).zfill(5)}.png')
    return key

progress_bar, key = tqdm(range(1, cfg.max_step + 1)), evaluate(0, key)

if cfg.alg_name == 'lsb':

    for step in progress_bar:
        xs = x_sampler.sample(cfg.batch_size)
        ys = y_sampler.sample(cfg.batch_size)
        st, vs, loss = train_lsb(st, vs, xs, ys, opt.update)
        progress_bar.set_description(f'loss [{loss: >11.5f}]')
        if step % cfg.eval_itvl == 0: key = evaluate(step, key)

elif cfg.alg_name == 'lsbm':

    for step in progress_bar:
        xs = x_sampler.sample(cfg.batch_size)
        ys = y_sampler.sample(cfg.batch_size)
        st, vs, loss, key = train_lsbm(st, vs, xs, ys, key, opt.update)
        progress_bar.set_description(f'loss [{loss: >11.5f}]')
        if step % cfg.eval_itvl == 0: key = evaluate(step, key)

elif cfg.alg_name == 'md0':

    for step in progress_bar:
        xs = x_sampler.sample(cfg.batch_size)
        ys = y_sampler.sample(cfg.batch_size)
        st, vs, vs_fast, vs_slow, vs_fast_capture, loss, key = train_md0(step,
            st, vs, vs_fast, vs_slow, vs_fast_capture, xs, ys, key, opt.update)
        progress_bar.set_description(f'loss [{loss: >11.5f}]')
        if step % cfg.eval_itvl == 0: key = evaluate(step, key)

elif cfg.alg_name == 'md0_lsbm':

    for step in progress_bar:
        xs = x_sampler.sample(cfg.batch_size)
        ys = y_sampler.sample(cfg.batch_size)
        st, vs, vs_fast, vs_slow, vs_fast_capture, loss, key = train_md0_lsbm(step,
            st, vs, vs_fast, vs_slow, vs_fast_capture, xs, ys, key, opt.update)
        progress_bar.set_description(f'loss [{loss: >11.5f}]')
        if step % cfg.eval_itvl == 0: key = evaluate(step, key)

elif cfg.alg_name == 'md':

    for step in progress_bar:
        xs = x_sampler.sample(cfg.batch_size)
        ys = y_sampler.sample(cfg.batch_size)
        st, vs, vs_fast, vs_slow, vs_fast_capture, loss, key = train_md(step,
            st, vs, vs_fast, vs_slow, vs_fast_capture, xs, ys, key, opt.update)
        progress_bar.set_description(f'loss [{loss: >11.5f}]')
        if step % cfg.eval_itvl == 0: key = evaluate(step, key)

elif cfg.alg_name == 'md_lsbm':

    for step in progress_bar:
        xs = x_sampler.sample(cfg.batch_size)
        ys = y_sampler.sample(cfg.batch_size)
        st, vs, vs_fast, vs_slow, vs_fast_capture, loss, key = train_md_lsbm(step,
            st, vs, vs_fast, vs_slow, vs_fast_capture, xs, ys, key, opt.update)
        progress_bar.set_description(f'loss [{loss: >11.5f}]')
        if step % cfg.eval_itvl == 0: key = evaluate(step, key)


else: raise ValueError("Invalid algorithm name")


# Result

fig, axes = plt.subplots(1, 2, figsize=(15, 6.75), dpi=200)

for ax in axes: ax.grid(zorder=-20)

x_samples = x_sampler.sample(2048)
y_samples = y_sampler.sample(2048)

tr_samples = jnp.asarray([[0.0, 0.0], [1.75, -1.75], [-1.5, 1.5], [2, 2]])

tr_samples = jnp.reshape(jnp.tile(jnp.expand_dims(tr_samples, 0), (3, 1, 1)), (12, 2))

axes[0].scatter(x_samples[:, 0], x_samples[:, 1], alpha=0.3,
    c="g", s=32, edgecolors="black", label = r"Input distirubtion $p_0$")
axes[0].scatter(y_samples[:, 0], y_samples[:, 1],
    c="orange", s=32, edgecolors="black", label = r"Target distribution $p_1$")

key, subkey = jax.random.split(key)
ypred = model(vs, x_samples, subkey)

ax.scatter(ypred[:, 0], ypred[:, 1], c="yellow", s=32, edgecolors="black", label = "Fitted distribution", zorder=1)

key, subkey = jax.random.split(key)
traj = np.asarray(solve_sde(vs, tr_samples, 1000, subkey))

ax.scatter(tr_samples[:, 0], tr_samples[:, 1],
   c="g", s=128, edgecolors="black", label = r"Trajectory start ($x \sim p_0$)", zorder=3)

ax.scatter(traj[:, -1, 0], traj[:, -1, 1],
   c="red", s=64, edgecolors="black", label = r"Trajectory end (fitted)", zorder=3)

for i in range(12):
    ax.plot(traj[i, ::1, 0], traj[i, ::1, 1], "black", markeredgecolor="black", linewidth=1.5, zorder=2)
    if i == 0:
        ax.plot(traj[i, ::1, 0], traj[i, ::1, 1], "grey", markeredgecolor="black", linewidth=0.5, zorder=2, label=r"traj of $T_{\theta}$")
    else:
        ax.plot(traj[i, ::1, 0], traj[i, ::1, 1], "grey", markeredgecolor="black", linewidth=0.5, zorder=2)

for ax in axes:
    ax.set_xlim([-2.5, 2.5])
    ax.set_ylim([-2.5, 2.5])
    ax.legend(loc="lower left")

fig.tight_layout(pad=0.1); fig.savefig(f'{fig_path}/result.png')
