import os
import datetime
import torch
import tqdm
import matplotlib.pyplot as plt
from util import set_random_seed, parse_config, generate_data, calculate_residuals
from torch.optim import LBFGS
from model import LBFGSModel
device = 'cuda' if torch.cuda.is_available() else 'cpu'


if __name__ == "__main__":
    Config = parse_config("poisson_10d_lbfgs")
    random_seed = getattr(Config, "random_seed", 42)
    num_f = getattr(Config, "num_f", 3000)
    num_b = getattr(Config, "num_b", 500)
    batch_size = getattr(Config, "batch_size", 100)
    num_test = getattr(Config, "num_test", 100)
    epochs = getattr(Config, "epochs", 200)
    iterations = getattr(Config, "iterations", 50)
    lr = getattr(Config, "lr", 1)
    archicture = getattr(Config, "architecture", [5, 128, 1])
    activation = getattr(Config, "activation", "tanh")
    
    timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
    save_path_pic1 = os.path.join(os.path.dirname(os.path.abspath(os.path.realpath(__file__))), "result", f"{timestamp}_pcolormesh_lbfgs.png")
    save_path_pic2 = os.path.join(os.path.dirname(os.path.abspath(os.path.realpath(__file__))), "result", f"{timestamp}_contourf_lbfgs.png")
    save_path_ckpt = os.path.join(os.path.dirname(os.path.abspath(os.path.realpath(__file__))), "result", f"{timestamp}_lbfgs.pth")
    
    set_random_seed(random_seed)
    dataloader_f, x_f, x_b, y_b, x_test, y_test = generate_data(num_f, num_b, batch_size, num_test)
    x_f = x_f.to(device)
    x_b = x_b.to(device)
    y_b = y_b.to(device)
    model = LBFGSModel(archicture, activation).to(device)
    optimizer = LBFGS(model.parameters(), lr=lr, max_iter=iterations)
    
    pbar = tqdm.tqdm(range(epochs))
    loss_history = []
    for epoch in pbar:
        for _, (dat_x,) in enumerate(dataloader_f):
            dat_x = dat_x.to(device)
            indices = torch.randperm(500)[:100]
            data_x_b = x_b[indices]
            data_y_b = y_b[indices]
            def closure():
                optimizer.zero_grad()
                _, loss_all = calculate_residuals(model, dat_x, data_x_b, data_y_b) #batch data
                #_, loss_all = calculate_residuals(model, x_f, x_b, y_b) #full data
                loss_all = torch.mean(loss_all**2)
                loss_all.backward(retain_graph=True)
                return loss_all
            
            optimizer.step(closure)
        
        if (epoch + 1) % 1 == 0:
            _, loss = calculate_residuals(model, x_f, x_b, y_b)
            loss = torch.mean(loss ** 2).item()
            pbar.set_postfix(loss=loss)
            loss_history.append(loss)
    
    print('memory allocation:',torch.cuda.max_memory_allocated()/(1024*1024),'MB')
    
    model.eval()
    with torch.no_grad():
        y_pred = model(x_test)
        test_error = torch.mean(torch.square(y_test-y_pred)) / torch.mean(torch.square(y_test))
        test_error = torch.sqrt(test_error)
        print(f"L2 error: {test_error}")
    
    result_dict = {
        "model_state_dict": model.state_dict(),
        "loss_history": loss_history,
        "l2_error": test_error
    }
    torch.save(result_dict, save_path_ckpt)
    
    torch.cuda.empty_cache()