# Import

import os
import sys
import jax
import flax
import optax
import functools
import numpy as np
from tqdm import tqdm
import jax.numpy as jnp
base_dir = os.path.split(os.path.dirname(__file__))[0]
src_dir = os.path.join(base_dir, 'src'); sys.path.append(src_dir)
from lsb import Lsb, init_r_by_samples, init_orthogonal_S, sb_opt
from util import StandardNormalSampler, ScurveSampler, MoonsSampler, SwissRollSampler, EightGaussianSampler
from util import StandardNormalSampler, ScurveSamplerV2, MoonsSamplerV2, SwissRollSamplerV2, EightGaussianSamplerV2
from util import get_cfg, jax_random_key_with_settings, mmd
from fn import vi, vi0

# Parameter

cfg = get_cfg('cfg', f'{base_dir}/cfg/default_online.py')
key = jax_random_key_with_settings(cfg.seed)
if cfg.cpu: jax.config.update("jax_platform_name", "cpu")

# Data

if cfg.mu == 'normal':
    x_sampler = StandardNormalSampler(cfg.dim)
elif cfg.mu == 'scurve':
    x_sampler = ScurveSamplerV2(cfg.dim)
elif cfg.mu == 'moons':
    x_sampler = MoonsSamplerV2(cfg.dim)
elif cfg.mu == 'swiss':
    x_sampler = SwissRollSampler(cfg.dim)
elif cfg.mu == 'gaussian8':
    x_sampler = EightGaussianSampler(cfg.dim)
else: raise ValueError("Invalid mu name")

if cfg.nu == 'normal':
    y_sampler = StandardNormalSampler(cfg.dim)
elif cfg.nu == 'scurve':
    y_sampler = ScurveSamplerV2(cfg.dim)
elif cfg.nu == 'moons':
    y_sampler = MoonsSamplerV2(cfg.dim)
elif cfg.nu == 'swiss':
    y_sampler = SwissRollSampler(cfg.dim)
elif cfg.nu == 'gaussian8':
    y_sampler = EightGaussianSampler(cfg.dim)
else: raise ValueError("Invalid nu name")

# Model

sb = Lsb(cfg.dim, cfg.n_potential, cfg.eps, cfg.diagonal, cfg.S_init, jnp.float32)
key, subkey1, subkey2 = jax.random.split(key, 3)
vs = flax.core.FrozenDict(sb.init(subkey1, x_sampler.sample(cfg.batch_size), key=subkey2))
vs = init_orthogonal_S((vs))
opt = sb_opt(optax.adam(cfg.lr))

if "md" not in cfg.alg_name:
    st = opt.init(vs["params"])
else:
    vs_fast, vs_slow, vs_fast_capture = vs.copy(), vs.copy(), vs.copy()
    st = opt.init(flax.core.FrozenDict({"ps": vs["params"], "ps_fast": vs_fast["params"]}))

# Function

model = jax.jit(flax.linen.apply(lambda m, x, k: m(x, k), sb))
log_v = jax.jit(flax.linen.apply(lambda m, x: m.get_log_potential(x), sb))
log_c = jax.jit(flax.linen.apply(lambda m, x: m.get_log_C(x), sb))
drift = jax.jit(flax.linen.apply(lambda m, x, t: m.get_drift(x, t), sb))
sample_comp = functools.partial(jax.jit, static_argnums=2)(flax.linen.apply(lambda m, e, n, k: m.sample_comp(e, n, k), sb))
sample_comp_each = functools.partial(jax.jit, static_argnums=3)(flax.linen.apply(lambda m, e, x, n, k: m.sample_comp_each(e, x, n, k), sb))

solve_sde = functools.partial(jax.jit, static_argnums=2)(
    flax.linen.apply(lambda m, x, n, k: m.sample_euler_maruyama(x, n, k), sb))

def lsb_loss(vs, x, y):
    return jnp.mean(log_c(vs, x)) - jnp.mean(log_v(vs, y))

def lsbm_loss(vs, x, y, key):
    k1, k2 = jax.random.split(key)
    t = 0.99 * jax.random.uniform(k1, (x.shape[0], 1))
    x_t = y * t + x * (1 - t) + \
        jax.lax.sqrt(cfg.eps * t * (1 - t)) * \
        jax.random.normal(k2, x.shape)
    drift_target = (y - x_t) / (1 - t)
    diff = drift_target - drift(vs, x_t, jnp.reshape(t, -1))
    return 0.5 * jnp.mean(jax.vmap(jnp.inner)(diff, diff))

