import torch
import dgl
import copy
import torch_geometric
from torch_geometric.loader import NeighborLoader

from .model_factory import get_feat_extractor

class NeighborhoodProcessor:
    def __init__(self, args):
        self.feature_extractor = get_feat_extractor(args)
        if self.feature_extractor is not None:
            self.feature_extractor.cuda(args.gpu)
            self.feature_extractor.requires_grad_(False)
            self.feature_extractor.eval()
        self.device = f'cuda:{args.gpu}'
        self.disjoint_subgraphs = args.backbone == 'GRNF'
        self.n_nbs_sample = args.n_nbs_sample

        if not self.disjoint_subgraphs:
            if args.sample_nbs:
                self.neighbor_sampler = dgl.dataloading.NeighborSampler(self.n_nbs_sample)
            else:
                self.neighbor_sampler = dgl.dataloading.MultiLayerFullNeighborSampler(len(self.n_nbs_sample))

    def extract(self, g, node_ids):
        if self.disjoint_subgraphs:
            sampled_data = self.extract_disjoint(g, node_ids)
        else:
            sampled_data = self.extract_blocks(g, node_ids)
        if self.feature_extractor is not None:
            with torch.no_grad():
                features = self.feature_extractor(sampled_data)
            return features
        return sampled_data

    def extract_blocks(self, g, node_ids):
        _, _, blocks = self.neighbor_sampler.sample_blocks(g, node_ids.to(device=g.device))
        blocks = [b.to(device=self.device) for b in blocks]
        return blocks

    def extract_disjoint(self, g, node_ids):
        geometric = torch_geometric.utils.from_dgl(g)
        loader = NeighborLoader(
            geometric,
            num_neighbors=self.n_nbs_sample,
            input_nodes=node_ids,
            disjoint=True,
            batch_size=len(node_ids),
            shuffle=False
        )
        sampled_data = next(iter(loader))
        sampled_data = self._reindex_batch_contiguous(sampled_data)
        sampled_data["x"] = sampled_data.feat
        sampled_data["edge_attr"] = torch.ones(sampled_data.edge_index.shape[1], 1)
        return sampled_data.to(device=self.device)

    def _reindex_batch_contiguous(self, sampled_data):
        perm = torch.argsort(sampled_data.batch, stable=True)
        sampled_data.feat = sampled_data.feat[perm]
        sampled_data.batch = sampled_data.batch[perm]
        mapping = torch.empty_like(perm).to(self.device)
        mapping[perm] = torch.arange(perm.size(0), device=self.device)
        sampled_data.edge_index = mapping[sampled_data.edge_index]
        return sampled_data

class MixedNeighborhoodProcessor:
    def __init__(self, args):
        grnf_args = copy.deepcopy(args)
        grnf_args.backbone = 'GRNF'
        grnf_args.backbone_args['h_dims'] = [int(n / 2) for n in args.backbone_args['h_dims']]
        self.processor1 = NeighborhoodProcessor(grnf_args)
        ugcn_args = copy.deepcopy(args)
        ugcn_args.backbone = 'UGCN'
        ugcn_args.backbone_args['h_dims'] = [int(n / 2) for n in args.backbone_args['h_dims']]
        self.processor2 = NeighborhoodProcessor(ugcn_args)

    def extract(self, g, node_ids):
        features1 = self.processor1.extract(g, node_ids)
        features2 = self.processor2.extract(g, node_ids)
        features = torch.cat((features1, features2), dim=1)
        return features
