import os
import datetime
import torch
import tqdm
import matplotlib.pyplot as plt
from util import set_random_seed, parse_config, generate_train_data, generate_test_data, calculate_residuals
from optimizer import NGDOptimizer
from model import NGDModel


if __name__ == "__main__":
    Config = parse_config("poisson_1d_ngd")
    random_seed = getattr(Config, "random_seed", 42)
    num_f = getattr(Config, "num_f", 500)
    batch_size = getattr(Config, "batch_size", 100)
    num_test = getattr(Config, "num_test", 5000)
    epochs = getattr(Config, "epochs", 100)
    lr = getattr(Config, "lr", 0.1)
    archicture = getattr(Config, "architecture", [1, 128, 1])
    activation = getattr(Config, "activation", "tanh")
    
    timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
    save_path_pic = os.path.join(os.path.dirname(os.path.abspath(os.path.realpath(__file__))), "result", f"{timestamp}_ngd.png")
    save_path_ckpt = os.path.join(os.path.dirname(os.path.abspath(os.path.realpath(__file__))), "result", f"{timestamp}_ngd.pth")
    
    set_random_seed(random_seed)
    dataloader_f, x_f, x_b, y_b = generate_train_data(num_f, batch_size)
    x_test, y_test = generate_test_data(num_test)
    model = NGDModel(archicture, activation)
    optimizer = NGDOptimizer(model, lr)
    
    pbar = tqdm.tqdm(range(epochs))
    loss_history = []
    for epoch in pbar:
        for _, (dat_x,) in enumerate(dataloader_f):
            optimizer.step(dat_x, x_b, y_b)
        
        if (epoch + 1) % 1 == 0:
            _, loss = calculate_residuals(model, x_f, x_b, y_b)
            loss = torch.mean(loss**2).item()
            pbar.set_postfix(loss=loss)
            loss_history.append(loss)
    
    model.eval()
    with torch.no_grad():
        y_pred = model(x_test)
        test_error = torch.mean(torch.square(y_test-y_pred)) / torch.mean(torch.square(y_test))
        test_error = torch.sqrt(test_error)
        print(f"L2 error: {test_error}")
    
# =============================================================================
#     plt.plot(x_test, y_test, ls="-", label="Reference solution")
#     plt.plot(x_test, y_pred, ls="--", label="Predcited solution")
#     plt.legend()
#     plt.savefig(save_path_pic)
#     plt.show()
# =============================================================================
    
    result_dict = {
        "model_state_dict": model.state_dict(),
        "loss_history": loss_history,
        "l2_error": test_error
    }
    torch.save(result_dict, save_path_ckpt)
    
    loss_history_theory = []
    for k in range(len(loss_history)):
        loss_history_theory.append(loss_history[0]*(1-lr)**k)
    plt.figure(1)
    plt.plot(loss_history, label="experimental decay")
    plt.plot(loss_history_theory, label="theoretical decay")
    #plt.ylim(1e-6, 1e+3)
    plt.yscale("log")
    plt.legend()
    plt.xlabel("epochs")
    plt.ylabel("training loss")
    plt.show()
    
    # compute the smallest and largest eigenvalue of Gram matrix.
    from util import calculate_residuals

    sk_all, _ = calculate_residuals(model, x_f, x_b, y_b)    
    gradients = []
    for i in range(sk_all.size(0)):
        sk_i = sk_all[i:i+1, :]
        grad_i = torch.autograd.grad(sk_i, model.parameters(), retain_graph=True)
        grad_flat = torch.cat([g.view(-1) for g in grad_i])
        gradients.append(grad_flat)   
    Jk = torch.stack(gradients)  
    
    U, S, Vt = torch.linalg.svd(Jk, full_matrices=False)
    lambda_min = S[-1]**2
    lambda_max = S[0]**2
    print(lambda_min, lambda_max)
    
    #H = Jk @Jk.T
    #eigvals = torch.linalg.eigvalsh(H)    
    #lambda_min = eigvals[0]
    #lambda_max = eigvals[-1]    
    #print(lambda_min, lambda_max)