import torch
import torch_geometric
from torch_geometric.data import Data


def to_pyg_batch(node_features,
                 edge_features,
                 edge_index,
                 node2type=None,
                 edge2type=None,
                 direction='forward',
                 label=None,
                 hidden_nodes=None,
                 first_layer_nodes=None
                 ):
    if direction in ['forward', 'backward']:
        edge_features = edge_features if direction == 'forward' else edge_features.transpose(-2, -3)
        edge_index = edge_index if direction == 'forward' else torch.flip(edge_index, [0])
        data = torch_geometric.data.Data(
                x=node_features,
                edge_index=edge_index,
                edge_attr=edge_features[edge_index[0], edge_index[1]],
                node2type=node2type if node2type is not None else None,
                edge2type=edge2type if edge2type is not None else None,
                label=label,
                mask_hidden=hidden_nodes,
                mask_first_layer=first_layer_nodes
            )
        return data, None

    elif direction == 'bidirectional':
        data = torch_geometric.data.Data(
            x=node_features,
            edge_index=edge_index,
            bw_edge_index=torch.flip(edge_index, [0]),
            edge_attr=edge_features[edge_index[0], edge_index[1]],
            node2type=node2type if node2type is not None else None,
            edge2type=edge2type if edge2type is not None else None,
            label=label,
            mask_hidden=hidden_nodes,
            mask_first_layer=first_layer_nodes
        )
        return data, None


def get_node_types(nodes_per_layer):
    node_types = []
    type = 0
    for i, el in enumerate(nodes_per_layer):
        if i == 0:  # first layer
            for _ in range(el):
                node_types.append(type)
                type += 1
        elif i > 0 and i < len(nodes_per_layer) - 1:  #  hidden layers
            for _ in range(el):
                node_types.append(type)
            type += 1
        elif i == len(nodes_per_layer) - 1:  # last layer
            for _ in range(el):
                node_types.append(type)
                type += 1
    return torch.tensor(node_types)


def get_edge_types(nodes_per_layer):
    edge_types = []
    type = 0
    for i, el in enumerate(nodes_per_layer[:-1]):
        if i == 0:  # first layer
            for _ in range(el):
                for neighbour in range(nodes_per_layer[i+1]):
                    edge_types.append(type)
                type += 1
        elif i > 0 and i < len(nodes_per_layer) - 2:  #  hidden layers
            for _ in range(el):
                for neighbour in range(nodes_per_layer[i+1]):
                    edge_types.append(type)
            type += 1
        elif i == len(nodes_per_layer) - 2:  # last layer
            for neighbour in range(nodes_per_layer[i + 1]):
                for _ in range(el):
                    edge_types.append(type)
                type += 1

    # from collections import Counter
    # print(Counter(edge_types))
    return torch.tensor(edge_types)


def nn_to_edge_index(layer_layout, device, dtype=torch.long):
    edge_index = []

    node_offset = 0
    nodes_per_layer = []
    for n in layer_layout:
        nodes_per_layer.append(list(range(node_offset, node_offset + n)))
        node_offset += n

    for i in range(1, len(layer_layout)):
        for j in nodes_per_layer[i - 1]:
            for k in nodes_per_layer[i]:
                edge_index.append([j, k])

    return torch.tensor(edge_index, device=device, dtype=dtype).T

