# Import

import os
import sys
import jax
import flax
import optax
import functools
from tqdm import tqdm
import jax.numpy as jnp
import matplotlib.pyplot as plt
base_dir = os.path.split(os.path.dirname(__file__))[0]
src_dir = os.path.join(base_dir, 'src'); sys.path.append(src_dir)
from lsb import Lsb, init_r_by_samples, init_orthogonal_S, sb_opt
from eot_benchmark.metrics import compute_BW_UVP_by_gt_samples
from util import (
    calcuate_condBW,
    get_indepedent_plan_sample_fn, get_discrete_ot_plan_sample_fn,
    get_gt_plan_sample_fn_EOT, EOTGMMSampler, get_cfg, jax_random_key_with_settings
)
from fn import vi, vi0

# Parameter

cfg = get_cfg('cfg', f'{base_dir}/cfg/default_eot.py')
key = jax_random_key_with_settings(cfg.seed)


# Data

sampler = EOTGMMSampler(cfg.dim, cfg.eps, cfg.batch_size, download=cfg.data_preload)

if cfg.plan_type == 'ind':
    sample_plan = get_indepedent_plan_sample_fn(sampler.x_sample, sampler.y_sample)
elif cfg.plan_type == 'ot':
    sample_plan = get_discrete_ot_plan_sample_fn(sampler.x_sample, sampler.y_sample)
elif cfg.plan_type == 'gt':
    sample_plan = get_gt_plan_sample_fn_EOT(sampler)
else:
    raise ValueError('Unknown type of sampling plan')


# Model

sb = Lsb(cfg.dim, cfg.n_potential, cfg.eps, cfg.diagonal, cfg.S_init, jnp.float32)
key, subkey1, subkey2 = jax.random.split(key, 3)
vs = flax.core.FrozenDict(sb.init(subkey1, sampler.x_sample(cfg.batch_size), key=subkey2))
vs = init_orthogonal_S(init_r_by_samples(vs, sampler.y_sample(cfg.n_potential)))
opt = sb_opt(optax.adam(cfg.lr))

if "md" not in cfg.alg_name:
    st = opt.init(vs["params"])
else:
    vs_fast, vs_slow, vs_fast_capture = vs.copy(), vs.copy(), vs.copy()
    st = opt.init(flax.core.FrozenDict({"ps": vs["params"], "ps_fast": vs_fast["params"]}))


# Function

log_v = jax.jit(flax.linen.apply(lambda m, x: m.get_log_potential(x), sb))
log_c = jax.jit(flax.linen.apply(lambda m, x: m.get_log_C(x), sb))
model = jax.jit(flax.linen.apply(lambda m, x, k: m(x, k), sb))
drift = jax.jit(flax.linen.apply(lambda m, x, t: m.get_drift(x, t), sb))
sample_comp = functools.partial(jax.jit, static_argnums=2)(flax.linen.apply(lambda m, e, n, k: m.sample_comp(e, n, k), sb))
sample_comp_each = functools.partial(jax.jit, static_argnums=3)(flax.linen.apply(lambda m, e, x, n, k: m.sample_comp_each(e, x, n, k), sb))
solve_sde = functools.partial(jax.jit, static_argnums=2)(
    flax.linen.apply(lambda m, x, n, k: m.sample_euler_maruyama(x, n, k), sb))

def lsb_loss(vs, x, y):
    return jnp.mean(log_c(vs, x)) - jnp.mean(log_v(vs, y))

def lsbm_loss(vs, x, y, key):
    k1, k2 = jax.random.split(key)
    t = 0.99 * jax.random.uniform(k1, (x.shape[0], 1))
    x_t = y * t + x * (1 - t) + \
        jax.lax.sqrt(cfg.eps * t * (1 - t)) * \
        jax.random.normal(k2, x.shape)
    drift_target = (y - x_t) / (1 - t)
    diff = drift_target - drift(vs, x_t, jnp.reshape(t, -1))
    return 0.5 * jnp.mean(jax.vmap(jnp.inner)(diff, diff))

def md0_loss(eta, vs, vs_fast, vs_slow, key):
    key1, key2 = jax.random.split(key, 2)
    ys1 = jax.lax.stop_gradient(sample_comp(vs, cfg.eps, cfg.md.sample_size, key1))
    ys2 = jax.lax.stop_gradient(sample_comp(vs, cfg.eps, cfg.md.sample_size, key2))
    return \
        (1 - eta) * (vi0(cfg.diagonal, cfg.eps, vs, (vs_slow), ys1)) + \
        (  eta  ) * (vi0(cfg.diagonal, cfg.eps, vs, (vs_fast), ys2))

