import torch
from torch_geometric.data import Data


def csbms_graph(args):
    ### 5:5 ###
    q0 = 0
    q1 = 2750
    q2 = 5000
    q3 = 7750
    q4 = 10000

    ### 2:8 ###
    # q0 = 0
    # q1 = 560
    # q2 = 2000
    # q3 = 8560
    # q4 = 10000
    
    """
    initialize Y by assigning half 0 and half 1
    """
    y = torch.zeros((args.num_nodes, ), dtype=torch.long)
    y[q2:] = 1
    
    """
    sample X
    """
    x0_sampler = torch.distributions.normal.Normal(0, 1)
    x1_sampler = torch.distributions.normal.Normal(0, args.sigma)
        
    x0 = x0_sampler.sample((y[y==0].shape[0], 1)).to(args.device)
    x1 = x1_sampler.sample((y[y==1].shape[0], 1)).to(args.device)

    """
    initialize S
    """
    s = torch.zeros((args.num_nodes, ), dtype=torch.long).to(args.device)
    s[q1:q3] = 1

    sens_x_sampler = torch.distributions.normal.Normal(0, 1)
        
    sens_x = sens_x_sampler.sample((args.num_nodes, 1)).to(args.device)


    """
    h_p_ij to edge sampling prob
    """
    x00_prob = torch.full((q1, q1), 1.0).to(args.device)
    x01_prob = torch.full((q1, q2-q1), 1.0).to(args.device)
    x02_prob = torch.full((q1, q3-q2), 1.0).to(args.device)
    x03_prob = torch.full((q1, q4-q3), 1.0).to(args.device)

    x10_prob = torch.full((q2-q1, q1), 1.0).to(args.device)
    x11_prob = torch.full((q2-q1, q2-q1), 1.0).to(args.device)
    x12_prob = torch.full((q2-q1, q3-q2), 1.0).to(args.device)
    x13_prob = torch.full((q2-q1, q4-q3), 1.0).to(args.device)

    x20_prob = torch.full((q3-q2, q1), 1.0).to(args.device)
    x21_prob = torch.full((q3-q2, q2-q1), 1.0).to(args.device)
    x22_prob = torch.full((q3-q2, q3-q2), 1.0).to(args.device)
    x23_prob = torch.full((q3-q2, q4-q3), 1.0).to(args.device)

    x30_prob = torch.full((q4-q3, q1), 1.0).to(args.device)
    x31_prob = torch.full((q4-q3, q2-q1), 1.0).to(args.device)
    x32_prob = torch.full((q4-q3, q3-q2), 1.0).to(args.device)
    x33_prob = torch.full((q4-q3, q4-q3), 1.0).to(args.device)

    # prob = 0 for self-loop
    for i in range(x00_prob.shape[0]): x00_prob[i, i] = 0
    for i in range(x11_prob.shape[0]): x11_prob[i, i] = 0
    for i in range(x22_prob.shape[0]): x22_prob[i, i] = 0
    for i in range(x33_prob.shape[0]): x33_prob[i, i] = 0
        
    """
    weighted sampling of directed edges without replacement
    """
    idx_00 = torch.multinomial(x00_prob, args.d_pp, replacement=False)
    idx_11 = torch.multinomial(x11_prob, args.d_pp, replacement=False)
    idx_22 = torch.multinomial(x22_prob, args.d_pp, replacement=False)
    idx_33 = torch.multinomial(x33_prob, args.d_pp, replacement=False)

    idx_01 = torch.multinomial(x01_prob, args.d_pn, replacement=False)
    idx_10 = torch.multinomial(x10_prob, args.d_pn, replacement=False)
    idx_23 = torch.multinomial(x23_prob, args.d_pn, replacement=False)
    idx_32 = torch.multinomial(x32_prob, args.d_pn, replacement=False)
    
    idx_02 = torch.multinomial(x02_prob, args.d_nn, replacement=False)
    idx_13 = torch.multinomial(x13_prob, args.d_nn, replacement=False)
    idx_20 = torch.multinomial(x20_prob, args.d_nn, replacement=False)
    idx_31 = torch.multinomial(x31_prob, args.d_nn, replacement=False)

    idx_03 = torch.multinomial(x03_prob, args.d_np, replacement=False)
    idx_12 = torch.multinomial(x12_prob, args.d_np, replacement=False)
    idx_21 = torch.multinomial(x21_prob, args.d_np, replacement=False)
    idx_30 = torch.multinomial(x30_prob, args.d_np, replacement=False)

    edge_00 = torch.zeros_like(x00_prob)
    edge_01 = torch.zeros_like(x01_prob)
    edge_02 = torch.zeros_like(x02_prob)
    edge_03 = torch.zeros_like(x03_prob)

    edge_10 = torch.zeros_like(x10_prob)
    edge_11 = torch.zeros_like(x11_prob)
    edge_12 = torch.zeros_like(x12_prob)
    edge_13 = torch.zeros_like(x13_prob)

    edge_20 = torch.zeros_like(x20_prob)
    edge_21 = torch.zeros_like(x21_prob)
    edge_22 = torch.zeros_like(x22_prob)
    edge_23 = torch.zeros_like(x23_prob)

    edge_30 = torch.zeros_like(x30_prob)
    edge_31 = torch.zeros_like(x31_prob)
    edge_32 = torch.zeros_like(x32_prob)
    edge_33 = torch.zeros_like(x33_prob)
    
    for i in range(x00_prob.shape[0]): edge_00[i, idx_00[i]] = 1
    for i in range(x11_prob.shape[0]): edge_11[i, idx_11[i]] = 1
    for i in range(x22_prob.shape[0]): edge_22[i, idx_22[i]] = 1
    for i in range(x33_prob.shape[0]): edge_33[i, idx_33[i]] = 1

    for i in range(x01_prob.shape[0]): edge_01[i, idx_01[i]] = 1
    for i in range(x10_prob.shape[0]): edge_10[i, idx_10[i]] = 1
    for i in range(x23_prob.shape[0]): edge_23[i, idx_23[i]] = 1
    for i in range(x32_prob.shape[0]): edge_32[i, idx_32[i]] = 1

    for i in range(x02_prob.shape[0]): edge_02[i, idx_02[i]] = 1
    for i in range(x13_prob.shape[0]): edge_13[i, idx_13[i]] = 1
    for i in range(x20_prob.shape[0]): edge_20[i, idx_20[i]] = 1
    for i in range(x31_prob.shape[0]): edge_31[i, idx_31[i]] = 1

    for i in range(x03_prob.shape[0]): edge_03[i, idx_03[i]] = 1
    for i in range(x12_prob.shape[0]): edge_12[i, idx_12[i]] = 1
    for i in range(x21_prob.shape[0]): edge_21[i, idx_21[i]] = 1
    for i in range(x30_prob.shape[0]): edge_30[i, idx_30[i]] = 1
    
    edge = torch.cat([torch.cat([edge_00, edge_01, edge_02, edge_03], dim=1),
                      torch.cat([edge_10, edge_11, edge_12, edge_13], dim=1),
                      torch.cat([edge_20, edge_21, edge_22, edge_23], dim=1),
                      torch.cat([edge_30, edge_31, edge_32, edge_33], dim=1)], dim=0).t()

    edge_index = edge.nonzero(as_tuple=False).t()

    """
    non-class-controlled features
    """
    x0 = (x0 - args.mu_y)
    x1 = (x1 + args.mu_y)
    x = torch.cat([x0, x1], dim=0)

    ### TODO ###
    sens_x = sens_x + (s==1).float().view(-1, 1)*args.mu_s - (s==0).float().view(-1, 1)*args.mu_s
    ############
    x = torch.cat([x, sens_x], dim=1)

    """
    to pyg graph
    """
    graph = Data(x=x, y=y, s=s, edge_index=edge_index)

    return graph
