# Import

import os
import sys
import jax
import flax
import optax
import functools
import numpy as np
from tqdm import tqdm
import jax.numpy as jnp
import matplotlib.pyplot as plt
base_dir = os.path.split(os.path.dirname(__file__))[0]
src_dir = os.path.join(base_dir, 'src'); sys.path.append(src_dir)
from lsb import Lsb, init_r_by_samples, init_orthogonal_S, sb_opt
from util import mmd, get_cfg, jax_random_key_with_settings, sample_array
from fn import vi, vi0


# Parameter

cfg = get_cfg('cfg', f'{base_dir}/cfg/default_mcsi.py')
key = jax_random_key_with_settings(cfg.seed)


# Data

raw = [
    np.load(f"{base_dir}/data/full_cite_pcas_{cfg.dim}_day_{day}.npy") \
    for day in (cfg.day_start, cfg.day_eval, cfg.day_end)
]
data_eval = raw[1]
t_eval = jnp.asarray([(cfg.day_eval - cfg.day_start)/ (cfg.day_end - cfg.day_start)])

scale = np.concatenate(raw).std(axis=0).mean()
x_train, y_train = map(lambda x: jnp.asarray(x / scale), raw[::2]); del raw


# Model

sb = Lsb(cfg.dim, cfg.n_potential, cfg.eps, cfg.diagonal, cfg.S_init, jnp.float32)
key, subkey1, subkey2, subkey3, subkey4 = jax.random.split(key, 5)
vs = flax.core.FrozenDict(sb.init(subkey1, sample_array(x_train, cfg.batch_size, subkey2), key=subkey3))
vs = init_orthogonal_S(init_r_by_samples(vs, sample_array(y_train, cfg.n_potential, subkey4)))
opt = sb_opt(optax.adam(cfg.lr))
if "md" not in cfg.alg_name:
    st = opt.init(vs["params"])
else:
    vs_fast, vs_slow, vs_fast_capture = vs.copy(), vs.copy(), vs.copy()
    st = opt.init(flax.core.FrozenDict({"ps": vs["params"], "ps_fast": vs_fast["params"]}))


# Function

log_v = jax.jit(flax.linen.apply(lambda m, x: m.get_log_potential(x), sb))
log_c = jax.jit(flax.linen.apply(lambda m, x: m.get_log_C(x), sb))
model = jax.jit(flax.linen.apply(lambda m, x, k: m(x, k), sb))
sample_t = jax.jit(flax.linen.apply(lambda m, x, t, k: m.sample_at_time_moment(x, t, k), sb))
drift = jax.jit(flax.linen.apply(lambda m, x, t: m.get_drift(x, t), sb))
sample_comp = functools.partial(jax.jit, static_argnums=2)(flax.linen.apply(lambda m, e, n, k: m.sample_comp(e, n, k), sb))
sample_comp_each = functools.partial(jax.jit, static_argnums=3)(flax.linen.apply(lambda m, e, x, n, k: m.sample_comp_each(e, x, n, k), sb))

def lsb_loss(vs, key):
    n = cfg.batch_size
    k1, k2 = jax.random.split(key)
    x = sample_array(x_train, n, k1)
    y = sample_array(y_train, n, k2)
    return jnp.mean(log_c(vs, x)) - jnp.mean(log_v(vs, y))

def lsbm_loss(vs, key):
    n = cfg.batch_size
    k1, k2, k3, k4 = jax.random.split(key, 4)
    x = sample_array(x_train, n, k1)
    y = sample_array(y_train, n, k2)
    t = 0.999 * jax.random.uniform(k3, (x.shape[0], 1))
    x_t = y * t + x * (1 - t) + \
        jax.lax.sqrt(cfg.eps * t * (1 - t)) * \
        jax.random.normal(k4, x.shape)
    drift_target = (y - x_t) / (1 - t)
    diff = drift_target - drift(vs, x_t, jnp.reshape(t, -1))
    return 0.5 * jnp.mean(jax.vmap(jnp.inner)(diff, diff))

def md0_loss(eta, vs, vs_fast, vs_slow, key):
    key1, key2 = jax.random.split(key, 2)
    ys1 = jax.lax.stop_gradient(sample_comp(vs, cfg.eps, cfg.md.sample_size, key1))
    ys2 = jax.lax.stop_gradient(sample_comp(vs, cfg.eps, cfg.md.sample_size, key2))
    return \
        (1 - eta) * (vi0(cfg.diagonal, cfg.eps, vs, (vs_slow), ys1)) + \
        (  eta  ) * (vi0(cfg.diagonal, cfg.eps, vs, (vs_fast), ys2))

