import torch


def sampled_pts_rdm(N_domain, N_boundary, domain, time_dependent = False):
    x1l, x1r = domain[0]
    x2l, x2r = domain[1]
    
    # interior nodes
    X_domain = torch.cat((torch.rand(N_domain, 1) * (x1r - x1l) + x1l,
                          torch.rand(N_domain, 1) * (x2r - x2l) + x2l), dim=1)
    
    # boundary points
    if not time_dependent:
        #(x,y) in [x1l,x1r]*[x2l,x2r] default = [0,1]*[0,1]
        N_boundary_per_bd = N_boundary // 4
        X_boundary = torch.zeros((N_boundary_per_bd * 4, 2))
        
        # bottom face
        X_boundary[:N_boundary_per_bd, 0] = torch.rand(N_boundary_per_bd) * (x1r - x1l) + x1l
        X_boundary[:N_boundary_per_bd, 1] = x2l
        # left face
        X_boundary[N_boundary_per_bd:2*N_boundary_per_bd, 0] = x1l
        X_boundary[N_boundary_per_bd:2*N_boundary_per_bd, 1] = torch.rand(N_boundary_per_bd) * (x2r - x2l) + x2l
        # top face
        X_boundary[2*N_boundary_per_bd:3*N_boundary_per_bd, 0] = torch.rand(N_boundary_per_bd) * (x1r - x1l) + x1l
        X_boundary[2*N_boundary_per_bd:3*N_boundary_per_bd, 1] = x2r
        # right face
        X_boundary[3*N_boundary_per_bd:, 0] = x1r
        X_boundary[3*N_boundary_per_bd:, 1] = torch.rand(N_boundary_per_bd) * (x2r - x2l) + x2l
    else:
        #(x,t) in [x1l,x1r]*[x2l,x2r] default = [-1,1]*(0,1]
        N_boundary_per_bd = N_boundary // 3
        X_boundary = torch.zeros((N_boundary_per_bd * 3, 2))

        # right face
        X_boundary[:N_boundary_per_bd, 0] = x1r
        X_boundary[:N_boundary_per_bd, 1] = torch.rand(N_boundary_per_bd) * (x2r - x2l) + x2l
        # bottom face
        X_boundary[N_boundary_per_bd:2*N_boundary_per_bd, 0] = torch.rand(N_boundary_per_bd) * (x1r - x1l) + x1l
        X_boundary[N_boundary_per_bd:2*N_boundary_per_bd, 1] = x2l
        # left face
        X_boundary[2*N_boundary_per_bd:, 0] = x1l
        X_boundary[2*N_boundary_per_bd:, 1] = torch.rand(N_boundary_per_bd) * (x2r - x2l) + x2l
        
    return X_domain, X_boundary

def sampled_pts_grid(N_domain, N_boundary, domain, time_dependent = False):
    x1l, x1r = domain[0]
    x2l, x2r = domain[1]
    
    N_pts = int(torch.sqrt(torch.tensor(N_domain + N_boundary)).item()) - 2
    xx = torch.linspace(x1l, x1r, N_pts + 2)
    yy = torch.linspace(x2l, x2r, N_pts + 2)
    XX, YY = torch.meshgrid(xx, yy, indexing='ij')
    
    if not time_dependent:

        XX_int = XX[1:-1, 1:-1]
        YY_int = YY[1:-1, 1:-1]
        
        XXv_bd = torch.cat((XX[:-1, 0], XX[0, :-1], XX[:-1, -1], XX[-1, :-1]))
        YYv_bd = torch.cat((YY[:-1, 0], YY[0, :-1], YY[:-1, -1], YY[-1, :-1]))
        
    else:

        XX_int = XX[1:-1, 1:]
        YY_int = YY[1:-1, 1:]
        
        XXv_bd = torch.cat((XX[-1,1:], XX[:, 0], XX[0, 1:]))
        YYv_bd = torch.cat((YY[-1,1:], YY[:, 0], YY[0, 1:]))
    
    # vectorized (x,y) coordinates
    XXv_int = XX_int.flatten().unsqueeze(1)
    YYv_int = YY_int.flatten().unsqueeze(1)
    
    XXv_bd = XXv_bd.unsqueeze(1)
    YYv_bd = YYv_bd.unsqueeze(1)
    
    X_domain = torch.cat((XXv_int, YYv_int), dim=1)
    X_boundary = torch.cat((XXv_bd, YYv_bd), dim=1)
    return X_domain, X_boundary