import os
import pickle
import random

import matplotlib.pyplot as plt
import networkx as nx
from torch.utils.data import Dataset


def get_previous(i, v_max):
    if i == 0:
        return v_max
    else:
        return i - 1


def get_next(i, v_max):
    if i == v_max:
        return 0
    else:
        return i + 1


def is_cycle(g):
    size = g.number_of_nodes()

    if size < 3:
        return False

    for node in range(size):
        neighbors = g.successors(node)

        if len(neighbors) != 2:
            return False

        if get_previous(node, size - 1) not in neighbors:
            return False

        if get_next(node, size - 1) not in neighbors:
            return False

    return True


def get_decision_sequence(size):
    """
    Get the decision sequence for generating valid cycles with DGMG for teacher
    forcing optimization.
    """
    decision_sequence = []

    for i in range(size):
        decision_sequence.append(0)  # Add node

        if i != 0:
            decision_sequence.append(0)  # Add edge
            decision_sequence.append(
                i - 1
            )  # Set destination to be previous node.

        if i == size - 1:
            decision_sequence.append(0)  # Add edge
            decision_sequence.append(0)  # Set destination to be the root.

        decision_sequence.append(1)  # Stop adding edge

    decision_sequence.append(1)  # Stop adding node

    return decision_sequence


def generate_dataset(v_min, v_max, n_samples, fname):
    samples = []
    for _ in range(n_samples):
        size = random.randint(v_min, v_max)
        samples.append(get_decision_sequence(size))

    with open(fname, "wb") as f:
        pickle.dump(samples, f)


class CycleDataset(Dataset):
    def __init__(self, fname):
        super(CycleDataset, self).__init__()

        with open(fname, "rb") as f:
            self.dataset = pickle.load(f)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        return self.dataset[index]

    def collate_single(self, batch):
        assert len(batch) == 1, "Currently we do not support batched training"
        return batch[0]

    def collate_batch(self, batch):
        return batch


def dglGraph_to_adj_list(g):
    adj_list = {}
    for node in range(g.number_of_nodes()):
        # For undirected graph. successors and
        # predecessors are equivalent.
        adj_list[node] = g.successors(node).tolist()
    return adj_list


class CycleModelEvaluation(object):
    def __init__(self, v_min, v_max, dir):
        super(CycleModelEvaluation, self).__init__()

        self.v_min = v_min
        self.v_max = v_max

        self.dir = dir

    def rollout_and_examine(self, model, num_samples):
        assert not model.training, "You need to call model.eval()."

        num_total_size = 0
        num_valid_size = 0
        num_cycle = 0
        num_valid = 0
        plot_times = 0
        adj_lists_to_plot = []

        for i in range(num_samples):
            sampled_graph = model()
            if isinstance(sampled_graph, list):
                # When the model is a batched implementation, a list of
                # DGLGraph objects is returned. Note that with model(),
                # we generate a single graph as with the non-batched
                # implementation. We actually support batched generation
                # during the inference so feel free to modify the code.
                sampled_graph = sampled_graph[0]

            sampled_adj_list = dglGraph_to_adj_list(sampled_graph)
            adj_lists_to_plot.append(sampled_adj_list)

            graph_size = sampled_graph.number_of_nodes()
            valid_size = self.v_min <= graph_size <= self.v_max
            cycle = is_cycle(sampled_graph)

            num_total_size += graph_size

            if valid_size:
                num_valid_size += 1

            if cycle:
                num_cycle += 1

            if valid_size and cycle:
                num_valid += 1

            if len(adj_lists_to_plot) >= 4:
                plot_times += 1
                fig, ((ax0, ax1), (ax2, ax3)) = plt.subplots(2, 2)
                axes = {0: ax0, 1: ax1, 2: ax2, 3: ax3}
                for i in range(4):
                    nx.draw_circular(
                        nx.from_dict_of_lists(adj_lists_to_plot[i]),
                        with_labels=True,
                        ax=axes[i],
                    )

                plt.savefig(self.dir + "/samples/{:d}".format(plot_times))
                plt.close()

                adj_lists_to_plot = []

        self.num_samples_examined = num_samples
        self.average_size = num_total_size / num_samples
        self.valid_size_ratio = num_valid_size / num_samples
        self.cycle_ratio = num_cycle / num_samples
        self.valid_ratio = num_valid / num_samples

    def write_summary(self):
        def _format_value(v):
            if isinstance(v, float):
                return "{:.4f}".format(v)
            elif isinstance(v, int):
                return "{:d}".format(v)
            else:
                return "{}".format(v)

        statistics = {
            "num_samples": self.num_samples_examined,
            "v_min": self.v_min,
            "v_max": self.v_max,
            "average_size": self.average_size,
            "valid_size_ratio": self.valid_size_ratio,
            "cycle_ratio": self.cycle_ratio,
            "valid_ratio": self.valid_ratio,
        }

        model_eval_path = os.path.join(self.dir, "model_eval.txt")

        with open(model_eval_path, "w") as f:
            for key, value in statistics.items():
                msg = "{}\t{}\n".format(key, _format_value(value))
                f.write(msg)

        print("Saved model evaluation statistics to {}".format(model_eval_path))


class CyclePrinting(object):
    def __init__(self, num_epochs, num_batches):
        super(CyclePrinting, self).__init__()

        self.num_epochs = num_epochs
        self.num_batches = num_batches
        self.batch_count = 0

    def update(self, epoch, metrics):
        self.batch_count = (self.batch_count) % self.num_batches + 1

        msg = "epoch {:d}/{:d}, batch {:d}/{:d}".format(
            epoch, self.num_epochs, self.batch_count, self.num_batches
        )
        for key, value in metrics.items():
            msg += ", {}: {:4f}".format(key, value)
        print(msg)