def md_loss(eta, vs, vs_fast, vs_slow, x, key):
    key1, key2, key3 = jax.random.split(key, 3)
    ys1 = jax.lax.stop_gradient(sample_comp_each(vs, cfg.eps, x, cfg.md.sample_size, key1))
    ys2 = jax.lax.stop_gradient(sample_comp_each(vs, cfg.eps, x, cfg.md.sample_size, key2))
    return \
        (1 - eta) * jnp.mean(vi(cfg.diagonal, cfg.eps, vs, (vs_slow), x, ys1)) + \
        (  eta  ) * jnp.mean(vi(cfg.diagonal, cfg.eps, vs, (vs_fast), x, ys2))


# Train

exp_name = f"{cfg.alg_name}_eot_seed_{cfg.seed}"
fig_path = f'{base_dir}/fig/{exp_name}'
if not os.path.exists(fig_path): os.makedirs(fig_path)

@functools.partial(jax.jit, static_argnums=4)
def train_lsb(st, vs, x, y, update_fn):
    def loss_fn(ps):
        return lsb_loss({**vs, 'params': ps}, x, y)
    _, ps = vs.pop('params')
    loss, grads = jax.value_and_grad(loss_fn)(ps)
    us, st_new = update_fn(grads, st, ps)
    ps_new = optax.apply_updates(ps, us)
    vs_new = flax.core.FrozenDict({**vs, "params": ps_new})
    return st_new, vs_new, loss

@functools.partial(jax.jit, static_argnums=5)
def train_lsbm(st, vs, x, y, key, update_fn):
    key, subkey = jax.random.split(key)
    def loss_fn(ps):
        return lsbm_loss({**vs, 'params': ps}, x, y, subkey)
    _, ps = vs.pop('params')
    loss, grads = jax.value_and_grad(loss_fn)(ps)
    us, st_new = update_fn(grads, st, ps)
    ps_new = optax.apply_updates(ps, us)
    vs_new = flax.core.FrozenDict({**vs, "params": ps_new})
    return st_new, vs_new, loss, key

@functools.partial(jax.jit, static_argnums=9)
def train_md0_pre(step, st, vs, vs_fast, vs_slow, vs_fast_capture, x, y, key, update_fn):
    key, subkey = jax.random.split(key)
    eta = 1/(1/cfg.md.eta_start+
        (((step-1)//cfg.md.itvl)/(cfg.max_step//cfg.md.itvl-1))*
        (1/cfg.md.eta_end-1/cfg.md.eta_start))

    def loss_fast_fn(ps_fast):
        return lsb_loss({**vs, 'params': ps_fast}, x, y)

    _, ps = vs.pop('params')
    _, ps_fast = vs_fast.pop('params')

    loss_fast, grads_fast = jax.value_and_grad(loss_fast_fn)(ps_fast)

    grads = flax.core.FrozenDict({"ps": jax.tree_util.tree_map(lambda p: jnp.zeros_like(p), ps), "ps_fast": grads_fast})
    ps_all = flax.core.FrozenDict({"ps": ps, "ps_fast": ps_fast})
    us, st_new = update_fn(grads, st, ps_all)
    ps_new = optax.apply_updates(ps_all, us)
    vs_new = flax.core.FrozenDict({**vs, "params": ps_new["ps"]})
    vs_fast_new = flax.core.FrozenDict({**vs_fast, "params": ps_new["ps_fast"]})
    vs_slow_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_new, lambda: vs_slow)
    vs_fast_capture_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_fast_new, lambda: vs_fast_capture)
    return st_new, vs_new, vs_fast_new, vs_slow_new, vs_fast_capture_new, loss_fast, key

@functools.partial(jax.jit, static_argnums=9)
def train_md0(step, st, vs, vs_fast, vs_slow, vs_fast_capture, x, y, key, update_fn):
    key, subkey = jax.random.split(key)
    eta = 1/(1/cfg.md.eta_start+
        (((step-1)//cfg.md.itvl)/(cfg.max_step//cfg.md.itvl-1))*
        (1/cfg.md.eta_end-1/cfg.md.eta_start))

    def loss_fn(ps):
        return md0_loss(eta,
            {**vs, 'params': ps}, vs_fast_capture, vs_slow, subkey)

    def loss_fast_fn(ps_fast):
        return lsb_loss({**vs, 'params': ps_fast}, x, y)

    _, ps = vs.pop('params')
    _, ps_fast = vs_fast.pop('params')

    loss, grads = jax.value_and_grad(loss_fn)(ps)
    loss_fast, grads_fast = jax.value_and_grad(loss_fast_fn)(ps_fast)

    ps = flax.core.FrozenDict({"ps": ps, "ps_fast": ps_fast})
    grads = flax.core.FrozenDict({"ps": grads, "ps_fast": grads_fast})

    us, st_new = update_fn(grads, st, ps)
    ps_new = optax.apply_updates(ps, us)
    vs_new = flax.core.FrozenDict({**vs, "params": ps_new["ps"]})
    vs_fast_new = flax.core.FrozenDict({**vs_fast, "params": ps_new["ps_fast"]})
    vs_slow_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_new, lambda: vs_slow)
    vs_fast_capture_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_fast_new, lambda: vs_fast_capture)
    return st_new, vs_new, vs_fast_new, vs_slow_new, vs_fast_capture_new, loss + loss_fast, key


