import os
import torch
import tqdm
from util import set_random_seed, parse_config, generate_data, calculate_residuals
from torch.optim import Adam
from model import AdamModel
device = 'cuda' if torch.cuda.is_available() else 'cpu'


if __name__ == "__main__":
    Config = parse_config("poisson_10d_adam")
    random_seed = getattr(Config, "random_seed", 42)
    num_f = getattr(Config, "num_f", 10000)
    num_b = getattr(Config, "num_b", 200)
    batch_size = getattr(Config, "batch_size", 100)
    num_test = getattr(Config, "num_test", 20000)
    epochs = getattr(Config, "epochs", 200)
    lr = getattr(Config, "lr", 0.001)
    archicture = getattr(Config, "architecture", [10, 128, 1])
    activation = getattr(Config, "activation", "tanh")
    
    save_path_pic1 = os.path.join(os.path.dirname(os.path.abspath(os.path.realpath(__file__))), "result", "pcolormesh_adam.png")
    save_path_pic2 = os.path.join(os.path.dirname(os.path.abspath(os.path.realpath(__file__))), "result", "contourf_adam.png")
    save_path_ckpt = os.path.join(os.path.dirname(os.path.abspath(os.path.realpath(__file__))), "result", "ckpt_adam.pth")
    
    set_random_seed(random_seed)
    dataloader_f, x_f, x_b, y_b, x_test, y_test = generate_data(num_f, num_b, batch_size, num_test)
    x_f = x_f.to(device)
    x_b = x_b.to(device)
    y_b = y_b.to(device)
    model = AdamModel(archicture, activation).to(device)
    optimizer = Adam(model.parameters(), lr)
    
    pbar = tqdm.tqdm(range(epochs))
    loss_history = []
    for epoch in pbar:
        for _, (dat_x,) in enumerate(dataloader_f):
            dat_x = dat_x.to(device)
            indices = torch.randperm(1000)[:200]
            data_x_b = x_b[indices]
            data_y_b = y_b[indices]
            optimizer.zero_grad()
            _, loss_all = calculate_residuals(model, dat_x, data_x_b, data_y_b)
            loss_all = torch.mean(loss_all**2)
            loss_all.backward()
            optimizer.step()
        
        if (epoch + 1) % 10 == 0:
            _, loss = calculate_residuals(model, x_f, x_b, y_b)
            loss = torch.mean(loss ** 2).item()
            pbar.set_postfix(loss=loss)
            loss_history.append(loss)
            
    
    print('memory allocation:',torch.cuda.max_memory_allocated()/(1024*1024),'MB')
    
    model.eval()
    with torch.no_grad():
        y_pred = model(x_test)
        test_error = torch.mean(torch.square(y_test-y_pred)) / torch.mean(torch.square(y_test))
        test_error = torch.sqrt(test_error)
        print(f"L2 error: {test_error}")
    
    result_dict = {
        "model_state_dict": model.state_dict(),
        "loss_history": loss_history,
        "l2_error": test_error
    }
    torch.save(result_dict, save_path_ckpt)
    
    torch.cuda.empty_cache()