def mlp_params_to_scalegmn_data(params_list, device):
    fc1_w, fc1_b, fc2_w, fc2_b, fc3_w, fc3_b = params_list

    n_in = fc1_w.size(1)
    n_h1 = fc1_w.size(0)
    n_h2 = fc2_w.size(0)
    n_out = fc3_w.size(0)

    input_biases = torch.zeros(n_in, device=device)
    biases = torch.cat([input_biases, fc1_b, fc2_b, fc3_b])
    x = biases.unsqueeze(1)  # shape (num_nodes, 1)

    def layer_edges(w, src_off, dst_off):
        in_dim = w.size(1)
        out_dim = w.size(0)
        src = torch.arange(src_off, src_off + in_dim)
        dst = torch.arange(dst_off, dst_off + out_dim)
        src_idx, dst_idx = torch.meshgrid(src, dst, indexing='ij')
        e_idx = torch.stack([src_idx.flatten(), dst_idx.flatten()], dim=0)
        e_attr = w.flatten().unsqueeze(1)
        return e_idx, e_attr

    e1_idx, e1_attr = layer_edges(fc1_w, 0, n_in)
    e2_idx, e2_attr = layer_edges(fc2_w, n_in, n_in + n_h1)
    e3_idx, e3_attr = layer_edges(fc3_w, n_in + n_h1, n_in + n_h1 + n_h2)

    edge_index = torch.cat([e1_idx, e2_idx, e3_idx], dim=1)
    edge_attr = torch.cat([e1_attr, e2_attr, e3_attr], dim=0)

    mlp_edge_masks = torch.ones(edge_index.size(1), dtype=torch.long)
    spatial_embed_mask = torch.zeros(x.size(0), dtype=torch.long)
    node_types = torch.cat([
        torch.zeros(n_in, dtype=torch.long),
        torch.ones(n_h1, dtype=torch.long),
        torch.full((n_h2,), 2, dtype=torch.long),
        torch.full((n_out,), 3, dtype=torch.long)
    ])
    layer_layout = torch.tensor([n_in, n_h1, n_h2, n_out])
    node2type = node_types.clone()

    # Fix here: mask_hidden is 1D boolean tensor
    mask_hidden = torch.zeros(x.size(0), 1, dtype=torch.bool)
    mask_hidden[n_in:n_in + n_h1] = True
    mask_hidden[n_in + n_h1:n_in + n_h1 + n_h2] = True

    bw_edge_index = torch.flip(edge_index, [0])
    bw_edge_attr = edge_attr.clone()

    sign_mask = False

    return Data(
        x=x,
        edge_index=edge_index,
        edge_attr=edge_attr,
        bw_edge_index=bw_edge_index,
        bw_edge_attr=bw_edge_attr,
        mlp_edge_masks=mlp_edge_masks,
        spatial_embed_mask=spatial_embed_mask,
        node_types=node_types,
        layer_layout=layer_layout,
        node2type=node2type,
        mask_hidden=mask_hidden,
        sign_mask=sign_mask
    )

def mlp_params_to_scalegmn_data(w_b, test_accuracy, device):
    weights = w_b.weights
    biases = w_b.biases
    num_nodes = weights[0].shape[0] + sum(w.shape[1] for w in weights)

    x = torch.zeros(num_nodes, biases[0].shape[-1])
    edge_features = torch.zeros(num_nodes, num_nodes, weights[0].shape[-1])

    row_offset = 0
    col_offset = weights[0].shape[0]
    for i, w in enumerate(weights):
        num_in, num_out, _ = w.shape
        edge_features[row_offset : row_offset + num_in, col_offset : col_offset + num_out] = w
        row_offset += num_in
        col_offset += num_out

    row_offset = weights[0].shape[0]  # no bias in input nodes
    x[:, 0: row_offset] = torch.tensor([1])  # set input node state to 1.

    for i, b in enumerate(biases):
        num_out, _ = b.shape
        x[row_offset : row_offset + num_out] = b
        row_offset += num_out

    layer_layout = [64, 48, 32, 16, 10]
    edge_index = nn_to_edge_index(layer_layout, "cpu", dtype=torch.long)
    edge_attr = edge_features[edge_index[0], edge_index[1]]
    node2type = get_node_types(layer_layout)

    mask_hidden = torch.zeros(x.size(0), 1, dtype=torch.bool, device=device)
    mask_hidden[64:64 + 48] = True
    mask_hidden[64 + 48:64 + 48 + 32] = True
    mask_hidden[64 + 48 + 32:64 + 48 + 32 + 16] = True

    # Backward edges
    bw_edge_index = torch.flip(edge_index, [0])

    return Data(
        x=x,
        edge_index=edge_index,
        edge_attr=edge_attr,
        bw_edge_index=bw_edge_index,
        node2type=node2type,
        mask_hidden=mask_hidden,
        test_accuracy=test_accuracy
    )