import torch
from torch.autograd import Variable
import dgl

class ReservoirSamplingBuffer:
    def __init__(self, budget):
        self.budget = budget
        self.aux_features = []
        self.aux_labels = []
        self.n_seen_examples = 0

    def __len__(self):
        return len(self.aux_labels)

    def sample(self, n_samples):
        sampled_indices = torch.randperm(len(self))[:n_samples]
        return self.aux_features[sampled_indices], self.aux_labels[sampled_indices]

    def update(self, features, labels):
        place_left = max(0, self.budget - len(self))
        n_nodes = len(labels)
        if place_left:
            offset = min(place_left, n_nodes)
            if offset > 0:
                if len(self.aux_labels) > 0:
                    self.aux_features = torch.cat([self.aux_features, features[:offset]], dim=0)
                    self.aux_labels = torch.cat([self.aux_labels, labels[:offset]], dim=0)
                else:
                    self.aux_features = features[:offset]
                    self.aux_labels = labels[:offset]
            if offset < n_nodes:
                for i in range(offset, n_nodes):
                    j = torch.randint(0, self.n_seen_examples + i, (1,))
                    if j < self.budget:
                        self.aux_features[j] = features[i]
                        self.aux_labels[j] = labels[i]
        else:
            for i in range(n_nodes):
                j = torch.randint(0, self.n_seen_examples + i, (1,))
                if j < self.budget:
                    self.aux_features[j] = features[i]
                    self.aux_labels[j] = labels[i]
        self.n_seen_examples += n_nodes

