import sys
sys.path.append("../")
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

import torch
import torch.nn as nn
import numpy as np
import time

from HyperSINDy import Net
from baseline import Trainer
from library_utils import Library
from Datasets import SyntheticDataset
from other import init_weights, set_random_seed

"""
Train HyperSINDy on bayesian data and time how long it takes
"""

def pipeline(library, trainset, epochs, lr,
             lmda_init, lmda_max, lmda_max_epoch, lmda_spike, lmda_spike_epoch,
             beta_init, beta_max, beta_max_epoch, beta_spike, beta_spike_epoch,
             adam_reg, gamma_factor, batch_size,
             thresh_interval, eval_interval, hard_thresh,
             run_name, runs, noise_dim, hidden_dim, stat_size, device,
             num_hidden, batch_norm):
    print(run_name)

    torch.cuda.set_device(device=device)
    device = torch.cuda.current_device()
    net = Net(library, noise_dim=noise_dim, hidden_dim=hidden_dim,
              statistic_batch_size=stat_size, num_hidden=num_hidden,
              batch_norm=batch_norm).to(device)
    net.apply(init_weights)

    trainer = Trainer(net, library, runs + run_name, runs + "cp_" + run_name + ".pt",
                      beta_init=beta_init, beta_max=beta_max, beta_max_epoch=beta_max_epoch, 
                      beta_spike=beta_spike, beta_spike_epoch=beta_spike_epoch,
                      lmda_init=lmda_init, lmda_max=lmda_max, lmda_max_epoch=lmda_max_epoch,
                      lmda_spike=lmda_spike, lmda_spike_epoch=lmda_spike_epoch,
                      learning_rate=lr, adam_reg=adam_reg, gamma_factor=gamma_factor,
                      epochs=epochs, batch_size=batch_size, device=device,
                      hard_threshold=hard_thresh, threshold_interval=thresh_interval,
                      eval_interval=eval_interval)
    trainer.train(trainset)

def load_data(library, data_folder, dataset, t, dt, model):
    x = np.load(data_folder + dataset + "/yobs_norm.npy")
    return SyntheticDataset(x=x, t=t, library=library, dataset=dataset, dt=dt, model=model)

def main():
    start_time = time.time()

    # Globals
    data_folder = "../data/"
    model = "HyperSINDy"
    dt = 0.4897959183673469
    hidden_dim = 64
    stat_size = 250
    num_hidden = 5
    x_dim = 2
    adam_reg = 1e-2
    gamma_factor = 0.999
    include_constant = True
    poly_order = 3
    runs = "../runs/bayesian/"
    library = Library(n=x_dim, poly_order=poly_order, include_constant=include_constant)
    t = None
    seed = 0

    # Individual experiments
    set_random_seed(seed)
    run_name = "1"
    device = 0
    gamma_factor = 0.999999
    library = Library(n=x_dim, poly_order=poly_order, include_constant=include_constant)
    trainset = load_data(library, data_folder, "bayesian", t, dt, model)
    pipeline(library=library, trainset=trainset, epochs=3099, lr=5e-3, batch_size=50,
        lmda_init=1e-2, lmda_max=1e-2, lmda_max_epoch=1, lmda_spike=1e0, lmda_spike_epoch=1000,
        beta_init=0.01, beta_max=1.0, beta_max_epoch=100, beta_spike=25, beta_spike_epoch=1000,
        noise_dim=4, thresh_interval=500, hard_thresh=0.05, batch_norm=False, eval_interval=50,
        adam_reg=adam_reg, gamma_factor=gamma_factor, runs=runs, run_name=run_name, device=device,
        hidden_dim=hidden_dim, stat_size=stat_size, num_hidden=num_hidden)

    end_time = time.time()
    with open("../results/timer.txt", "w") as f:
        print(str(end_time - start_time), file=f)    

if __name__ == "__main__":
    main()