import torch
from torch_geometric.data import Data


def mask_edge(edge_index,edge_attr,mask,remove_edge):
    if remove_edge:
        edge_index = edge_index[:,mask]
        edge_attr = edge_attr[mask]
    else:
        edge_attr[~mask] = 0.
    return edge_index, edge_attr


def create_node(tensor: torch.Tensor, mode: int) -> torch.Tensor:
    device = tensor.device
    nrow, ncol = tensor.shape

    if mode == 0:
        sample_node = torch.ones(nrow, ncol, device=device)
        feature_node = torch.eye(ncol, device=device)
        node = torch.cat([sample_node, feature_node], dim=0)

    elif mode == 1:
        sample_node = torch.zeros(nrow, ncol + 1, device=device)
        sample_node[:, 0] = 1
        feature_node = torch.eye(ncol, device=device)
        feature_node = torch.cat([torch.zeros(ncol, 1, device=device), feature_node], dim=1)
        node = torch.cat([sample_node, feature_node], dim=0)

    return node


def create_edge(tensor: torch.Tensor):
    device = tensor.device
    nrow, ncol = tensor.shape

    # row_indices: (nrow*ncol,)
    row_indices = torch.arange(nrow, device=device, dtype=torch.int64).unsqueeze(1).repeat(1, ncol).flatten()
    # col_indices: (nrow*ncol,)
    col_indices = (torch.arange(ncol, device=device, dtype=torch.int64) + nrow).unsqueeze(0).repeat(nrow, 1).flatten()

    # (nrow*ncol)개 객체 -> 속성, (nrow*ncol)개 속성 -> 객체
    edge_start_new = torch.cat([row_indices, col_indices], dim=0)
    edge_end_new = torch.cat([col_indices, row_indices], dim=0)

    return edge_start_new, edge_end_new


def create_edge_attr(tensor: torch.Tensor) -> torch.Tensor:
    flattened = tensor.view(-1, 1)  # shape: (nrow*ncol, 1)
    edge_attr = torch.cat([flattened, flattened], dim=0)
    return edge_attr


def get_data(tensor, node_mode):
    edge_start, edge_end = create_edge(tensor)
    edge_index = torch.stack([edge_start, edge_end])
    edge_attr = torch.tensor(create_edge_attr(tensor))
    node_init = create_node(tensor, node_mode)
    x = torch.tensor(node_init)
    # set seed to fix known/unknwon edges
    train_edge_mask = ~(edge_attr.isnan())[:int(edge_attr.shape[0]/2)].squeeze(-1)
    double_train_edge_mask = torch.cat((train_edge_mask, train_edge_mask), dim=0)
    # mask edges based on the generated train_edge_mask
    # train_edge_index is known, test_edge_index in unknwon, i.e. missing
    train_edge_index, train_edge_attr = mask_edge(edge_index, edge_attr,
                                                  double_train_edge_mask, True)
    train_labels = train_edge_attr[:int(train_edge_attr.shape[0] / 2), 0]
    test_edge_index, test_edge_attr = mask_edge(edge_index, edge_attr,
                                                ~double_train_edge_mask, True)
    test_labels = test_edge_attr[:int(test_edge_attr.shape[0] / 2), 0]

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr,
                train_edge_index=train_edge_index, train_edge_attr=train_edge_attr,
                train_edge_mask=train_edge_mask, train_labels=train_labels,
                test_edge_index=test_edge_index, test_edge_attr=test_edge_attr,
                test_edge_mask=~train_edge_mask, test_labels=test_labels,
                edge_attr_dim=train_edge_attr.shape[-1] if edge_attr.shape[0] > 0 else 1,
                user_num=tensor.shape[0]
                )
    return data