import os
import yaml
import random
import torch
import numpy as np
from pyDOE import lhs
from torch.utils.data import TensorDataset, DataLoader

k0 = 4.

def set_random_seed(seed=42):
    """
    Set random seed for reproducibility
    """
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def get_root_dir():
    """
    Get the root directory of the current file
    """
    return os.path.dirname(os.path.abspath(os.path.realpath(__file__)))


def parse_config(filename):
    """
    Parse the config file and return a Config object
    """
    filepath = os.path.join(get_root_dir(), "config", f"{filename}.yaml")
    assert os.path.exists(filepath), f"config path {filepath} does not exist. Please pass in a string that represents the file path to the config yaml."
    with open(filepath, 'r', encoding="utf-8") as f:
        config_data = yaml.load(f, Loader=yaml.FullLoader)
        
    return type("Config", (object,), config_data)
    
    
def generate_train_data(num_f=500, batch_size_f=100):
    """
    Generate training data
    """
    x_f = lhs(1, num_f)
    x_f = torch.tensor(x_f, dtype=torch.float32) * torch.pi
    x_b = torch.vstack((torch.tensor([0.0]), torch.tensor([torch.pi])))
    y_b = torch.zeros_like(x_b)
    dataset_f = TensorDataset(x_f)
    dataloader_f = DataLoader(dataset_f, batch_size=batch_size_f, shuffle=True)
    return dataloader_f, x_f, x_b, y_b


def generate_test_data(num=5000):
    """
    Generate test data
    """
    x_test = torch.linspace(0, torch.pi, num).reshape(-1, 1)
    y_test = torch.sin(k0 * x_test)
    return x_test, y_test
    
    
def calculate_residuals(model, x_f, x_b, y_b):
    """
    Calculate the residuals
    """
    f_xy = k0**2 * torch.sin(k0 * x_f)
    y_b_pred = model(x_b)
    
    x_f.requires_grad_(True)
    u = torch.autograd.grad(model(x_f), x_f, grad_outputs=torch.ones_like(model(x_f)), create_graph=True)
    u_x = u[0][:, 0].unsqueeze(-1)
    u_xx = torch.autograd.grad(u_x, x_f, grad_outputs=torch.ones_like(u_x), create_graph=True)[0][:, 0].unsqueeze(-1)

    sk = - u_xx - f_xy
    hk = y_b_pred - y_b
    
    residual_s = torch.sqrt(torch.tensor(1 / x_f.size(0), dtype=torch.float32)) * sk
    residual_h = torch.sqrt(torch.tensor(1 / x_b.size(0), dtype=torch.float32)) * hk
    residual_all = torch.cat([residual_s, residual_h], dim=0)
    
    loss_all = torch.cat([sk, hk], dim=0)
    return residual_all, loss_all