def md_loss(eta, vs, vs_fast, vs_slow, key):
    n = cfg.batch_size
    k1, k2, k3, k4 = jax.random.split(key, 4)
    x = sample_array(x_train, n, k1)
    ys1 = jax.lax.stop_gradient(sample_comp_each(vs, cfg.eps, x, cfg.md.sample_size, key2))
    ys2 = jax.lax.stop_gradient(sample_comp_each(vs, cfg.eps, x, cfg.md.sample_size, key3))
    return \
        (1 - eta) * jnp.mean(vi(cfg.diagonal, cfg.eps, vs, (vs_slow), x, ys1)) + \
        (  eta  ) * jnp.mean(vi(cfg.diagonal, cfg.eps, vs, (vs_fast), x, ys2))


# Train

exp_name = f"{cfg.alg_name}_msci_seed_{cfg.seed}"
fig_path = f'{base_dir}/fig/{exp_name}'
if not os.path.exists(fig_path): os.makedirs(fig_path)


@functools.partial(jax.jit, static_argnums=3)
def train_lsb(st, vs, key, update_fn):
    key, subkey = jax.random.split(key)
    def loss_fn(ps):
        return lsb_loss({**vs, 'params': ps}, subkey)
    _, ps = vs.pop('params')
    loss, grads = jax.value_and_grad(loss_fn)(ps)
    us, st_new = update_fn(grads, st, ps)
    ps_new = optax.apply_updates(ps, us)
    vs_new = flax.core.FrozenDict({**vs, "params": ps_new})
    return st_new, vs_new, loss, key

@functools.partial(jax.jit, static_argnums=3)
def train_lsbm(st, vs, key, update_fn):
    key, subkey = jax.random.split(key)
    def loss_fn(ps):
        return lsbm_loss({**vs, 'params': ps}, subkey)
    _, ps = vs.pop('params')
    loss, grads = jax.value_and_grad(loss_fn)(ps)
    us, st_new = update_fn(grads, st, ps)
    ps_new = optax.apply_updates(ps, us)
    vs_new = flax.core.FrozenDict({**vs, "params": ps_new})
    return st_new, vs_new, loss, key

@functools.partial(jax.jit, static_argnums=7)
def train_md0(step, st, vs, vs_fast, vs_slow, vs_fast_capture, key, update_fn):
    key, subkey, subkey2 = jax.random.split(key, 3)
    eta = 1/(1/cfg.md.eta_start+
        (((step-1)//cfg.md.itvl)/(cfg.max_step//cfg.md.itvl-1))*
        (1/cfg.md.eta_end-1/cfg.md.eta_start))

    def loss_fn(ps):
        return md0_loss(eta,
            {**vs, 'params': ps}, vs_fast_capture, vs_slow, subkey)

    def loss_fast_fn(ps_fast):
        return lsb_loss({**vs, 'params': ps_fast}, subkey2)

    _, ps = vs.pop('params')
    _, ps_fast = vs_fast.pop('params')

    loss, grads = jax.value_and_grad(loss_fn)(ps)
    loss_fast, grads_fast = jax.value_and_grad(loss_fast_fn)(ps_fast)

    ps = flax.core.FrozenDict({"ps": ps, "ps_fast": ps_fast})
    grads = flax.core.FrozenDict({"ps": grads, "ps_fast": grads_fast})

    us, st_new = update_fn(grads, st, ps)
    ps_new = optax.apply_updates(ps, us)
    vs_new = flax.core.FrozenDict({**vs, "params": ps_new["ps"]})
    vs_fast_new = flax.core.FrozenDict({**vs_fast, "params": ps_new["ps_fast"]})
    vs_slow_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_new, lambda: vs_slow)
    vs_fast_capture_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_fast_new, lambda: vs_fast_capture)
    return st_new, vs_new, vs_fast_new, vs_slow_new, vs_fast_capture_new, loss + loss_fast, key