@functools.partial(jax.jit, static_argnums=9)
def train_md0_lsbm_pre(step, st, vs, vs_fast, vs_slow, vs_fast_capture, x, y, key, update_fn):
    key, subkey, subkey2 = jax.random.split(key, 3)
    eta = 1/(1/cfg.md.eta_start+
        (((step-1)//cfg.md.itvl)/(cfg.max_step//cfg.md.itvl-1))*
        (1/cfg.md.eta_end-1/cfg.md.eta_start))

    def loss_fast_fn(ps_fast):
        return lsbm_loss({**vs, 'params': ps_fast}, x, y, subkey2)

    _, ps = vs.pop('params')
    _, ps_fast = vs_fast.pop('params')

    loss_fast, grads_fast = jax.value_and_grad(loss_fast_fn)(ps_fast)

    grads = flax.core.FrozenDict({"ps": jax.tree_util.tree_map(lambda p: jnp.zeros_like(p), ps), "ps_fast": grads_fast})
    ps_all = flax.core.FrozenDict({"ps": ps, "ps_fast": ps_fast})
    us, st_new = update_fn(grads, st, ps_all)
    ps_new = optax.apply_updates(ps_all, us)
    vs_fast_new = flax.core.FrozenDict({**vs_fast, "params": ps_new["ps_fast"]})
    vs_slow_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_new, lambda: vs_slow)
    vs_fast_capture_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_fast_new, lambda: vs_fast_capture)
    return st_new, vs, vs_fast_new, vs_slow_new, vs_fast_capture_new, loss_fast, key

@functools.partial(jax.jit, static_argnums=9)
def train_md0_lsbm(step, st, vs, vs_fast, vs_slow, vs_fast_capture, x, y, key, update_fn):
    key, subkey, subkey2 = jax.random.split(key, 3)
    eta = 1/(1/cfg.md.eta_start+
        (((step-1)//cfg.md.itvl)/(cfg.max_step//cfg.md.itvl-1))*
        (1/cfg.md.eta_end-1/cfg.md.eta_start))

    def loss_fn(ps):
        return md0_loss(eta,
            {**vs, 'params': ps}, vs_fast_capture, vs_slow, subkey)

    def loss_fast_fn(ps_fast):
        return lsbm_loss({**vs, 'params': ps_fast}, x, y, subkey2)

    _, ps = vs.pop('params')
    _, ps_fast = vs_fast.pop('params')

    loss, grads = jax.value_and_grad(loss_fn)(ps)
    loss_fast, grads_fast = jax.value_and_grad(loss_fast_fn)(ps_fast)

    ps = flax.core.FrozenDict({"ps": ps, "ps_fast": ps_fast})
    grads = flax.core.FrozenDict({"ps": grads, "ps_fast": grads_fast})

    us, st_new = update_fn(grads, st, ps)
    ps_new = optax.apply_updates(ps, us)
    vs_new = flax.core.FrozenDict({**vs, "params": ps_new["ps"]})
    vs_fast_new = flax.core.FrozenDict({**vs_fast, "params": ps_new["ps_fast"]})
    vs_slow_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_new, lambda: vs_slow)
    vs_fast_capture_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_fast_new, lambda: vs_fast_capture)
    return st_new, vs_new, vs_fast_new, vs_slow_new, vs_fast_capture_new, loss + loss_fast, key


