import os
import yaml
import random
import torch
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#device = 'cpu'

def set_random_seed(seed=42):
    """
    Set random seed for reproducibility
    """
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def get_root_dir():
    """
    Get the root directory of the current file
    """
    return os.path.dirname(os.path.abspath(os.path.realpath(__file__)))


def parse_config(filename):
    """
    Parse the config file and return a Config object
    """
    filepath = os.path.join(get_root_dir(), "config", f"{filename}.yaml")
    assert os.path.exists(filepath), f"config path {filepath} does not exist. Please pass in a string that represents the file path to the config yaml."
    with open(filepath, 'r', encoding="utf-8") as f:
        config_data = yaml.load(f, Loader=yaml.FullLoader)
        
    return type("Config", (object,), config_data)


def exact_u(x):
    u_exact = torch.zeros((x.shape[0], 1)).to(device)
    for i in range(10):
        u_exact += torch.sin(torch.pi * x[:, i:i+1])
    return u_exact


def exact_g(x):
    g_exact = torch.pi**2 * exact_u(x)
    return g_exact


def generate_data(num_f=10000, num_b=1000, batch_size_f=100, num_test=10000):
    """
    Generate training data
    """
    x_f = torch.rand(num_f, 10)
    
    # Create training dataset and dataloader
    dataset_f = TensorDataset(x_f)
    dataloader_f = DataLoader(dataset_f, batch_size=batch_size_f, shuffle=True)

    x_bound = torch.rand(num_b, 10).to(device)
    for i in range(10):
        x_bound[i*100: i*100+50, i:i+1] = torch.full((50, 1), 0).to(device)
        x_bound[i*100+50: (i+1)*100, i:i+1] = torch.full((50, 1), 1).to(device)
    y_bound = exact_u(x_bound)

    x_test = torch.rand(num_test, 10).to(device)
    y_test = exact_u(x_test).to(device)
    
    return dataloader_f, x_f, x_bound, y_bound, x_test, y_test


def calculate_residuals(model, x_f, x_b, y_b):
    """
    Calculate the residuals
    """
    x_f.requires_grad_(True)
    u_xx = torch.zeros((x_f.shape[0], 1)).to(device)
    u = torch.autograd.grad(model(x_f), x_f, grad_outputs=torch.ones_like(model(x_f)), create_graph=True)
    for idx in range(10):
        u_x_tmp = u[0][:, idx].unsqueeze(-1)
        u_xx_tmp = torch.autograd.grad(u_x_tmp, x_f, grad_outputs=torch.ones_like(u_x_tmp), create_graph=True)[0][:, idx].unsqueeze(-1)
        u_xx += u_xx_tmp
    g = exact_g(x_f).to(device)
    y_b_pred = model(x_b)
    
    sk = - u_xx - g
    hk = y_b_pred - y_b
    
    residual_s = torch.sqrt(torch.tensor(1 / x_f.size(0), dtype=torch.float32)) * sk
    residual_h = torch.sqrt(torch.tensor(1 / x_b.size(0), dtype=torch.float32)) * hk
    residual_all = torch.cat([residual_s, residual_h], dim=0)
    
    loss_all = torch.cat([sk, hk], dim=0)
    return residual_all, loss_all
