import gwot
from gwot import sim, ts, util
import autograd
import autograd.numpy as np
import importlib
import math
import torch
import sys
import sklearn as sk

# setup potential function
def Psi(x, t, dim):
    x0 = np.array([1, 1] + [0, ]*(dim - 2))
    x1 = -np.array([1, 1] + [0, ]*(dim - 2))
    return 0.9*np.sum((x - x0)*(x - x0), axis = -1) * np.sum((x - x1)*(x - x1), axis = -1) + 10*np.sum(x[:, 2:]**2, axis = -1)

beta = lambda x, t, betamax: betamax*((np.tanh(2*x[0]) + 1)/2)
delta = lambda x, t: 0

if __name__ == '__main__':
    SRAND = 42
    np.random.seed(SRAND)
    dim = int(sys.argv[1])
    N = int(sys.argv[2])
    T = int(sys.argv[3])
    D = (0.5)**2 # diffusivity
    # setup simulation parameters
    sim_steps = 1000 # number of steps to use for Euler-Maruyama method
    t_final = 1 # simulation run on [0, t_final]
    # get gradient 
    dPsi = autograd.elementwise_grad(lambda x, t: Psi(x, t, dim))
    # branching rates
    betamax = float(sys.argv[4])
    # function for particle initialisation
    ic_func = lambda N, d: np.random.randn(N, d)*0.1 # 0.25
    # setup simulation object
    sim = gwot.sim.Simulation(V = Psi, dV = dPsi, birth_death = betamax > 0, 
                              birth = lambda x, t: beta(x, t, betamax), 
                              death = delta,
                              N = np.repeat(N, T), 
                              T = T, 
                              d = dim, 
                              D = D, 
                              t_final = t_final, 
                              ic_func = ic_func, 
                              pool = None)
    # sample from simulation
    sim.sample(steps_scale = int(sim_steps/sim.T));

    paths_gt = sim.sample_trajectory(steps_scale = int(sim_steps/sim.T), N = 1000)

    import matplotlib.pyplot as plt
    ts = np.linspace(0, sim.t_final, sim.T)
    plt.figure(figsize = (3, 3))
    plt.scatter(ts[sim.t_idx] + np.random.randn(len(sim.t_idx))*0.01, sim.x[:, 0], alpha = 0.2, s = 1, color = "red")
    plt.xlabel("t"); plt.ylabel("x")
    plt.title("Sampled particles")
    plt.tight_layout()
    plt.savefig("twowell_data.pdf")

    # Save ground truth potential and velocities
    pot_gt = Psi(sim.x, None, dim = dim)
    vf_gt = -autograd.elementwise_grad(lambda x,t : Psi(x, t, dim = dim))(sim.x, None)
    g_gt = beta(sim.x.T, None, betamax)

    data = {'x' : sim.x,
            't_idx' : sim.t_idx,
            't_final' : sim.t_final,
            'x_paths' : paths_gt,
            'potential_true' : pot_gt, 
            'v_true' : vf_gt, 
            'g_true' : g_gt
            }

    # Compute centroids
    print("Computing centroids")
    clust = sk.cluster.KMeans(n_clusters = 2)
    clust.fit(data['x'][data['t_idx'] == T-1])
    # Compute fate probs
    sys.path.append("../../src/")
    import evals
    from tqdm import tqdm
    ts = np.linspace(0, sim.t_final, sim.T)
    _centroids = torch.tensor(clust.cluster_centers_, dtype = torch.float32)
    def _get_fate(x):
        x0 = torch.tensor(x, dtype = torch.float32)
        y = gwot.sim.sde_integrate(sim.dV, sim.D, x0, sim.t_final*(1-i/sim.T), 100, snaps = np.array([99, ]), birth_death = False)[0][0]
        return torch.tensor(y, dtype = torch.float32)
    print("Computing fate probabilities")
    probs = []
    for i in tqdm(range(T)):
        probs.append(evals.get_centroid_probs(torch.tensor(data['x'][data['t_idx'] == i, :]), _get_fate, _centroids, n_sample = 25))
    probs = torch.vstack(probs)
    data['probs'] = probs
    data['centroids'] = _centroids

    print(f"Writing to : sim_twowell_N_{N}_T_{T}_dim_{dim}_D_{D}_beta_{betamax}.pkl")
    torch.save(data,
            f"sim_twowell_N_{N}_T_{T}_dim_{dim}_D_{D}_beta_{betamax}.pkl")



