# Import

import os
import sys
import jax
import flax
import torch
import optax
import functools
import numpy as np
from tqdm import tqdm
import jax.numpy as jnp
import matplotlib.pyplot as plt
base_dir = os.path.split(os.path.dirname(__file__))[0]
src_dir = os.path.join(base_dir, 'src'); sys.path.append(src_dir)
alae_dir = os.path.join(base_dir, 'ALAE'); sys.path.append(alae_dir)
from lsb import Lsb, init_orthogonal_S, sb_opt
from alae_ffhq_inference import load_model, decode
from util import mmd, get_cfg, jax_random_key_with_settings, sample_array
from fn import vi, vi0

# Parameter

cfg = get_cfg('cfg', f'{base_dir}/cfg/default_alae.py')
key = jax_random_key_with_settings(cfg.seed)


# Data

if cfg.data_preload:
    if not os.path.isdir('../data'): os.makedirs('../data')

    import gdown

    urls = {
        "../data/age.npy": "https://drive.google.com/uc?id=1Vi6NzxCsS23GBNq48E-97Z9UuIuNaxPJ",
        "../data/gender.npy": "https://drive.google.com/uc?id=1SEdsmQGL3mOok1CPTBEfc_O1750fGRtf",
        "../data/latents.npy": "https://drive.google.com/uc?id=1ENhiTRsHtSjIjoRu1xYprcpNd8M9aVu8",
        "../data/test_images.npy": "https://drive.google.com/uc?id=1SjBWWlPjq-dxX4kxzW-Zn3iUR3po8Z0i",
    }

    for name, url in urls.items():
        gdown.download(url, os.path.join(f"{name}"), quiet=False)

data = np.load("../data/latents.npy")
gender = np.load("../data/gender.npy")
age = np.load("../data/age.npy")
test_imgs = np.load("../data/test_images.npy")

data_train, data_test = data[:cfg.train_size], data[cfg.train_size:]
gender_train, gender_test = gender[:cfg.train_size], gender[cfg.train_size:]
age_train, age_test = age[:cfg.train_size], age[cfg.train_size:]

if cfg.input_group == "male":
    x_train_inds = np.arange(cfg.train_size)[(gender_train == "male").reshape(-1)]
    x_test_inds = np.arange(cfg.test_size)[(gender_test == "male").reshape(-1)]
elif cfg.input_group == "female":
    x_train_inds = np.arange(cfg.train_size)[(gender_train == "female").reshape(-1)]
    x_test_inds = np.arange(cfg.test_size)[(gender_test == "female").reshape(-1)]
elif cfg.input_group == "adult":
    x_train_inds = np.arange(cfg.train_size)[(age_train >= 18).reshape(-1)*(age_train != -1).reshape(-1)]
    x_test_inds = np.arange(cfg.test_size)[(age_test >= 18).reshape(-1)*(age_test != -1).reshape(-1)]
elif cfg.input_group == "child":
    x_train_inds = np.arange(cfg.train_size)[(age_train < 18).reshape(-1)*(age_train != -1).reshape(-1)]
    x_test_inds = np.arange(cfg.test_size)[(age_test < 18).reshape(-1)*(age_test != -1).reshape(-1)]

if cfg.target_group == "male":
    y_train_inds = np.arange(cfg.train_size)[(gender_train == "male").reshape(-1)]
    y_test_inds = np.arange(cfg.test_size)[(gender_test == "male").reshape(-1)]
elif cfg.target_group == "female":
    y_train_inds = np.arange(cfg.train_size)[(gender_train == "female").reshape(-1)]
    y_test_inds = np.arange(cfg.test_size)[(gender_test == "female").reshape(-1)]
elif cfg.target_group == "adult":
    y_train_inds = np.arange(cfg.train_size)[(age_train >= 18).reshape(-1)*(age_train != -1).reshape(-1)]
    y_test_inds = np.arange(cfg.test_size)[(age_test >= 18).reshape(-1)*(age_test != -1).reshape(-1)]
elif cfg.target_group == "child":
    y_train_inds = np.arange(cfg.train_size)[(age_train < 18).reshape(-1)*(age_train != -1).reshape(-1)]
    y_test_inds = np.arange(cfg.test_size)[(age_test < 18).reshape(-1)*(age_test != -1).reshape(-1)]