@functools.partial(jax.jit, static_argnums=7)
def train_md0_lsbm(step, st, vs, vs_fast, vs_slow, vs_fast_capture, key, update_fn):
    key, subkey, subkey2 = jax.random.split(key, 3)
    eta = 1/(1/cfg.md.eta_start+
        (((step-1)//cfg.md.itvl)/(cfg.max_step//cfg.md.itvl-1))*
        (1/cfg.md.eta_end-1/cfg.md.eta_start))

    def loss_fn(ps):
        return md0_loss(eta,
            {**vs, 'params': ps}, vs_fast_capture, vs_slow, subkey)

    def loss_fast_fn(ps_fast):
        return lsbm_loss({**vs, 'params': ps_fast}, subkey2)

    _, ps = vs.pop('params')
    _, ps_fast = vs_fast.pop('params')

    loss, grads = jax.value_and_grad(loss_fn)(ps)
    loss_fast, grads_fast = jax.value_and_grad(loss_fast_fn)(ps_fast)

    ps = flax.core.FrozenDict({"ps": ps, "ps_fast": ps_fast})
    grads = flax.core.FrozenDict({"ps": grads, "ps_fast": grads_fast})

    us, st_new = update_fn(grads, st, ps)
    ps_new = optax.apply_updates(ps, us)
    vs_new = flax.core.FrozenDict({**vs, "params": ps_new["ps"]})
    vs_fast_new = flax.core.FrozenDict({**vs_fast, "params": ps_new["ps_fast"]})
    vs_slow_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_new, lambda: vs_slow)
    vs_fast_capture_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_fast_new, lambda: vs_fast_capture)
    return st_new, vs_new, vs_fast_new, vs_slow_new, vs_fast_capture_new, loss + loss_fast, key


@functools.partial(jax.jit, static_argnums=7)
def train_md(step, st, vs, vs_fast, vs_slow, vs_fast_capture key, update_fn):
    key, subkey, subkey2 = jax.random.split(key, 3)
    eta = 1/(1/cfg.md.eta_start+
        (((step-1)//cfg.md.itvl)/(cfg.max_step//cfg.md.itvl-1))*
        (1/cfg.md.eta_end-1/cfg.md.eta_start))

    def loss_fn(ps, key):
        return md_loss(eta,
            {**vs, 'params': ps}, vs_fast_capture, vs_slow, key)

    def loss_fast_fn(ps_fast):
        return lsb_loss({**vs, 'params': ps_fast}, subkey2)

    _, ps = vs.pop('params')
    _, ps_fast = vs_fast.pop('params')

    n = cfg.md.accumulation
    ks = jax.random.split(subkey, n)
    def body_fun(i, val):
        loss, grads = jax.value_and_grad(loss_fn)(ps, ks[i])
        return loss/n+val[0], jax.tree_util.tree_map(lambda x, y: x/n + y, grads, val[1])

    loss, grads = jax.lax.fori_loop(0, n, body_fun, (0.0, jax.tree_util.tree_map(jnp.zeros_like, ps)))
    loss_fast, grads_fast = jax.value_and_grad(loss_fast_fn)(ps_fast)

    ps = flax.core.FrozenDict({"ps": ps, "ps_fast": ps_fast})
    grads = flax.core.FrozenDict({"ps": grads, "ps_fast": grads_fast})

    us, st_new = update_fn(grads, st, ps)
    ps_new = optax.apply_updates(ps, us)
    vs_new = flax.core.FrozenDict({**vs, "params": ps_new["ps"]})
    vs_fast_new = flax.core.FrozenDict({**vs_fast, "params": ps_new["ps_fast"]})
    vs_slow_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_new, lambda: vs_slow)
    vs_fast_capture_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_fast_new, lambda: vs_fast_capture)
    return st_new, vs_new, vs_fast_new, vs_slow_new, vs_fast_capture_new, loss + loss_fast, key

