import os
import yaml
import random
import torch
import numpy as np
from pyDOE import lhs
from torch.utils.data import TensorDataset, DataLoader


def set_random_seed(seed=42):
    """
    Set random seed for reproducibility
    """
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def get_root_dir():
    """
    Get the root directory of the current file
    """
    return os.path.dirname(os.path.abspath(os.path.realpath(__file__)))


def parse_config(filename):
    """
    Parse the config file and return a Config object
    """
    filepath = os.path.join(get_root_dir(), "config", f"{filename}.yaml")
    assert os.path.exists(filepath), f"config path {filepath} does not exist. Please pass in a string that represents the file path to the config yaml."
    with open(filepath, 'r', encoding="utf-8") as f:
        config_data = yaml.load(f, Loader=yaml.FullLoader)
        
    return type("Config", (object,), config_data)


def generate_train_data(num_f=1000, num_b=50, batch_size_f=100):
    """
    Generate training data
    """
    xx = np.linspace(0, 1, num_b)
    X1, X2 = np.meshgrid(xx, xx)
    x_b1 = np.hstack((X1[:, 0:1], X2[:, 0:1]))
    y_b1 = np.sin(np.pi * x_b1[:, 1:2])
    x_b2 = np.hstack((X1[0:1, :].T, X2[0:1, :].T))
    y_b2 = np.zeros((x_b2.shape[0], 1))
    x_b3 = np.hstack((X1[0:1, :].T, X2[-1:, :].T))
    y_b3 = np.zeros((x_b3.shape[0], 1))
    x_b = np.vstack([x_b1, x_b2, x_b3])
    y_b = np.vstack([y_b1, y_b2, y_b3])
    x_b = torch.tensor(x_b, dtype=torch.float32)
    y_b = torch.tensor(y_b, dtype=torch.float32)
    
    x_f = lhs(2, num_f)
    x_f = torch.tensor(x_f, dtype=torch.float32)
    dataset_f = TensorDataset(x_f)
    dataloader_f = DataLoader(dataset_f, batch_size=batch_size_f, shuffle=True)
    return dataloader_f, x_f, x_b, y_b


def generate_test_data(num=100):
    """
    Generate test data
    """
    xx = np.linspace(0, 1, num)
    X1, X2 = np.meshgrid(xx, xx)
    x_test = np.hstack((X1.flatten()[:, None], X2.flatten()[:, None]))
    x_test = torch.tensor(x_test, dtype=torch.float32)
    y_test = torch.exp(-torch.pi ** 2 * x_test[:, 0:1] / 4) * torch.sin(torch.pi * x_test[:, 1:2])
    return x_test, y_test


def calculate_residuals(model, x_f, x_b, y_b):
    """
    Calculate the residuals
    """
    y_b_pred = model(x_b)
    
    x_f.requires_grad_(True)
    u = torch.autograd.grad(model(x_f), x_f, grad_outputs=torch.ones_like(model(x_f)), create_graph=True)
    u_t = u[0][:, 0].unsqueeze(-1)
    u_x = u[0][:, 1].unsqueeze(-1)    
    u_xx = torch.autograd.grad(u_x, x_f, grad_outputs=torch.ones_like(u_x), create_graph=True)[0][:, 1].unsqueeze(-1)

    sk = u_t - u_xx / 4
    hk = y_b_pred - y_b
    
    residual_s = torch.sqrt(torch.tensor(1 / x_f.size(0), dtype=torch.float32)) * sk
    residual_h = torch.sqrt(torch.tensor(1 / x_b.size(0), dtype=torch.float32)) * hk
    residual_all = torch.cat([residual_s, residual_h], dim=0)
    
    loss_all = torch.cat([sk, hk], dim=0)
    return residual_all, loss_all