x_train = jax.device_put(data_train[x_train_inds])
x_test = jax.device_put(data_test[x_test_inds])
y_train = jax.device_put(data_train[y_train_inds])
y_test = jax.device_put(data_test[y_test_inds])


# Model

sb = Lsb(cfg.dim, cfg.n_potential, cfg.eps, cfg.diagonal, cfg.S_init, jnp.float32)
key, subkey1, subkey2, subkey3 = jax.random.split(key, 4)
vs = flax.core.FrozenDict(sb.init(subkey1, sample_array(x_train, cfg.batch_size, subkey2), key=subkey3))
vs = init_orthogonal_S(vs)
opt = sb_opt(optax.adam(cfg.lr))

if "md" not in cfg.alg_name:
    st = opt.init(vs["params"])
else:
    vs_fast, vs_slow, vs_fast_capture = vs.copy(), vs.copy(), vs.copy()
    st = opt.init(flax.core.FrozenDict({"ps": vs["params"], "ps_fast": vs_fast["params"]}))


# Function

log_v = jax.jit(flax.linen.apply(lambda m, x: m.get_log_potential(x), sb))
log_c = jax.jit(flax.linen.apply(lambda m, x: m.get_log_C(x), sb))
model = jax.jit(flax.linen.apply(lambda m, x, k: m(x, k), sb))
drift = jax.jit(flax.linen.apply(lambda m, x, t: m.get_drift(x, t), sb))
sample_comp = functools.partial(jax.jit, static_argnums=2)(flax.linen.apply(lambda m, e, n, k: m.sample_comp(e, n, k), sb))
sample_comp_each = functools.partial(jax.jit, static_argnums=3)(flax.linen.apply(lambda m, e, x, n, k: m.sample_comp_each(e, x, n, k), sb))

def lsb_loss(vs, key):
    n = cfg.batch_size
    k1, k2 = jax.random.split(key)
    x = sample_array(x_train, n, k1)
    y = sample_array(y_train, n, k2)
    return jnp.mean(log_c(vs, x)) - jnp.mean(log_v(vs, y))

def lsbm_loss(vs, key):
    n = cfg.batch_size
    k1, k2, k3, k4 = jax.random.split(key, 4)
    x = sample_array(x_train, n, k1)
    y = sample_array(y_train, n, k2)
    t = 0.999 * jax.random.uniform(k3, (x.shape[0], 1))
    x_t = y * t + x * (1 - t) + \
        jax.lax.sqrt(cfg.eps * t * (1 - t)) * \
        jax.random.normal(k4, x.shape)
    drift_target = (y - x_t) / (1 - t)
    diff = drift_target - drift(vs, x_t, jnp.reshape(t, -1))
    return 0.5 * jnp.mean(jax.vmap(jnp.inner)(diff, diff))

def md0_loss(eta, vs, vs_fast, vs_slow, key):
    key1, key2 = jax.random.split(key, 2)
    ys1 = jax.lax.stop_gradient(sample_comp(vs, cfg.eps, cfg.md.sample_size, key1))
    ys2 = jax.lax.stop_gradient(sample_comp(vs, cfg.eps, cfg.md.sample_size, key2))

    return \
        (1 - eta) * (vi0(cfg.diagonal, cfg.eps, vs, (vs_slow), ys1)) + \
        (  eta  ) * (vi0(cfg.diagonal, cfg.eps, vs, (vs_fast), ys2))

def md_loss(eta, vs, vs_fast, vs_slow, key):
    n = cfg.batch_size
    k1, k2, k3, k4 = jax.random.split(key, 4)
    x = sample_array(x_train, n, k1)
    ys1 = jax.lax.stop_gradient(sample_comp_each(vs, cfg.eps, x, cfg.md.sample_size, k2))
    ys2 = jax.lax.stop_gradient(sample_comp_each(vs, cfg.eps, x, cfg.md.sample_size, k3))
    return \
        (1 - eta) * jnp.mean(vi(cfg.diagonal, cfg.eps, vs, (vs_slow), x, ys1)) + \
        (  eta  ) * jnp.mean(vi(cfg.diagonal, cfg.eps, vs, (vs_fast), x, ys2))


# Train

exp_name = f"{cfg.alg_name}_alae_seed_{cfg.seed}"
fig_path = f'{base_dir}/fig/{exp_name}'
if not os.path.exists(fig_path): os.makedirs(fig_path)

