import numpy as np
import torch
from torch.nn import Parameter
from torch.utils.data.sampler import Sampler
from copy import deepcopy
from collections import OrderedDict

from core.utils import batch_tensor_with_padding


class GraphDataset(torch.utils.data.Dataset):
    def __init__(self, graphs, labels, cached_num=100, device=None):
        self.graphs = graphs
        self.n_nodes = [g.n_nodes for g in graphs]
        self.labels = labels
        self.n_graphs = len(graphs)
        self.cached_num = cached_num
        self.cached_results = OrderedDict()
        self.device = device

    def filter_empty(self):
        non_empty_indices = [
            i for i, graph in enumerate(self.graphs)
            if graph.n_nodes > 0
        ]
        self.graphs = [self.graphs[i] for i in non_empty_indices]
        self.labels = [self.labels[i] for i in non_empty_indices]
        self.n_graphs = len(self.graphs)

    def __len__(self):
        return self.n_graphs

    def __getitem__(self, index):
        if isinstance(index, list):
            return [self[i] for i in index]
        else:
            return self.graphs[index], self.labels[index], index

    def collate_fn(self, item_list):
        cache_key = tuple(item[2] for item in item_list)
        if cache_key in self.cached_results:
            return self.cached_results[cache_key]

        item_list = [item for item in item_list if item[0].n_nodes > 0]
        if len(item_list) == 0:
            return (None,) * 6, None, None
        graph_list = [item[0] for item in item_list]
        graph0 = graph_list[0]
        if len(graph0.attribute_names) != 0:
            attributes, attributes_mask = batch_tensor_with_padding(
                [graph.attributes for graph in graph_list], 0
            )
            attributes = Parameter(
                deepcopy(attributes.data),
                requires_grad=graph0.requires_grad)
        else:
            attributes, attributes_mask = None, None
        if len(graph0.relation_names) != 0:
            relations, relations_mask = batch_tensor_with_padding(
                [graph.relations for graph in graph_list], 0
            )
            relations = Parameter(
                deepcopy(relations.data),
                requires_grad=graph0.requires_grad)
        else:
            relations, relations_mask = None, None
        y = Parameter(
            torch.tensor([item[1] for item in item_list]),
            requires_grad=graph0.requires_grad).to(self.device)
        indices = [item[2] for item in item_list]
        results = (attributes, attributes_mask,
                   relations, relations_mask,
                   graph0.attribute_names, graph0.relation_names), y, indices

        if len(self.cached_results) >= self.cached_num:
            self.cached_results.popitem(last=False)
        self.cached_results[cache_key] = results
        return results

    @staticmethod
    def uncollate_fn(batch, all_graphs):
        attributes, _, relations, _, _, _, _, indices = batch
        for _attr, _rel, i in zip(attributes, relations, indices):
            _graph = all_graphs[i]
            n_nodes = _graph.n_nodes
            _graph.attributes.data[:, :] = _attr[:, :n_nodes]
            _graph.relations.data[:, :] = _rel[:, :n_nodes, :n_nodes]


class SizeBasedSampler(Sampler):
    def __init__(self, sizes, size_fn, total_size, shuffle=False):
        self.sizes = sizes
        self.size_fn = size_fn
        self.total_sizes = total_size
        self.shuffle = shuffle

        sorted_sizes = sorted(enumerate(sizes), key=lambda x: x[1])
        self.batch_indices = [[]]
        self.total_sizes = [0]
        for i, _size in sorted_sizes:
            this_size = size_fn(_size)
            if self.total_sizes[-1] + this_size > total_size:
                self.batch_indices.append([])
                self.total_sizes.append(0)
            self.batch_indices[-1].append(i)
            self.total_sizes[-1] += this_size
        if self.batch_indices[-1] == []:
            self.batch_indices.pop()
            self.total_sizes.pop()

    def __iter__(self):
        if self.shuffle:
            permutation = np.random.permutation(len(self))
        else:
            permutation = list(range(len(self)))
        for i in permutation:
            yield self.batch_indices[i]

    def __len__(self):
        return len(self.batch_indices)


def get_size_fn(constant, power):
    def size_fn(n):
        return constant * (n ** power)
    return size_fn


def get_net_size_fn(n_predicates, args):
    return get_size_fn(
        n_predicates + args.n_width * args.n_layers + 1,
        args.n_variables
    )
