from time import time
import numpy as np
import torch as tc
from src.utils import Mesh, LinearSolveArgs, SamplingArgs, eval_dxdt, MeshType, flatten, unflatten_TC, grad, flatten_TC
from src.model import GNN
from src.train import sample_and_linear_solve
from src.data import MassSpring

n_obj = 5
dof = 2
model_config = { "width": 512, "enc_width": 32, "activ_str": "softplus", "local_pooling": "sum", "direct": False, "global_pooling": "sum" }
train_config = { "mode": "forward", "driver": "gels", "rcond": 1e-15, "device": "cpu" }
device = "cpu"
auto_diff_mode = "forward"
dtype_np = np.float64
dtype_tc = tc.float64
n_features = n_obj * 2*dof

train_split = 1.0
integrator = "stormer_verlet"
#integrator = "runge_kutta"
dt_obs = 0.1 # observed dt is 1e-03 * 1e+02 = 1e-01 from the data provided by NRI
dt_true = 1e-03
len_traj = 100 # 1second? What is this exactly if not 100?

train_data_path = f"./data/{n_obj}-spring/dataset_train.npy"
test_data_path = f"./data/{n_obj}-spring/dataset_test.npy"

model_seed = 5165
data_seed = 15694 # for shuffling
rng = np.random.default_rng(data_seed)

def load_data():
    """
    Load training data and split it as 75 training and 25 evaluation trajectories
    """
    data = np.load(train_data_path).astype(dtype_np)
    print(f"-> Data of shape {data.shape} is read")
    q, p = np.split(data, 2, axis=3)
    q_dot = q[:, 1, ...]
    q = q[:, 0, ...]
    p_dot = p[:, 1, ...]
    p = p[:, 0, ...]

    # reshape into n_points = num_trajs * len_traj
    q, q_dot = q.reshape(-1, n_obj, 2), q_dot.reshape(-1, n_obj, 2)
    p, p_dot = p.reshape(-1, n_obj, 2), p_dot.reshape(-1, n_obj, 2)
    x = np.concatenate([q, p], axis=-1)
    dxdt = np.concatenate([q_dot, p_dot], axis=-1)

    print("-> Read data")
    print(f"x of shape {x.shape}")
    print(f"dxdt of shape {dxdt.shape}")

    # flatten
    x, dxdt = flatten(x), flatten(dxdt)
    x_train, dxdt_train = tc.from_numpy(x), tc.from_numpy(dxdt)
    #x_train, x_test, dxdt_train, dxdt_test = train_test_split(
    #    tc.from_numpy(x),
    #    tc.from_numpy(dxdt),
    #    train_size=train_split,
    #    shuffle=True,
    #    random_state=1591756
    #)

    # x_train, x_test = tc.from_numpy(x[:train_split]), tc.from_numpy(x[train_split:])
    # print(f"x_train limits: from {x_train.min():.2e} to {x_train.max():.2e}")
    # print(f"x_test limits: from {x_test.min():.2e} to {x_test.max():.2e}")
    # dxdt_train, dxdt_test = tc.from_numpy(dxdt[:train_split]), tc.from_numpy(dxdt[train_split:])

    # construct edge index
    L = np.zeros((n_features, n_features), x.dtype)
    total_num_objects = n_obj
    for i in range(total_num_objects):
        for j in range(dof):
            L[dof*i + j, dof*total_num_objects + dof*i + j] =  1.
            L[dof*total_num_objects + dof*i + j, dof*i + j] = -1.

    # circular connection
    #edge_index = [[0,1], [1,2], [2,0]]
    #edge_index = [[0,1], [1,2], [2,3], [3,0]]
    #edge_index = [[0,1], [1,2], [2,3], [3,4], [4,0]]
    edge_index = [[4,0], [4,3], [3,2], [2,1], [1,0]]
    edge_index = tc.tensor(edge_index).T

    return x_train, tc.from_numpy(L), edge_index, dxdt_train
    #return x_train, x_test, tc.from_numpy(L), edge_index, dxdt_train, dxdt_test

