import os
import yaml
import random
import torch
import numpy as np
from pyDOE import lhs
from torch.utils.data import TensorDataset, DataLoader


def set_random_seed(seed=42):
    """
    Set random seed for reproducibility
    """
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def get_root_dir():
    """
    Get the root directory of the current file
    """
    return os.path.dirname(os.path.abspath(os.path.realpath(__file__)))


def parse_config(filename):
    """
    Parse the config file and return a Config object
    """
    filepath = os.path.join(get_root_dir(), "config", f"{filename}.yaml")
    assert os.path.exists(filepath), f"config path {filepath} does not exist. Please pass in a string that represents the file path to the config yaml."
    with open(filepath, 'r', encoding="utf-8") as f:
        config_data = yaml.load(f, Loader=yaml.FullLoader)
        
    return type("Config", (object,), config_data)
    
    
def generate_train_data(num_f=1000, num_b=50, batch_size_f=100):
    """
    Generate training data
    """
    xx = np.linspace(0, 1, num_b)
    X1, X2 = np.meshgrid(xx, xx)
    x_b1 = np.hstack((X1[0:1, :].T, X2[0:1, :].T))
    x_b2 = np.hstack((X1[:, 0:1], X2[:, 0:1]))
    x_b3 = np.hstack((X1[-1:, :].T, X2[-1:, :].T))
    x_b4 = np.hstack((X1[:, -1:], X2[:, -1:]))
    x_b = np.vstack([x_b1, x_b2, x_b3, x_b4])
    x_b = torch.tensor(x_b, dtype=torch.float32)
    y_b = torch.sin(torch.pi * x_b[:, 0:1]) * torch.sin(torch.pi * x_b[:, 1:2])
    x_f = lhs(2, num_f)
    x_f = torch.tensor(x_f, dtype=torch.float32)
    dataset_f = TensorDataset(x_f)
    dataloader_f = DataLoader(dataset_f, batch_size=batch_size_f, shuffle=True)
    return dataloader_f, x_f, x_b, y_b


def generate_test_data(num=100):
    """
    Generate test data
    """
    xx = np.linspace(0, 1, num)
    X1, X2 = np.meshgrid(xx, xx)
    x_test = np.hstack((X1.flatten()[:, None], X2.flatten()[:, None]))
    x_test = torch.tensor(x_test, dtype=torch.float32)
    y_test = torch.sin(torch.pi * x_test[:, 0:1]) * torch.sin(torch.pi * x_test[:, 1:2])
    return x_test, y_test


def calculate_residuals(model, x_f, x_b, y_b):
    """
    Calculate the residuals
    """
    f_xy = 2 * torch.pi ** 2 * torch.sin(torch.pi * x_f[:, 0:1]) * torch.sin(torch.pi * x_f[:, 1:2])
    y_b_pred = model(x_b)
    
    x_f.requires_grad_(True)
    u = torch.autograd.grad(model(x_f), x_f, grad_outputs=torch.ones_like(model(x_f)), create_graph=True)
    u_x = u[0][:, 0].unsqueeze(-1)
    u_y = u[0][:, 1].unsqueeze(-1)    
    u_xx = torch.autograd.grad(u_x, x_f, grad_outputs=torch.ones_like(u_x), create_graph=True)[0][:, 0].unsqueeze(-1)
    u_yy = torch.autograd.grad(u_y, x_f, grad_outputs=torch.ones_like(u_y), create_graph=True)[0][:, 1].unsqueeze(-1)

    sk = - u_xx - u_yy - f_xy
    hk = y_b_pred - y_b
    
    residual_s = torch.sqrt(torch.tensor(1 / x_f.size(0), dtype=torch.float32)) * sk
    residual_h = torch.sqrt(torch.tensor(1 / x_b.size(0), dtype=torch.float32)) * hk
    residual_all = torch.cat([residual_s, residual_h], dim=0)
    
    loss_all = torch.cat([sk, hk], dim=0)
    return residual_all, loss_all