@functools.partial(jax.jit, static_argnums=3)
def train_lsb(st, vs, key, update_fn):
    key, subkey = jax.random.split(key)
    def loss_fn(ps):
        return lsb_loss({**vs, 'params': ps}, subkey)
    _, ps = vs.pop('params')
    loss, grads = jax.value_and_grad(loss_fn)(ps)
    us, st_new = update_fn(grads, st, ps)
    ps_new = optax.apply_updates(ps, us)
    vs_new = flax.core.FrozenDict({**vs, "params": ps_new})
    return st_new, vs_new, loss, key

@functools.partial(jax.jit, static_argnums=3)
def train_lsbm(st, vs, key, update_fn):
    key, subkey = jax.random.split(key)
    def loss_fn(ps):
        return lsbm_loss({**vs, 'params': ps}, subkey)
    _, ps = vs.pop('params')
    loss, grads = jax.value_and_grad(loss_fn)(ps)
    us, st_new = update_fn(grads, st, ps)
    ps_new = optax.apply_updates(ps, us)
    vs_new = flax.core.FrozenDict({**vs, "params": ps_new})
    return st_new, vs_new, loss, key


@functools.partial(jax.jit, static_argnums=7)
def train_md0(step, st, vs, vs_fast, vs_fast_capture, vs_slow, key, update_fn):
    key, subkey, subkey2 = jax.random.split(key, 3)
    eta = 1/(1/cfg.md.eta_start+
        (((step-1)//cfg.md.itvl)/(cfg.max_step//cfg.md.itvl-1))*
        (1/cfg.md.eta_end-1/cfg.md.eta_start))

    def loss_fn(ps):
        return md0_loss(eta,
            {**vs, 'params': ps}, vs_fast_capture, vs_slow, subkey)

    def loss_fast_fn(ps_fast):
        return lsb_loss({**vs, 'params': ps_fast}, subkey2)

    _, ps = vs.pop('params')
    _, ps_fast = vs_fast.pop('params')

    loss, grads = jax.value_and_grad(loss_fn)(ps)
    loss_fast, grads_fast = jax.value_and_grad(loss_fast_fn)(ps_fast)

    ps = flax.core.FrozenDict({"ps": ps, "ps_fast": ps_fast})
    grads = flax.core.FrozenDict({"ps": grads, "ps_fast": grads_fast})

    us, st_new = update_fn(grads, st, ps)
    ps_new = optax.apply_updates(ps, us)
    vs_new = flax.core.FrozenDict({**vs, "params": ps_new["ps"]})
    vs_fast_new = flax.core.FrozenDict({**vs_fast, "params": ps_new["ps_fast"]})
    vs_slow_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_new, lambda: vs_slow)
    vs_fast_capture_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_fast_new, lambda: vs_fast_capture)
    return st_new, vs_new, vs_fast_new, vs_slow_new, vs_fast_capture_new, loss + loss_fast, key

@functools.partial(jax.jit, static_argnums=7)
def train_md0_lsbm(step, st, vs, vs_fast, vs_slow, vs_fast_capture, key, update_fn):
    key, subkey, subkey2 = jax.random.split(key, 3)
    eta = 1/(1/cfg.md.eta_start+
        (((step-1)//cfg.md.itvl)/(cfg.max_step//cfg.md.itvl-1))*
        (1/cfg.md.eta_end-1/cfg.md.eta_start))

    def loss_fn(ps):
        return md0_loss(eta,
            {**vs, 'params': ps}, vs_fast_capture, vs_slow, subkey)

    def loss_fast_fn(ps_fast):
        return lsbm_loss({**vs, 'params': ps_fast}, subkey2)

    _, ps = vs.pop('params')
    _, ps_fast = vs_fast.pop('params')

    loss, grads = jax.value_and_grad(loss_fn)(ps)
    loss_fast, grads_fast = jax.value_and_grad(loss_fast_fn)(ps_fast)

    ps = flax.core.FrozenDict({"ps": ps, "ps_fast": ps_fast})
    grads = flax.core.FrozenDict({"ps": grads, "ps_fast": grads_fast})

    us, st_new = update_fn(grads, st, ps)
    ps_new = optax.apply_updates(ps, us)
    vs_new = flax.core.FrozenDict({**vs, "params": ps_new["ps"]})
    vs_fast_new = flax.core.FrozenDict({**vs_fast, "params": ps_new["ps_fast"]})
    vs_slow_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_new, lambda: vs_slow)
    vs_fast_capture_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_fast_new, lambda: vs_fast_capture)
    return st_new, vs_new, vs_fast_new, vs_slow_new, vs_fast_capture_new, loss + loss_fast, key