def create_gnn(n_obj, edge_index, model_seed):
    return GNN(dof=dof, n_obj=n_obj, edge_index=edge_index,
               direct=model_config["direct"],
               msg_width=model_config["width"]-model_config["enc_width"],
               enc_width=model_config["enc_width"],
               local_pooling=model_config["local_pooling"], global_pooling=model_config["global_pooling"],
               activ_str=model_config["activ_str"],
               seed=model_seed, dtype=dtype_tc)

def sample_fit(model, x_train, L, dxdt_train, param_sampler, sample_uniformly, sampling_seed):
    sampling_args = SamplingArgs(
        param_sampler=param_sampler,
        seed=sampling_seed,
        sample_uniformly=sample_uniformly,
        resample_duplicates=True,
        dtype=dtype_np,
    )
    linear_solve_args = LinearSolveArgs(
        mode=auto_diff_mode,
        driver=train_config["driver"],
        rcond=train_config["rcond"],
        device=device,
    )
    sample_and_linear_solve(model, x_train, L, dxdt_train, sampling_args, linear_solve_args)
    dxdt_mse, dxdt_rel2 = eval_dxdt(model, x_train, L, dxdt_train, mode=auto_diff_mode, verbose=False)
    return dxdt_mse, dxdt_rel2


x_train, L, edge_index, dxdt_train = load_data()
print("Warming up...")
model = create_gnn([n_obj], edge_index, model_seed)
_, _ = sample_fit(model, x_train, L, dxdt_train, param_sampler="random", sample_uniformly=True, sampling_seed=198456)
del model
tc.cuda.empty_cache()
print("Warm-up complete.")
#x_train, x_test, L, edge_index, dxdt_train, dxdt_test = load_data()
print(x_train.dtype)
print(L.dtype)
print(edge_index.dtype)
model = create_gnn([n_obj], edge_index, model_seed)                                       # Initialize a GNN
print(model.linear.weight.dtype)
print(model.node_encoder.weight.dtype)
time0 = time()
train_mse, train_rel2 = sample_fit(model, x_train, L, dxdt_train, "relu", sample_uniformly=True, sampling_seed=51234)
time1 = time()
print(f"-> train mse :  {train_mse:.2e}")
print(f"-> train rel2:  {train_rel2:.2e}")
train_time = time1 - time0
print(f"-> Training took {train_time:.2f}")

#test_mse, test_rel2 = eval_dxdt(model, x_test, L, dxdt_test, mode=auto_diff_mode, verbose=False)
#print(f"-> test mse :   {test_mse:.2e}")
#print(f"-> test rel2:   {test_rel2:.2e}")

def rollout_error(true, pred):
    """
    Trajectories of
    True shape (len_traj,  dof)
    Pred shape (len_traj,  dof)
    """
    def get_err(x_true, x_pred, step_idx):
        return np.sqrt(np.sum((x_pred[step_idx] - x_true[step_idx])**2)) / ( (np.sqrt(np.sum(x_pred[step_idx]))) + (np.sqrt(np.sum(x_true[step_idx]))) )

    all_err = [ get_err(true, pred, idx) for idx in range(1, 100) ]
    print(f"-> Step: 1          {get_err(true, pred, 1)}")
    print(f"-> Step: 10         {get_err(true, pred, 9)}")
    print(f"-> Step: 20         {get_err(true, pred, 19)}")
    print(f"-> Step: 30         {get_err(true, pred, 29)}")
    print(f"-> Step: 40         {get_err(true, pred, 39)}")
    print(f"-> Step: 50         {get_err(true, pred, 49)}")
    print(f"-> Step: 60         {get_err(true, pred, 59)}")
    print(f"-> Step: 70         {get_err(true, pred, 69)}")
    print(f"-> Step: 80         {get_err(true, pred, 79)}")
    print(f"-> Step: 90         {get_err(true, pred, 89)}")
    print(f"-> Step: 100        {get_err(true, pred, 99)}")
    return np.linalg.norm(pred - true) / ( np.linalg.norm(pred) + np.linalg.norm(true) ), all_err

def integrate(system, integration_grad_H):
    time0 = time()
    system.integrate(len_traj, dt_obs, integration_grad_H, integration_method=integrator)
    time1 = time()
    print(f"Integration took: {(time1-time0):.2f} seconds")
    x = system.to_array()
    print(f"q limits: [{np.min(x[..., :dof])}, {np.max(x[..., :dof])}]")
    print(f"p limits: [{np.min(x[..., dof:])}, {np.max(x[..., dof:])}]")
    qx, qy = x[..., 0], x[..., 1]
    hamiltonian = system.H()
    print(f"hamiltonian limits: [{np.min(hamiltonian):.2f}, {np.max(hamiltonian):.2f}]")
    return qx, qy, hamiltonian

