import os
import torch
import tqdm
from util import set_random_seed, parse_config, generate_data, calculate_residuals
from optimizer import NGDOptimizer
from model import NGDModel


if __name__ == "__main__":
    Config = parse_config("poisson_10d_ngd")
    random_seed = getattr(Config, "random_seed", 42)
    num_f = getattr(Config, "num_f", 10000)
    num_b = getattr(Config, "num_b", 200)
    batch_size = getattr(Config, "batch_size", 100)
    num_test = getattr(Config, "num_test", 20000)
    epochs = getattr(Config, "epochs", 200)
    lr = getattr(Config, "lr", 0.01)
    archicture = getattr(Config, "architecture", [10, 128, 1])
    activation = getattr(Config, "activation", "tanh")
    
    save_path_pic1 = os.path.join(os.path.dirname(os.path.abspath(os.path.realpath(__file__))), "result", "pcolormesh_ngd.png")
    save_path_pic2 = os.path.join(os.path.dirname(os.path.abspath(os.path.realpath(__file__))), "result", "contourf_ngd.png")
    save_path_ckpt = os.path.join(os.path.dirname(os.path.abspath(os.path.realpath(__file__))), "result", "ckpt_ngd.pth")
    
    set_random_seed(random_seed)
    dataloader_f, x_f, x_b, y_b, x_test, y_test = generate_data(num_f, num_b, batch_size, num_test)
    model = NGDModel(archicture, activation)
    optimizer = NGDOptimizer(model, lr)
    #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 20, gamma=0.5)
    
    pbar = tqdm.tqdm(range(epochs))
    loss_history = []
    for epoch in pbar:
        for _, (dat_x,) in enumerate(dataloader_f):
            indices = torch.randperm(1000)[:200]
            data_x_b = x_b[indices]
            data_y_b = y_b[indices]
            optimizer.step(dat_x, data_x_b, data_y_b)
        
        #scheduler.step()
        if (epoch + 1) % 1 == 0:
            _, loss = calculate_residuals(model, x_f, x_b, y_b)
            loss = torch.mean(loss ** 2).item()
            pbar.set_postfix(loss=loss)
            loss_history.append(loss)
    
    model.eval()
    with torch.no_grad():
        y_pred = model(x_test)
        test_error = torch.mean(torch.square(y_test-y_pred)) / torch.mean(torch.square(y_test))
        test_error = torch.sqrt(test_error)
        print(f"L2 error: {test_error}")
    
    result_dict = {
        "model_state_dict": model.state_dict(),
        "loss_history": loss_history,
        "l2_error": test_error
    }
    torch.save(result_dict, save_path_ckpt)
    