import os
import datetime
import torch
import tqdm
import matplotlib.pyplot as plt
from util import set_random_seed, parse_config, generate_train_data, generate_test_data, calculate_residuals
from torch.optim import SGD
from model import SGDModel


if __name__ == "__main__":
    Config = parse_config("poisson_1d_sgd")
    random_seed = getattr(Config, "random_seed", 42)
    num_f = getattr(Config, "num_f", 500)
    batch_size = getattr(Config, "batch_size", 100)
    num_test = getattr(Config, "num_test", 5000)
    epochs = getattr(Config, "epochs", 100)
    lr = getattr(Config, "lr", 0.1)
    archicture = getattr(Config, "architecture", [1, 128, 1])
    activation = getattr(Config, "activation", "tanh")
    
    timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
    save_path_pic = os.path.join(os.path.dirname(os.path.abspath(os.path.realpath(__file__))), "result", f"{timestamp}_sgd.png")
    save_path_ckpt = os.path.join(os.path.dirname(os.path.abspath(os.path.realpath(__file__))), "result", f"{timestamp}_sgd.pth")
    
    set_random_seed(random_seed)
    dataloader_f, x_f, x_b, y_b = generate_train_data(num_f, batch_size)
    x_test, y_test = generate_test_data(num_test)
    model = SGDModel(archicture, activation)
    optimizer = SGD(model.parameters(), lr)
    
    pbar = tqdm.tqdm(range(epochs))
    loss_history = []
    for epoch in pbar:
        for _, (dat_x,) in enumerate(dataloader_f):
            optimizer.zero_grad()
            _, loss_all = calculate_residuals(model, x_f, x_b, y_b)
            loss_all = torch.mean(loss_all**2)
            loss_all.backward()
            optimizer.step()
        
        if (epoch + 1) % 10 == 0:
            _, loss = calculate_residuals(model, x_f, x_b, y_b)
            loss = torch.mean(loss**2).item()
            pbar.set_postfix(loss=loss)
            loss_history.append(loss)
    
    model.eval()
    with torch.no_grad():
        y_pred = model(x_test)
        test_error = torch.mean(torch.square(y_test-y_pred)) / torch.mean(torch.square(y_test))
        test_error = torch.sqrt(test_error)
        print(f"L2 error: {test_error}")
    
    plt.plot(x_test, y_test, ls="-", label="Reference solution")
    plt.plot(x_test, y_pred, ls="--", label="Predcited solution")
    plt.legend()
    plt.savefig(save_path_pic)
    plt.show()
    
    result_dict = {
        "model_state_dict": model.state_dict(),
        "loss_history": loss_history,
        "l2_error": test_error
    }
    torch.save(result_dict, save_path_ckpt)
    