@functools.partial(jax.jit, static_argnums=7)
def train_md_lsbm(step, st, vs, vs_fast, vs_slow, vs_fast_capture, key, update_fn):
    key, subkey, subkey2 = jax.random.split(key, 3)
    eta = 1/(1/cfg.md.eta_start+
        (((step-1)//cfg.md.itvl)/(cfg.max_step//cfg.md.itvl-1))*
        (1/cfg.md.eta_end-1/cfg.md.eta_start))

    def loss_fn(ps, key):
        return md_loss(eta,
            {**vs, 'params': ps}, vs_fast_capture, vs_slow, key)

    def loss_fast_fn(ps_fast):
        return lsbm_loss({**vs, 'params': ps_fast}, subkey2)

    _, ps = vs.pop('params')
    _, ps_fast = vs_fast.pop('params')

    n = cfg.md.accumulation
    ks = jax.random.split(subkey, n)
    def body_fun(i, val):
        loss, grads = jax.value_and_grad(loss_fn)(ps, ks[i])
        return loss/n+val[0], jax.tree_util.tree_map(lambda x, y: x/n + y, grads, val[1])

    loss, grads = jax.lax.fori_loop(0, n, body_fun, (0.0, jax.tree_util.tree_map(jnp.zeros_like, ps)))
    loss_fast, grads_fast = jax.value_and_grad(loss_fast_fn)(ps_fast)

    ps = flax.core.FrozenDict({"ps": ps, "ps_fast": ps_fast})
    grads = flax.core.FrozenDict({"ps": grads, "ps_fast": grads_fast})

    us, st_new = update_fn(grads, st, ps)
    ps_new = optax.apply_updates(ps, us)
    vs_new = flax.core.FrozenDict({**vs, "params": ps_new["ps"]})
    vs_fast_new = flax.core.FrozenDict({**vs_fast, "params": ps_new["ps_fast"]})
    vs_slow_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_new, lambda: vs_slow)
    vs_fast_capture_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_fast_new, lambda: vs_fast_capture)
    return st_new, vs_new, vs_fast_new, vs_slow_new, vs_fast_capture_new, loss + loss_fast, key

eval_steps = []
eval_scores = []

def evaluate(step, key):
    key, subkey = jax.random.split(key)
    eval_steps.append(step)
    eval_scores.append(mmd(np.asarray(sample_t(vs, x_train, t_eval, subkey)) * scale, data_eval))
    return key

progress_bar, key = tqdm(range(1, cfg.max_step + 1)), evaluate(0, key)


if cfg.alg_name == 'lsb':

    for step in progress_bar:
        st, vs, loss, key = train_lsb(st, vs, key, opt.update)
        progress_bar.set_description(f'loss [{loss: >11.5f}]')
        if step % cfg.eval_itvl == 0: key = evaluate(step, key)

elif cfg.alg_name == 'lsbm':

    for step in progress_bar:
        st, vs, loss, key = train_lsbm(st, vs, key, opt.update)
        progress_bar.set_description(f'loss [{loss: >11.5f}]')
        if step % cfg.eval_itvl == 0: key = evaluate(step, key)

elif cfg.alg_name == 'md0':

    for step in progress_bar:
        st, vs, vs_fast, vs_slow, vs_fast_capture, loss, key = train_md0(step,
            st, vs, vs_fast, vs_slow, vs_fast_capture, key, opt.update)
        progress_bar.set_description(f'loss [{loss: >11.5f}]')
        if step % cfg.eval_itvl == 0: key = evaluate(step, key)

elif cfg.alg_name == 'md0_lsbm':

    for step in progress_bar:
        st, vs, vs_fast, vs_slow, vs_fast_capture, loss, key = train_md0_lsbm(step,
            st, vs, vs_fast, vs_slow, vs_fast_capture, key, opt.update)
        progress_bar.set_description(f'loss [{loss: >11.5f}]')
        if step % cfg.eval_itvl == 0: key = evaluate(step, key)

elif cfg.alg_name == 'md':

    for step in progress_bar:
        st, vs, vs_fast, vs_slow, vs_fast_capture, loss, key = train_md(step,
            st, vs, vs_fast, vs_slow, vs_fast_capture, key, opt.update)
        progress_bar.set_description(f'loss [{loss: >11.5f}]')
        if step % cfg.eval_itvl == 0: key = evaluate(step, key)

elif cfg.alg_name == 'md_lsbm':

    for step in progress_bar:
        st, vs, vs_fast, vs_slow, vs_fast_capture, loss, key = train_md_lsbm(step,
            st, vs, vs_fast, vs_slow, vs_fast_capture, key, opt.update)
        progress_bar.set_description(f'loss [{loss: >11.5f}]')
        if step % cfg.eval_itvl == 0: key = evaluate(step, key)

else: raise ValueError("Invalid algorithm name")


# Result

plt.figure(figsize=(8, 4))
plt.plot(eval_steps, eval_scores, '-o', color='blue')
plt.title('MMD')
plt.xlabel('Step')
plt.grid(True)
plt.tight_layout(pad=0.1); plt.savefig(f'{fig_path}/result.png')

print(eval_scores)
