#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# ----------------------------------------
import numpy as np
import keras
from keras.models import Sequential
from keras.layers import Dense

N_target = 200
N_init = 500
N_dim = 2
sigmas = np.exp(np.linspace(np.log(3), 0, 5))
whole_region = True # false: Initial samples from gaussian
#dataset, dist, epochs = "student_t", 10.0, [int(2000/sigma) for sigma in sigmas]
dataset, dist, epochs = "gaussian", 8.0, [int(1000/sigma) for sigma in sigmas]

# functions ---------------------------------------------
# sample target data
def sample_Mix_student_t(N_target, N_dim, dist):
    from util.generate_data import generate_four_student_t
    X = generate_four_student_t(size=(N_target, N_dim), dist=dist, nu=0.5, random_seed=0) # target
    return X
    
def sample_Mix_gaussian(N_target, N_dim, dist):
    from util.generate_data import generate_four_gaussians
    X = generate_four_gaussians(size=(N_target, N_dim), dist=dist, std=1.0, random_seed=0) # target
    return X

    
# sample initial data
def sample_initial(N_init, N_dim, dist, whole_region=True):
    if whole_region:
        X = dist*(2*np.random.uniform(size=(N_init, N_dim)) - 1)
    else:
        from util.generate_data import generate_gaussian
        X = generate_gaussian(size=(N_init, N_dim),m=dist, std=1.0, random_seed=0)
    return X
    
# scorenet
def NN(N_dim):
    model = Sequential()
    model.add(Dense(16, input_dim=N_dim, activation='relu'))
    model.add(Dense(32, activation='relu'))
    model.add(Dense(16, activation='relu'))
    model.add(Dense(N_dim, activation='linear'))
    
    model.compile(loss=keras.losses.MeanSquaredError(), optimizer='adam', metrics=[keras.metrics.MeanSquaredError()])
    
    return model
    
# train scorenets
def train_scorenets(target, sigmas, epochs):
    scorenets = [NN(target.shape[1]) for sigma in sigmas]
    
    for i, sigma in enumerate(sigmas):
        perturbed = sigma*np.matmul(np.random.randn(target.shape[0],1), np.ones(shape=[1, target.shape[1]])) + target
        
        scorenets[i].fit(perturbed, (target-perturbed)/sigma**2, epochs=epochs[i], batch_size=64)
    
    return scorenets
    
# annealed_langevin dynamics
def annealed_langevin(scorenets, init, sigmas, lr=0.1, n_steps_each=100):
    for i, sigma in enumerate(sigmas):
        for j in range(n_steps_each):
            current_lr = lr * (sigma / sigmas[-1]) ** 2
            init = init + current_lr / 2 * scorenets[i].predict(init)
            init = init + np.random.randn(*init.shape) * np.sqrt(current_lr)
    return init
    

if dataset == "student_t":
    target = sample_Mix_student_t(N_target, N_dim, dist)
elif dataset == "gaussian":
    target = sample_Mix_gaussian(N_target, N_dim, dist)
init = sample_initial(N_init, N_dim, dist, whole_region=True)
model_path = "util/saved_model/ncsn/{}".format(dataset)
try:
    scorenets = [keras.models.load_model(model_path+"_{}".format(sigma), compile=False) for i, sigma in enumerate(sigmas)]
except:
    scorenets = train_scorenets(target, sigmas, epochs)
    [scorenets[i].save(model_path+"_{}".format(sigma)) for i, sigma in enumerate(sigmas)]

output = annealed_langevin(scorenets, init, sigmas, lr=0.1, n_steps_each=100)

# Plot outputs --------------------------------------------
import matplotlib.pyplot as plt

plt.scatter(target[:,0], target[:,1], label="target", s=5, alpha=0.7)
plt.scatter(output[:, 0], output[:, 1], label="output", s=5, alpha=0.7)
if dataset == "student_t":
    plt.xlim([-3*dist, 3*dist])
    plt.ylim([-3*dist, 3*dist])
elif dataset == "gaussian":
    plt.xlim([-3, dist+3])
    plt.ylim([-3, dist+3])

plt.legend()
plt.tight_layout()
plt.show()