@functools.partial(jax.jit, static_argnums=7)
def train_md(step, st, vs, vs_fast, vs_slow, vs_fast_capture, key, update_fn):
    key, subkey, subkey2 = jax.random.split(key, 3)
    eta = 1/(1/cfg.md.eta_start+
        (((step-1)//cfg.md.itvl)/(cfg.max_step//cfg.md.itvl-1))*
        (1/cfg.md.eta_end-1/cfg.md.eta_start))

    def loss_fn(ps, key):
        return md_loss(eta,
            {**vs, 'params': ps}, vs_fast_capture, vs_slow, key)

    def loss_fast_fn(ps_fast):
        return lsb_loss({**vs, 'params': ps_fast}, subkey2)

    _, ps = vs.pop('params')
    _, ps_fast = vs_fast.pop('params')

    n = cfg.md.accumulation
    ks = jax.random.split(subkey, n)
    def body_fun(i, val):
        loss, grads = jax.value_and_grad(loss_fn)(ps, ks[i])
        return loss/n+val[0], jax.tree_util.tree_map(lambda x, y: x/n + y, grads, val[1])

    loss, grads = jax.lax.fori_loop(0, n, body_fun, (0.0, jax.tree_util.tree_map(jnp.zeros_like, ps)))
    loss_fast, grads_fast = jax.value_and_grad(loss_fast_fn)(ps_fast)

    ps = flax.core.FrozenDict({"ps": ps, "ps_fast": ps_fast})
    grads = flax.core.FrozenDict({"ps": grads, "ps_fast": grads_fast})

    us, st_new = update_fn(grads, st, ps)
    ps_new = optax.apply_updates(ps, us)
    vs_new = flax.core.FrozenDict({**vs, "params": ps_new["ps"]})
    vs_fast_new = flax.core.FrozenDict({**vs_fast, "params": ps_new["ps_fast"]})
    vs_slow_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_new, lambda: vs_slow)
    vs_fast_capture_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_fast_new, lambda: vs_fast_capture)
    return st_new, vs_new, vs_fast_new, vs_slow_new, vs_fast_capture_new, loss + loss_fast, key