# Evaluate test trajectory
def test_trajectory(model, L, data):
    q, p = np.split(data, 2, axis=3)
    q_dot = q[:, 1, ...]
    q = q[:, 0, ...]
    p_dot = p[:, 1, ...]
    p = p[:, 0, ...]

    # reshape into n_points = num_trajs * len_traj
    q, q_dot = q.reshape(-1, n_obj, dof), q_dot.reshape(-1, n_obj, dof)
    p, p_dot = p.reshape(-1, n_obj, dof), p_dot.reshape(-1, n_obj, dof)
    x = np.concatenate([q, p], axis=-1)
    dxdt = np.concatenate([q_dot, p_dot], axis=-1)
    qx_true, qy_true = x[..., 0], x[..., 1]
    x, dxdt = flatten(x), flatten(dxdt)

    # Evaluate error on dxdt
    test_mse, test_rel2 = eval_dxdt(model, tc.from_numpy(x), L, tc.from_numpy(dxdt),
                                    mode=auto_diff_mode, verbose=False)
    print(f"-> test mse :   {test_mse:.2e}")
    print(f"-> test rel2:   {test_rel2:.2e}")

    system = MassSpring(len_traj, n_features, q.reshape(-1, n_features//2), p.reshape(-1, n_features//2), n_obj=[n_obj], dof=dof, meshing=Mesh(MeshType("rectangular")))
    return system, qx_true, qy_true

data = np.load(test_data_path).astype(dtype_np)
print(f"-> Data of shape {data.shape} is read")
system, qx_true, qy_true = test_trajectory(model, L, data)
print("-> got qx_true of shape", qx_true.shape)
q_true = np.concatenate([qx_true, qy_true], axis=-1)
print("-> q_true of shape", q_true.shape)
model_grad_H = lambda x: unflatten_TC(grad(model, flatten_TC(tc.from_numpy(x))), model.n_obj, dof).detach().numpy()
print("-> Integrating")
qx_pred, qy_pred, hamil_pred = integrate(system, model_grad_H)
print("-> got qx_pred of shape", qx_pred.shape)
q_pred = np.concatenate([qx_pred, qy_pred], axis=-1)
print("-> got q_pred of shape", q_pred.shape)
err, all_err_test = rollout_error(q_true, q_pred)
np.save(f"./benchmark_test_traj_spring{n_obj}_rollout_error.npy", all_err_test)
print(f"-> Rollout error:   {err:.2e}")
mse = np.mean((qx_true - qx_pred)**2 + (qy_true - qy_pred)**2, axis=1)
print(f"=trajectory mean of all (q) MSE: {mse.mean():.2e}")

exit(0)

print("========= Testing in one train traj ==============")
data = np.load(train_data_path).astype(dtype_np)
data = data[0][np.newaxis, ...]
print(f"-> Data of shape {data.shape} is read")
system, qx_true, qy_true = test_trajectory(model, L, data)
print("-> got qx_true of shape", qx_true.shape)
q_true = np.concatenate([qx_true, qy_true], axis=-1)
print("-> q_true of shape", q_true.shape)
model_grad_H = lambda x: unflatten_TC(grad(model, flatten_TC(tc.from_numpy(x))), model.n_obj, dof).detach().numpy()
print("-> Integrating")
qx_pred, qy_pred, hamil_pred = integrate(system, model_grad_H)
print("-> got qx_pred of shape", qx_pred.shape)
q_pred = np.concatenate([qx_pred, qy_pred], axis=-1)
print("-> got q_pred of shape", q_pred.shape)
err, all_err_train = rollout_error(q_true, q_pred)
np.save(f"./benchmark_train_traj_spring{n_obj}.npy", all_err_train)
print(f"-> Rollout error:   {err:.2e}")
mse = np.mean((qx_true - qx_pred)**2 + (qy_true - qy_pred)**2, axis=1)
print(f"=trajectory mean of all (q) MSE: {mse.mean():.2e}")


exit(0)