def md_loss(eta, vs, vs_fast, vs_slow, x, key):
    key1, key2, key3 = jax.random.split(key, 3)
    ys1 = jax.lax.stop_gradient(sample_comp_each(vs, cfg.eps, x, cfg.md.sample_size, key1))
    ys2 = jax.lax.stop_gradient(sample_comp_each(vs, cfg.eps, x, cfg.md.sample_size, key2))
    return \
        (1 - eta) * jnp.mean(vi(cfg.diagonal, cfg.eps, vs, (vs_slow), x, ys1)) + \
        (  eta  ) * jnp.mean(vi(cfg.diagonal, cfg.eps, vs, (vs_fast), x, ys2))

# Train

@functools.partial(jax.jit, static_argnums=4)
def train_lsb(st, vs, x, y, update_fn):
    def loss_fn(ps):
        return lsb_loss({**vs, 'params': ps}, x, y)
    _, ps = vs.pop('params')
    loss, grads = jax.value_and_grad(loss_fn)(ps)
    us, st_new = update_fn(grads, st, ps)
    ps_new = optax.apply_updates(ps, us)
    vs_new = flax.core.FrozenDict({**vs, "params": ps_new})
    return st_new, vs_new, loss

@functools.partial(jax.jit, static_argnums=5)
def train_lsbm(st, vs, x, y, key, update_fn):
    key, subkey = jax.random.split(key)
    def loss_fn(ps):
        return lsbm_loss({**vs, 'params': ps}, x, y, subkey)
    _, ps = vs.pop('params')
    loss, grads = jax.value_and_grad(loss_fn)(ps)
    us, st_new = update_fn(grads, st, ps)
    ps_new = optax.apply_updates(ps, us)
    vs_new = flax.core.FrozenDict({**vs, "params": ps_new})
    return st_new, vs_new, loss, key

@functools.partial(jax.jit, static_argnums=9)
def train_md(step, st, vs, vs_fast, vs_slow, vs_fast_capture, x, y, key, update_fn):
    key, subkey = jax.random.split(key)
    eta = 1/(1/cfg.md.eta_start+
        (((step-1)//cfg.md.itvl)/(cfg.max_step//cfg.md.itvl-1))*
        (1/cfg.md.eta_end-1/cfg.md.eta_start))

    def loss_fn(ps, key):
        return md_loss(eta,
            {**vs, 'params': ps}, vs_fast_capture, vs_slow, x, key)

    def loss_fast_fn(ps_fast):
        return lsb_loss({**vs, 'params': ps_fast}, x, y)

    _, ps = vs.pop('params')
    _, ps_fast = vs_fast.pop('params')

    n = cfg.md.accumulation
    ks = jax.random.split(subkey, n)
    def body_fun(i, val):
        loss, grads = jax.value_and_grad(loss_fn)(ps, ks[i])
        return loss/n+val[0], jax.tree_util.tree_map(lambda x, y: x/n + y, grads, val[1])

    loss, grads = jax.lax.fori_loop(0, n, body_fun, (0.0, jax.tree_util.tree_map(jnp.zeros_like, ps)))
    loss_fast, grads_fast = jax.value_and_grad(loss_fast_fn)(ps_fast)

    ps = flax.core.FrozenDict({"ps": ps, "ps_fast": ps_fast})
    grads = flax.core.FrozenDict({"ps": grads, "ps_fast": grads_fast})

    us, st_new = update_fn(grads, st, ps)
    ps_new = optax.apply_updates(ps, us)
    vs_new = flax.core.FrozenDict({**vs, "params": ps_new["ps"]})
    vs_fast_new = flax.core.FrozenDict({**vs_fast, "params": ps_new["ps_fast"]})
    vs_slow_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_new, lambda: vs_slow)
    vs_fast_capture_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_fast_new, lambda: vs_fast_capture)
    return st_new, vs_new, vs_fast_new, vs_slow_new, vs_fast_capture_new, loss + loss_fast, key

@functools.partial(jax.jit, static_argnums=9)
def train_md_lsbm(step, st, vs, vs_fast, vs_slow, vs_fast_capture, x, y, key, update_fn):
    key, subkey, subkey2 = jax.random.split(key, 3)
    eta = 1/(1/cfg.md.eta_start+
        (((step-1)//cfg.md.itvl)/(cfg.max_step//cfg.md.itvl-1))*
        (1/cfg.md.eta_end-1/cfg.md.eta_start))

    def loss_fn(ps, key):
        return md_loss(eta,
            {**vs, 'params': ps}, vs_fast_capture, vs_slow, x, key)

    def loss_fast_fn(ps_fast):
        return lsbm_loss({**vs, 'params': ps_fast}, x, y, subkey2)

    _, ps = vs.pop('params')
    _, ps_fast = vs_fast.pop('params')

    n = cfg.md.accumulation
    ks = jax.random.split(subkey, n)
    def body_fun(i, val):
        loss, grads = jax.value_and_grad(loss_fn)(ps, ks[i])
        return loss/n+val[0], jax.tree_util.tree_map(lambda x, y: x/n + y, grads, val[1])

    loss, grads = jax.lax.fori_loop(0, n, body_fun, (0.0, jax.tree_util.tree_map(jnp.zeros_like, ps)))
    loss_fast, grads_fast = jax.value_and_grad(loss_fast_fn)(ps_fast)

    ps = flax.core.FrozenDict({"ps": ps, "ps_fast": ps_fast})
    grads = flax.core.FrozenDict({"ps": grads, "ps_fast": grads_fast})

    us, st_new = update_fn(grads, st, ps)
    ps_new = optax.apply_updates(ps, us)
    vs_new = flax.core.FrozenDict({**vs, "params": ps_new["ps"]})
    vs_fast_new = flax.core.FrozenDict({**vs_fast, "params": ps_new["ps_fast"]})
    vs_slow_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_new, lambda: vs_slow)
    vs_fast_capture_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_fast_new, lambda: vs_fast_capture)
    return st_new, vs_new, vs_fast_new, vs_slow_new, vs_fast_capture_new, loss + loss_fast, key