@functools.partial(jax.jit, static_argnums=9)
def train_md(step, st, vs, vs_fast, vs_slow, vs_fast_capture, x, y, key, update_fn):
    key, subkey = jax.random.split(key)
    eta = 1/(1/cfg.md.eta_start+
        (((step-1)//cfg.md.itvl)/(cfg.max_step//cfg.md.itvl-1))*
        (1/cfg.md.eta_end-1/cfg.md.eta_start))

    def loss_fn(ps, key):
        return md_loss(eta,
            {**vs, 'params': ps}, vs_fast_capture, vs_slow, x, key)

    def loss_fast_fn(ps_fast):
        return lsb_loss({**vs, 'params': ps_fast}, x, y)

    _, ps = vs.pop('params')
    _, ps_fast = vs_fast.pop('params')

    n = cfg.md.accumulation
    ks = jax.random.split(subkey, n)
    def body_fun(i, val):
        loss, grads = jax.value_and_grad(loss_fn)(ps, ks[i])
        return loss/n+val[0], jax.tree_util.tree_map(lambda x, y: x/n + y, grads, val[1])

    loss, grads = jax.lax.fori_loop(0, n, body_fun, (0.0, jax.tree_util.tree_map(jnp.zeros_like, ps)))
    loss_fast, grads_fast = jax.value_and_grad(loss_fast_fn)(ps_fast)

    ps = flax.core.FrozenDict({"ps": ps, "ps_fast": ps_fast})
    grads = flax.core.FrozenDict({"ps": grads, "ps_fast": grads_fast})

    us, st_new = update_fn(grads, st, ps)
    ps_new = optax.apply_updates(ps, us)
    vs_new = flax.core.FrozenDict({**vs, "params": ps_new["ps"]})
    vs_fast_new = flax.core.FrozenDict({**vs_fast, "params": ps_new["ps_fast"]})
    vs_slow_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_new, lambda: vs_slow)
    vs_fast_capture_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_fast_new, lambda: vs_fast_capture)
    return st_new, vs_new, vs_fast_new, vs_slow_new, vs_fast_capture_new, loss + loss_fast, key

@functools.partial(jax.jit, static_argnums=9)
def train_md_lsbm(step, st, vs, vs_fast, vs_slow, vs_fast_capture, x, y, key, update_fn):
    key, subkey = jax.random.split(key)
    eta = 1/(1/cfg.md.eta_start+
        (((step-1)//cfg.md.itvl)/(cfg.max_step//cfg.md.itvl-1))*
        (1/cfg.md.eta_end-1/cfg.md.eta_start))

    def loss_fn(ps, key):
        return md_loss(eta,
            {**vs, 'params': ps}, vs_fast_capture, vs_slow, x, key)

    def loss_fast_fn(ps_fast):
        return lsbm_loss({**vs, 'params': ps_fast}, x, y, subkey2)

    _, ps = vs.pop('params')
    _, ps_fast = vs_fast.pop('params')

    n = cfg.md.accumulation
    ks = jax.random.split(subkey, n)
    def body_fun(i, val):
        loss, grads = jax.value_and_grad(loss_fn)(ps, ks[i])
        return loss/n+val[0], jax.tree_util.tree_map(lambda x, y: x/n + y, grads, val[1])

    loss, grads = jax.lax.fori_loop(0, n, body_fun, (0.0, jax.tree_util.tree_map(jnp.zeros_like, ps)))
    loss_fast, grads_fast = jax.value_and_grad(loss_fast_fn)(ps_fast)

    ps = flax.core.FrozenDict({"ps": ps, "ps_fast": ps_fast})
    grads = flax.core.FrozenDict({"ps": grads, "ps_fast": grads_fast})

    us, st_new = update_fn(grads, st, ps)
    ps_new = optax.apply_updates(ps, us)
    vs_new = flax.core.FrozenDict({**vs, "params": ps_new["ps"]})
    vs_fast_new = flax.core.FrozenDict({**vs_fast, "params": ps_new["ps_fast"]})
    vs_slow_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_new, lambda: vs_slow)
    vs_fast_capture_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_fast_new, lambda: vs_fast_capture)
    return st_new, vs_new, vs_fast_new, vs_slow_new, vs_fast_capture_new, loss + loss_fast, key


eval_steps = []
eval_scores1 = []
eval_scores2 = []
def evaluate(step, key):
    key, subkey1, subkey2 = jax.random.split(key, 3)

    xs, ys = sample_plan(cfg.eval_size)

    bw_uvp = compute_BW_UVP_by_gt_samples(model(vs, xs, subkey1), ys)
    if cfg.eps >= 1:
        eps = int(cfg.eps)
    else:
        eps = cfg.eps
    cond_bw_uvp = calcuate_condBW(model, vs, subkey2, cfg.dim, eps, n_samples=1000)

    eval_steps.append(step)
    eval_scores1.append(bw_uvp.tolist())
    eval_scores2.append(cond_bw_uvp.tolist())
    return key


