import os
import datetime
import torch
import tqdm
import matplotlib.pyplot as plt
from util import set_random_seed, parse_config, generate_train_data, generate_test_data, calculate_residuals
from optimizer import NGDOptimizer
from model import NGDModel


if __name__ == "__main__":
    config_list = ["ngd1", "ngd2", "ngd3", "ngd4", "ngd5", "ngd6"]
    for cfg in config_list:
        Config = parse_config("poisson_2d_" + cfg)
        random_seed = getattr(Config, "random_seed", 42)
        num_f = getattr(Config, "num_f", 1000)
        num_b = getattr(Config, "num_b", 50)
        batch_size = getattr(Config, "batch_size", 100)
        num_test = getattr(Config, "num_test", 100)
        epochs = getattr(Config, "epochs", 200)
        lr = getattr(Config, "lr", 0.01)
        archicture = getattr(Config, "architecture", [2, 128, 128, 1])
        activation = getattr(Config, "activation", "tanh")
        
        timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
        save_path_pic1 = os.path.join(os.path.dirname(os.path.abspath(os.path.realpath(__file__))), "result", f"{timestamp}_pcolormesh_{cfg}.png")
        save_path_pic2 = os.path.join(os.path.dirname(os.path.abspath(os.path.realpath(__file__))), "result", f"{timestamp}_contourf_{cfg}.png")
        save_path_ckpt = os.path.join(os.path.dirname(os.path.abspath(os.path.realpath(__file__))), "result", f"{timestamp}_{cfg}.pth")
        
        set_random_seed(random_seed)
        dataloader_f, x_f, x_b, y_b = generate_train_data(num_f, num_b, batch_size)
        x_test, y_test = generate_test_data(num_test)
        model = NGDModel(archicture, activation)
        optimizer = NGDOptimizer(model, lr)
        
        pbar = tqdm.tqdm(range(epochs))
        loss_history = []
        for epoch in pbar:
            for _, (dat_x,) in enumerate(dataloader_f):
                optimizer.step(dat_x, x_b, y_b)
            
            if (epoch + 1) % 10 == 0:
                _, loss = calculate_residuals(model, x_f, x_b, y_b)
                loss = torch.mean(loss ** 2).item()
                pbar.set_postfix(loss=loss)
                loss_history.append(loss)
        
        model.eval()
        with torch.no_grad():
            y_pred = model(x_test)
            test_error = torch.mean(torch.square(y_test-y_pred)) / torch.mean(torch.square(y_test))
            test_error = torch.sqrt(test_error)
            print(f"L2 error: {test_error}")
        
        pred = model(x_test).detach().cpu()
        fig, axs = plt.subplots(1, 3, figsize=(18, 6))
        X1 = x_test[:, 0:1].reshape(num_test, num_test)
        X2 = x_test[:, 1:2].reshape(num_test, num_test)
        cs1 = axs[0].pcolormesh(X1, X2, torch.reshape(pred, X1.shape))
        cbar1 = fig.colorbar(cs1, ax=axs[0])
        axs[0].set_title("Prediction")
        axs[0].set_xlabel('x')
        axs[0].set_ylabel('y')
        cs2 = axs[1].pcolormesh(X1, X2, y_test.reshape(num_test, num_test))
        cbar2 = fig.colorbar(cs2, ax=axs[1])
        axs[1].set_title("Exact solution")
        axs[1].set_xlabel('x')
        axs[1].set_ylabel('y')
        cs3 = axs[2].pcolormesh(X1, X2, torch.abs(pred - y_test).reshape(num_test, num_test))
        cbar3 = fig.colorbar(cs3, ax=axs[2])
        axs[2].set_title("Absolute error")
        axs[2].set_xlabel('x')
        axs[2].set_ylabel('y')
        plt.tight_layout()
        plt.savefig(save_path_pic1)
        plt.show()
        
        fig, axs = plt.subplots(1, 3, figsize=(18, 6))
        X1 = x_test[:, 0:1].reshape(num_test, num_test)
        X2 = x_test[:, 1:2].reshape(num_test, num_test)
        cs1 = axs[0].contourf(X1, X2, torch.reshape(pred, X1.shape))
        cbar1 = fig.colorbar(cs1, ax=axs[0])
        axs[0].set_title("Prediction")
        axs[0].set_xlabel('x')
        axs[0].set_ylabel('y')
        cs2 = axs[1].contourf(X1, X2, y_test.reshape(num_test, num_test))
        cbar2 = fig.colorbar(cs2, ax=axs[1])
        axs[1].set_title("Exact solution")
        axs[1].set_xlabel('x')
        axs[1].set_ylabel('y')
        cs3 = axs[2].contourf(X1, X2, torch.abs(pred - y_test).reshape(num_test, num_test))
        cbar3 = fig.colorbar(cs3, ax=axs[2])
        axs[2].set_title("Absolute error")
        axs[2].set_xlabel('x')
        axs[2].set_ylabel('y')
        plt.tight_layout()
        plt.savefig(save_path_pic2)
        plt.show()
        
        result_dict = {
            "model_state_dict": model.state_dict(),
            "loss_history": loss_history,
            "l2_error": test_error
        }
        torch.save(result_dict, save_path_ckpt)
        