class ByClassReservoirSamplingBuffer:
    def __init__(self, budget, num_classes):
        self.budget = budget
        self.num_classes = num_classes
        self.aux_buffers = [ReservoirSamplingBuffer(budget//num_classes) for _ in range(num_classes)]
    
    def __len__(self):
        return sum([len(buffer) for buffer in self.aux_buffers])
    
    def sample(self, n_samples):
        sampled_features = []
        sampled_labels = []

        global_indices = torch.randperm(len(self))[:n_samples]

        buffer_offsets = [0] + [len(buffer) for buffer in self.aux_buffers]
        buffer_offsets = torch.cumsum(torch.tensor(buffer_offsets), dim=0)

        for i, buffer in enumerate(self.aux_buffers):
            in_buffer = (global_indices >= buffer_offsets[i]) & (global_indices < buffer_offsets[i + 1])
            local_indices = global_indices[in_buffer] - buffer_offsets[i]

            if len(local_indices) > 0:
                sampled_features.append(buffer.aux_features[local_indices])
                sampled_labels.append(buffer.aux_labels[local_indices])

        if sampled_features:
            sampled_features = torch.cat(sampled_features, dim=0)
            sampled_labels = torch.cat(sampled_labels, dim=0)

        return sampled_features, sampled_labels

    def update(self, features, labels):
        for i in range(self.num_classes):
            mask = labels == i
            if mask.sum() > 0:
                self.aux_buffers[i].update(features[mask], labels[mask])

class ReservoirSSM:
    def __init__(self, budget, nei_budget):
        self.budget = budget // sum(nei_budget)
        self.aux_subgraphs = []
        self.aux_labels = []
        self.n_seen_examples = 0
        self.nei_budget = nei_budget

    def __len__(self):
        return len(self.aux_labels)

    def sample(self, n_samples):
        sampled_indices = torch.randperm(len(self))[:n_samples]
        return [self.aux_subgraphs[i] for i in sampled_indices], self.aux_labels[sampled_indices]

    def update(self, blocks, labels):
        place_left = max(0, self.budget - len(self))
        n_nodes = len(labels)
        if place_left:
            offset = min(place_left, n_nodes)
            if offset > 0:
                if len(self.aux_labels) > 0:
                    self.aux_subgraphs.extend([sparsify_blocks(blocks, i, self.nei_budget) for i in range(offset)])
                    self.aux_labels = torch.cat([self.aux_labels, labels[:offset]], dim=0)
                else:
                    self.aux_subgraphs = [sparsify_blocks(blocks, i, self.nei_budget) for i in range(offset)]
                    self.aux_labels = labels[:offset]
            if offset < n_nodes:
                for i in range(offset, n_nodes):
                    j = torch.randint(0, self.n_seen_examples + i, (1,))
                    if j < self.budget:
                        self.aux_subgraphs[j] = sparsify_blocks(blocks, i, self.nei_budget)
                        self.aux_labels[j] = labels[i]
        else:
            for i in range(n_nodes):
                j = torch.randint(0, self.n_seen_examples + i, (1,))
                if j < self.budget:
                    self.aux_subgraphs[j] = sparsify_blocks(blocks, i, self.nei_budget)
                    self.aux_labels[j] = labels[i]
        self.n_seen_examples += n_nodes

class ByClassReservoirSSM:
    def __init__(self, budget, num_classes, nei_budget):
        self.budget = budget
        self.num_classes = num_classes
        self.aux_buffers = [ReservoirSSM(budget//num_classes, nei_budget) for _ in range(num_classes)]
    
    def __len__(self):
        return sum([len(buffer) for buffer in self.aux_buffers])
    
    def sample(self, n_samples):
        sampled_subgraphs = []
        sampled_labels = []

        global_indices = torch.randperm(len(self))[:n_samples]

        buffer_offsets = [0] + [len(buffer) for buffer in self.aux_buffers]
        buffer_offsets = torch.cumsum(torch.tensor(buffer_offsets), dim=0)

        for i, buffer in enumerate(self.aux_buffers):
            in_buffer = (global_indices >= buffer_offsets[i]) & (global_indices < buffer_offsets[i + 1])
            local_indices = global_indices[in_buffer] - buffer_offsets[i]

            if len(local_indices) > 0:
                sampled_subgraphs.extend([buffer.aux_subgraphs[j] for j in local_indices])
                sampled_labels.append(buffer.aux_labels[local_indices])

        if sampled_labels:
            sampled_labels = torch.cat(sampled_labels, dim=0)

        return sampled_subgraphs, sampled_labels

    def update(self, blocks, labels):
        for i in range(self.num_classes):
            mask = labels == i
            if mask.sum() > 0:
                self.aux_buffers[i].update(blocks, labels[mask])


def sparsify_blocks(blocks, target_node, fanouts):

    current_target_nodes = torch.tensor([target_node], device=blocks[-1].device)

    # Initialize the new graph with only the target node
    new_graph = dgl.graph(([], []), num_nodes=0, device=blocks[-1].device)

    # Add the target node and its features
    new_graph.add_nodes(1, {'feat': blocks[-1].dstdata['feat'][current_target_nodes],
                            'target': torch.ones(1, dtype=torch.bool, device=blocks[-1].device)})

    # Dictionary to keep track of the new node indices
    node_mapping = {target_node: 0}
    next_node_index = 1

    for i in range(len(blocks) - 1, -1, -1):
        block = blocks[i]
        fanout = fanouts[i]

        # Get the source and destination nodes of the current block
        src_nodes, dst_nodes = block.edges()
        
        # Filter the source nodes that are connected to the current target nodes
        mask = torch.isin(dst_nodes, current_target_nodes)
        connected_src_nodes = src_nodes[mask]
        connected_dst_nodes = dst_nodes[mask]

        # Sample the source nodes randomly based on the fanout
        if len(connected_src_nodes) > fanout:
            perm = torch.randperm(len(connected_src_nodes))
            sampled_src_nodes = connected_src_nodes[perm[:fanout]]
            sampled_dst_nodes = connected_dst_nodes[perm[:fanout]]
        else:
            sampled_src_nodes = connected_src_nodes
            sampled_dst_nodes = connected_dst_nodes

        # Add the sampled nodes and their features to the new graph
        for src_node in sampled_src_nodes:
            if src_node.item() not in node_mapping:
                node_mapping[src_node.item()] = next_node_index
                new_graph.add_nodes(1)
                new_graph.ndata['feat'][next_node_index] = block.srcdata['feat'][src_node]
                new_graph.ndata['target'][next_node_index] = torch.zeros(1, dtype=torch.bool, device=blocks[-1].device)
                next_node_index += 1

        # Remap the indices for the edges
        remapped_src_nodes = torch.tensor([node_mapping[src_node.item()] for src_node in sampled_src_nodes], device=blocks[-1].device)
        remapped_dst_nodes = torch.tensor([node_mapping[dst_node.item()] for dst_node in sampled_dst_nodes], device=blocks[-1].device)

        # Connect the sampled nodes to the target nodes
        new_graph.add_edges(remapped_src_nodes, remapped_dst_nodes)

        # Update the target nodes for the next iteration
        current_target_nodes = sampled_src_nodes

    # add self loops
    new_graph.add_edges(new_graph.nodes(), new_graph.nodes())

    return new_graph    


def store_grad(pp, grads, grad_dims, tid):
    grads[:, tid].fill_(0.0)
    cnt = 0
    for param in pp():
        if param.grad is not None:
            beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
            en = sum(grad_dims[:cnt + 1])
            grads[beg: en, tid].copy_(param.grad.data.view(-1))
        cnt += 1

def overwrite_grad(pp, newgrad, grad_dims):
    cnt = 0
    for param in pp():
        if param.grad is not None:
            beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
            en = sum(grad_dims[:cnt + 1])
            this_grad = newgrad[beg: en].contiguous().view(
                param.grad.data.size())
            param.grad.data.copy_(this_grad)
        cnt += 1

def MultiClassCrossEntropy(logits, labels, T):
    labels = Variable(labels.data, requires_grad=False).cuda()
    outputs = torch.log_softmax(logits/T, dim=1)
    labels = torch.softmax(labels/T, dim=1)
    outputs = torch.sum(outputs * labels, dim=1, keepdim=False)
    outputs = -torch.mean(outputs, dim=0, keepdim=False)
    return outputs

def kaiming_normal_init(m):
	if isinstance(m, torch.nn.Conv2d):
		torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
	elif isinstance(m, torch.nn.Linear):
		torch.nn.init.kaiming_normal_(m.weight, nonlinearity='sigmoid')
