import numpy as np
import torch
import dgl
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

from . import DeepSet, GIN
# import DeepSet, GIN


def init_weights(m):
    if isinstance(m, nn.Linear):
        init.kaiming_normal_(m.weight, nonlinearity='leaky_relu')
        # init.xavier_uniform_(m.weight)
        if m.bias is not None:
            init.constant_(m.bias, 0)


class AutoVisualNet(nn.Module):
    def __init__(self, deepset_dim=16, gnn_hidden=128, gnn_out=32, out_dim=2, n_graph_view=5, device='cuda:0'):
        super().__init__()
        self.sub_batch_size = 1

        self.deepset_dim, self.gnn_hidden, self.gnn_out, self.out_dim = deepset_dim, gnn_hidden, gnn_out, out_dim

        self.device = device

        self.deepset = DeepSet.DeepSet(1, self.deepset_dim).to(device)
        self.gin = GIN.GIN(self.deepset_dim*2, self.gnn_hidden, self.gnn_out).to(device)

        self.n_graph_view = n_graph_view

        self.mlp = nn.Sequential(
            # nn.LayerNorm(self.gnn_out*self.n_graph_view),
            nn.Linear(self.gnn_out*self.n_graph_view, self.gnn_out),
            nn.SELU(),
            nn.Linear(self.gnn_out, self.out_dim)
        )
        # self.reset_parameters()

    def deepset_forward(self, cdists):
        n_ds = len(cdists)
        node_features = []
        for sub_batch_idx in range(n_ds // self.sub_batch_size):
            _b_cdists = cdists[sub_batch_idx * self.sub_batch_size: (sub_batch_idx + 1) * self.sub_batch_size]
            _b_n_nodes = [c.shape[0] for c in _b_cdists]
            _b_block_size = []
            for _n in _b_n_nodes:
                _b_block_size += [_n] * _n
            _b_block_size = torch.from_numpy(np.array(_b_block_size)).to(self.device)
            _b_cdists_merged = []
            for c in _b_cdists:
                _b_cdists_merged.append(c.reshape(-1))
            _b_cdists_merged = torch.cat(_b_cdists_merged).to(self.device)
            _, _node_f = self.deepset(_b_cdists_merged.unsqueeze(1), _b_block_size)
            node_features.append(_node_f)
        # node_features = torch.concatenate(node_features)
        return node_features

    def gin_forward(self, graphs, node_features):
        # graphs: [[dgl graph_i]*n_graph_view, ...]
        n_ds = len(graphs)
        outputs = []
        for sub_batch_idx in range(n_ds // self.sub_batch_size):
            _batch_graphs = dgl.batch(graphs[sub_batch_idx]).to(self.device)  # todo: crash when sub_batch_size > 1
            # _batch_graphs = graphs
            output = self.gin(_batch_graphs, node_features[sub_batch_idx].repeat((_batch_graphs.batch_size, 1)))
            outputs.append(torch.cat(output.chunk(self.n_graph_view), dim=1))
        outputs = torch.concatenate(outputs)
        return outputs

    def forward(self, cdists, graphs):
        set_enc = self.deepset_forward(cdists)
        n_view_out = self.gin_forward(graphs, set_enc)  # (n_node_i*n_graph, d*n_view)
        out = self.mlp(n_view_out)
        return out

    def reset_parameters(self):
        """ Initialize the weights and bias.
        :return: None
        """
        self.apply(init_weights)


def list_collate_fn(items):
    transposed = zip(*items)
    return list(transposed)

if __name__ == '__main__':
    import GraphDatasets
    from dgl.dataloading import GraphDataLoader

    ds = GraphDatasets.DatasetGraphDataset(data_names=['mnist_group1'], cdist_path='../prepare_data/clip/features',
                             visual_path='../prepare_data/bo/res-2')

    net = AutoVisualNet().to('cuda:0')

    train_loader = torch.utils.data.DataLoader(ds,
                                   batch_size=2,
                                   shuffle=True,
                                   num_workers=0,
                                   collate_fn=lambda x: list(zip(*x))
                                   )
    it = iter(train_loader)
    a = next(it)

    # out = net([ds[0][0], ds[1][0]], [ds[0][1], ds[1][1]])
    out = net(a[0], a[1])

    print(out.shape)