progress_bar, key = tqdm(range(1, cfg.max_step + 1)), evaluate(0, key)

if cfg.alg_name == 'lsb':

    for step in progress_bar:
        xs, ys = sample_plan(cfg.batch_size)
        st, vs, loss = train_lsb(st, vs, xs, ys, opt.update)
        progress_bar.set_description(f'loss [{loss: >11.5f}]')
        if step % cfg.eval_itvl == 0: key = evaluate(step, key)

elif cfg.alg_name == 'lsbm':

    for step in progress_bar:
        xs, ys = sample_plan(cfg.batch_size)
        st, vs, loss, key = train_lsbm(st, vs, xs, ys, key, opt.update)
        progress_bar.set_description(f'loss [{loss: >11.5f}]')
        if step % cfg.eval_itvl == 0: key = evaluate(step, key)

elif cfg.alg_name == 'md0':

    if cfg.burn_in > 0.0: 
        progress_bar2 = tqdm(range(1, int(cfg.max_step * cfg.burn_in) + 1))
        for step in progress_bar2:
            xs, ys = sample_plan(cfg.batch_size)
            st, vs, vs_fast, vs_slow, vs_fast_capture, loss, key = train_md0_pre(step,
                st, vs, vs_fast, vs_slow, vs_fast_capture, xs, ys, key, opt.update)
            progress_bar.set_description(f'loss [{loss: >11.5f}]')
        vs, vs_slow = vs_fast.copy(), vs_fast.copy()

    for step in progress_bar:
        xs, ys = sample_plan(cfg.batch_size)
        st, vs, vs_fast, vs_slow, vs_fast_capture, loss, key = train_md0(step,
            st, vs, vs_fast, vs_slow, vs_fast_capture, xs, ys, key, opt.update)
        progress_bar.set_description(f'loss [{loss: >11.5f}]')
        if step % cfg.eval_itvl == 0: key = evaluate(step, key)

elif cfg.alg_name == 'md0_lsbm':

    if cfg.burn_in > 0.0: 
        progress_bar2 = tqdm(range(1, int(cfg.max_step * cfg.burn_in) + 1))
        for step in progress_bar2:
            xs, ys = sample_plan(cfg.batch_size)
            st, vs, vs_fast, vs_slow, vs_fast_capture, loss, key = train_md0_lsbm_pre(step,
                st, vs, vs_fast, vs_slow, vs_fast_capture, xs, ys, key, opt.update)
            progress_bar.set_description(f'loss [{loss: >11.5f}]')
        vs, vs_slow = vs_fast.copy(), vs_fast.copy()

    for step in progress_bar:
        xs, ys = sample_plan(cfg.batch_size)
        st, vs, vs_fast, vs_slow, vs_fast_capture, loss, key = train_md0_lsbm(step,
            st, vs, vs_fast, vs_slow, vs_fast_capture, xs, ys, key, opt.update)
        progress_bar.set_description(f'loss [{loss: >11.5f}]')
        if step % cfg.eval_itvl == 0: key = evaluate(step, key)

elif cfg.alg_name == 'md':

    for step in progress_bar:
        xs, ys = sample_plan(cfg.batch_size)
        st, vs, vs_fast, vs_slow, vs_fast_capture, loss, key = train_md(step,
            st, vs, vs_fast, vs_slow, vs_fast_capture, xs, ys, key, opt.update)
        progress_bar.set_description(f'loss [{loss: >11.5f}]')
        if step % cfg.eval_itvl == 0: key = evaluate(step, key)

elif cfg.alg_name == 'md_lsbm':

    for step in progress_bar:
        xs, ys = sample_plan(cfg.batch_size)
        st, vs, vs_fast, vs_slow, vs_fast_capture, loss, key = train_md_lsbm(step,
            st, vs, vs_fast, vs_slow, vs_fast_capture, xs, ys, key, opt.update)
        progress_bar.set_description(f'loss [{loss: >11.5f}]')
        if step % cfg.eval_itvl == 0: key = evaluate(step, key)

else: raise ValueError("Invalid algorithm name")

# Result

plt.figure(figsize=(8, 4))
plt.plot(eval_steps, eval_scores1, '-o', color='blue')
plt.plot(eval_steps, eval_scores2, '-o', color='green')
plt.title('BW-UVP / cBW-UVP')
plt.xlabel('Step')
plt.grid(True)
plt.tight_layout(pad=0.1); plt.savefig(f'{fig_path}/result.png')
print(f"BW-UVP: {eval_scores1}")
print(f"cBW-UVP: {eval_scores2}")