def rotating_screen(x, step):
    x_coords = x[:, 0]
    y_coords = x[:, 1]

    angles_in_radians = np.arctan2(y_coords, x_coords)
    angles_in_radians = np.where(angles_in_radians < 0, angles_in_radians + 2 * np.pi, angles_in_radians)

    start_angle = ((step // 25) % 8) * (2 * np.pi / 8)
    end_angle = ((step // 25) % 8 + 1) * (2 * np.pi / 8)
    mask = np.logical_and((angles_in_radians >= start_angle), (angles_in_radians < end_angle))

    return x[mask]

progress_bar = tqdm(range(1, cfg.max_step + 1))

if cfg.alg_name == 'lsb':

    for step in progress_bar:
        xs = x_sampler.sample(cfg.batch_size // 8)
        ys = rotating_screen(y_sampler.sample(cfg.batch_size), step)
        st, vs, loss = train_lsb(st, vs, xs, ys, opt.update)
        progress_bar.set_description(f'loss [{loss: >11.5f}]')

elif cfg.alg_name == 'lsbm':

    for step in progress_bar:
        xs = x_sampler.sample(cfg.batch_size)
        ys = rotating_screen(y_sampler.sample(cfg.batch_size), step)
        xs = xs[:ys.shape[0],:]
        st, vs, loss, key = train_lsbm(st, vs, xs, ys, key, opt.update)
        progress_bar.set_description(f'loss [{loss: >11.5f}]')

elif cfg.alg_name == 'md0':

    for step in progress_bar:
        xs = x_sampler.sample(cfg.batch_size // 8)
        ys = rotating_screen(y_sampler.sample(cfg.batch_size), step)
        st, vs, vs_fast, vs_slow, loss, key = train_md0(step,
            st, vs, vs_fast, vs_slow, xs, ys, key, opt.update)
        progress_bar.set_description(f'loss [{loss: >11.5f}]')

elif cfg.alg_name == 'md0_lsbm':

    for step in progress_bar:
        xs = x_sampler.sample(cfg.batch_size // 8)
        ys = rotating_screen(y_sampler.sample(cfg.batch_size), step)
        st, vs, vs_fast, vs_slow, loss, key = train_md0_lsbm(step,
            st, vs, vs_fast, vs_slow, xs, ys, key, opt.update)
        progress_bar.set_description(f'loss [{loss: >11.5f}]')
    
elif cfg.alg_name == 'md':

    for step in progress_bar:
        xs = x_sampler.sample(cfg.batch_size // 8)
        ys = rotating_screen(y_sampler.sample(cfg.batch_size), step)
        st, vs, vs_fast, vs_slow, vs_fast_capture, loss, key = train_md(step,
            st, vs, vs_fast, vs_slow, vs_fast_capture, xs, ys, key, opt.update)
        progress_bar.set_description(f'loss [{loss: >11.5f}]')

elif cfg.alg_name == 'md_lsbm':

    for step in progress_bar:
        xs = x_sampler.sample(cfg.batch_size)
        ys = rotating_screen(y_sampler.sample(cfg.batch_size), step)
        xs = xs[:ys.shape[0],:]
        st, vs, vs_fast, vs_slow, vs_fast_capture, loss, key = train_md_lsbm(step,
            st, vs, vs_fast, vs_slow, vs_fast_capture, xs, ys, key, opt.update)
        progress_bar.set_description(f'loss [{loss: >11.5f}]')

else: raise ValueError("Invalid algorithm name")


# Result

num_eval = 50000

mmd_value = \
    mmd(np.array(y_sampler.sample(num_eval)),
      np.array(model(vs, x_sampler.sample(num_eval), key)))

print(f"{exp_name}: {mmd_value}")