@functools.partial(jax.jit, static_argnums=7)
def train_md_lsbm(step, st, vs, vs_fast, vs_slow, vs_fast_capture, key, update_fn):
    key, subkey, subkey2 = jax.random.split(key, 3)
    eta = 1/(1/cfg.md.eta_start+
        (((step-1)//cfg.md.itvl)/(cfg.max_step//cfg.md.itvl-1))*
        (1/cfg.md.eta_end-1/cfg.md.eta_start))

    def loss_fn(ps, key):
        return md_loss(eta,
            {**vs, 'params': ps}, vs_fast_capture, vs_slow, key)

    def loss_fast_fn(ps_fast):
        return lsbm_loss({**vs, 'params': ps_fast}, subkey2)

    _, ps = vs.pop('params')
    _, ps_fast = vs_fast.pop('params')

    n = cfg.md.accumulation
    ks = jax.random.split(subkey, n)
    def body_fun(i, val):
        loss, grads = jax.value_and_grad(loss_fn)(ps, ks[i])
        return loss/n+val[0], jax.tree_util.tree_map(lambda x, y: x/n + y, grads, val[1])

    loss, grads = jax.lax.fori_loop(0, n, body_fun, (0.0, jax.tree_util.tree_map(jnp.zeros_like, ps)))
    loss_fast, grads_fast = jax.value_and_grad(loss_fast_fn)(ps_fast)

    ps = flax.core.FrozenDict({"ps": ps, "ps_fast": ps_fast})
    grads = flax.core.FrozenDict({"ps": grads, "ps_fast": grads_fast})

    us, st_new = update_fn(grads, st, ps)
    ps_new = optax.apply_updates(ps, us)
    vs_new = flax.core.FrozenDict({**vs, "params": ps_new["ps"]})
    vs_fast_new = flax.core.FrozenDict({**vs_fast, "params": ps_new["ps_fast"]})
    vs_slow_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_new, lambda: vs_slow)
    vs_fast_capture_new = jax.lax.cond(step % cfg.md.itvl == 0, lambda: vs_fast_new, lambda: vs_fast_capture)
    return st_new, vs_new, vs_fast_new, vs_slow_new, vs_fast_capture_new, loss + loss_fast, key

progress_bar = tqdm(range(1, cfg.max_step + 1))

if cfg.alg_name == 'lsb':

    for step in progress_bar:
        st, vs, loss, key = train_lsb(st, vs, key, opt.update)
        progress_bar.set_description(f'loss [{loss: >13.5f}]')

elif cfg.alg_name == 'lsbm':

    for step in progress_bar:
        st, vs, loss, key = train_lsbm(st, vs, key, opt.update)
        progress_bar.set_description(f'loss [{loss: >13.5f}]')

elif cfg.alg_name == 'md0':

    for step in progress_bar:
        st, vs, vs_fast, vs_slow, vs_fast_capture, loss, key = \
            train_md0(step, st, vs, vs_fast, vs_slow, vs_fast_capture, key, opt.update)
        progress_bar.set_description(f'loss [{loss: >13.5f}]')

elif cfg.alg_name == 'md0_lsbm':

    for step in progress_bar:
        st, vs, vs_fast, vs_slow, vs_fast_capture, loss, key = \
            train_md0_lsbm(step, st, vs, vs_fast, vs_slow, vs_fast_capture, key, opt.update)
        progress_bar.set_description(f'loss [{loss: >13.5f}]')

elif cfg.alg_name == 'md':

    for step in progress_bar:
        st, vs, vs_fast, vs_slow, vs_fast_capture, loss, key = \
            train_md0(step, st, vs, vs_fast, vs_slow, vs_fast_capture, key, opt.update)
        progress_bar.set_description(f'loss [{loss: >13.5f}]')

elif cfg.alg_name == 'md_lsbm':

    for step in progress_bar:
        st, vs, vs_fast, vs_slow, vs_fast_capture, loss, key = \
            train_md_lsbm(step, st, vs, vs_fast, vs_slow, vs_fast_capture, key, opt.update)
        progress_bar.set_description(f'loss [{loss: >13.5f}]')

else: raise ValueError("Invalid algorithm name")


# Result

alae_model = load_model(f'{alae_dir}/configs/ffhq.yaml', training_artifacts_dir=f'{alae_dir}/training_artifacts/ffhq/')

inds_to_map = np.random.choice((x_test_inds < 300).sum(), size=10, replace=False)
n_sample = 3
n_pic = cfg.n_pic # 2

mapped_all = []
data_samples = jax.device_put(data_test[x_test_inds[inds_to_map]])
img_samples = test_imgs[x_test_inds[inds_to_map]]

for k in range(n_sample):
    key, subkey = jax.random.split(key)
    mapped = model(vs, data_samples, subkey)
    mapped_all.append(torch.from_numpy(np.asarray(jax.device_put(mapped, jax.devices("cpu")[0])).copy()))

mapped = torch.stack(mapped_all, dim=1)

decoded_all = []
with torch.no_grad():
    for k in range(n_sample):
        decoded = decode(alae_model, mapped[:, k])
        decoded = ((decoded * 0.5 + 0.5) * 255).type(torch.long).clamp(0, 255).cpu().type(torch.uint8).permute(0, 2, 3, 1).numpy()
        decoded_all.append(decoded)

decoded_all = np.stack(decoded_all, axis=1)

fig, axes = plt.subplots(n_pic, n_sample+1, figsize=(n_sample+1, n_pic), dpi=200)

for i, ind in enumerate(range(n_pic)):
    ax = axes[i]
    ax[0].imshow(img_samples[ind])
    for k in range(n_sample):
        ax[k+1].imshow(decoded_all[ind, k])
        ax[k+1].get_xaxis().set_visible(False)
        ax[k+1].set_yticks([])
    ax[0].get_xaxis().set_visible(False)
    ax[0].set_yticks([])

fig.tight_layout(pad=0.1)
fig.savefig(f'{fig_path}/result_eps-{cfg.eps}_{cfg.input_group}-to-{cfg.target_group}_mmd-{mmd(np.asarray(model(vs, x_test, key)), y_test)}.png')




