import os
import datetime
import torch
import tqdm
import matplotlib.pyplot as plt
from util import set_random_seed, parse_config, generate_train_data, generate_test_data, calculate_residuals
from torch.optim import SGD
from model import SGDModel


if __name__ == "__main__":
    Config = parse_config("helmholtz_2d_sgd")
    random_seed = getattr(Config, "random_seed", 42)
    num_f = getattr(Config, "num_f", 1000)
    num_b = getattr(Config, "num_b", 50)
    batch_size = getattr(Config, "batch_size", 100)
    num_test = getattr(Config, "num_test", 100)
    epochs = getattr(Config, "epochs", 200)
    lr = getattr(Config, "lr", 0.01)
    archicture = getattr(Config, "architecture", [2, 128, 128, 1])
    activation = getattr(Config, "activation", "tanh")
    
    timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
    save_path_pic1 = os.path.join(os.path.dirname(os.path.abspath(os.path.realpath(__file__))), "result", f"{timestamp}_pcolormesh_sgd.png")
    save_path_pic2 = os.path.join(os.path.dirname(os.path.abspath(os.path.realpath(__file__))), "result", f"{timestamp}_contourf_sgd.png")
    save_path_ckpt = os.path.join(os.path.dirname(os.path.abspath(os.path.realpath(__file__))), "result", f"{timestamp}_sgd.pth")
    
    set_random_seed(random_seed)
    dataloader_f, x_f, x_b, y_b = generate_train_data(num_f, num_b, batch_size)
    x_test, y_test = generate_test_data(num_test)
    model = SGDModel(archicture, activation)
    optimizer = SGD(model.parameters(), lr)
    
    pbar = tqdm.tqdm(range(epochs))
    loss_history = []
    for epoch in pbar:
        for _, (dat_x,) in enumerate(dataloader_f):
            optimizer.zero_grad()
            _, loss_all = calculate_residuals(model, dat_x, x_b, y_b)
            loss_all = torch.mean(loss_all**2)
            loss_all.backward()
            optimizer.step()
        
        if (epoch + 1) % 10 == 0:
            _, loss = calculate_residuals(model, x_f, x_b, y_b)
            loss = torch.mean(loss ** 2).item()
            pbar.set_postfix(loss=loss)
            loss_history.append(loss)
    
    model.eval()
    with torch.no_grad():
        y_pred = model(x_test)
        test_error = torch.mean(torch.square(y_test-y_pred)) / torch.mean(torch.square(y_test))
        test_error = torch.sqrt(test_error)
        print(f"L2 error: {test_error}")
    
    pred = model(x_test).detach().cpu()
    fig, axs = plt.subplots(1, 3, figsize=(18, 6))
    X1 = x_test[:, 0:1].reshape(num_test, num_test)
    X2 = x_test[:, 1:2].reshape(num_test, num_test)
    cs1 = axs[0].pcolormesh(X1, X2, pred.reshape(num_test, num_test))
    cbar1 = fig.colorbar(cs1, ax=axs[0])
    axs[0].set_title("Prediction")
    axs[0].set_xlabel('x')
    axs[0].set_ylabel('y')
    cs2 = axs[1].pcolormesh(X1, X2, y_test.reshape(num_test, num_test))
    cbar2 = fig.colorbar(cs2, ax=axs[1])
    axs[1].set_title("Exact solution")
    axs[1].set_xlabel('x')
    axs[1].set_ylabel('y')
    cs3 = axs[2].pcolormesh(X1, X2, torch.abs(pred - y_test).reshape(num_test, num_test))
    cbar3 = fig.colorbar(cs3, ax=axs[2])
    axs[2].set_title("Absolute error")
    axs[2].set_xlabel('x')
    axs[2].set_ylabel('y')
    plt.tight_layout()
    plt.savefig(save_path_pic1)
    plt.show()
    
    fig, axs = plt.subplots(1, 3, figsize=(18, 6))
    X1 = x_test[:, 0:1].reshape(num_test, num_test)
    X2 = x_test[:, 1:2].reshape(num_test, num_test)
    cs1 = axs[0].contourf(X1, X2, torch.reshape(pred, X1.shape))
    cbar1 = fig.colorbar(cs1, ax=axs[0])
    axs[0].set_title("Prediction")
    axs[0].set_xlabel('x')
    axs[0].set_ylabel('y')
    cs2 = axs[1].contourf(X1, X2, y_test.reshape(num_test, num_test))
    cbar2 = fig.colorbar(cs2, ax=axs[1])
    axs[1].set_title("Exact solution")
    axs[1].set_xlabel('x')
    axs[1].set_ylabel('y')
    cs3 = axs[2].contourf(X1, X2, torch.abs(pred - y_test).reshape(num_test, num_test))
    cbar3 = fig.colorbar(cs3, ax=axs[2])
    axs[2].set_title("Absolute error")
    axs[2].set_xlabel('x')
    axs[2].set_ylabel('y')
    plt.tight_layout()
    plt.savefig(save_path_pic2)
    plt.show()
    
    result_dict = {
        "model_state_dict": model.state_dict(),
        "loss_history": loss_history,
        "l2_error": test_error
    }
    torch.save(result_dict, save_path_ckpt)
    