from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import numpy as np
import networkx as nx
import os

class BaseDataset(Dataset):
    def __init__(self):
        super().__init__()

    def regen(self):
        raise NotImplementedError

    def set_curriculum_size(self, size):
        raise NotImplementedError

    @staticmethod
    def normalize(data):
        data = np.divide(data, np.max(data))
        data = np.transpose(data, (2, 0, 1))
        return data

    def plot_graph(self, G, title, writer, savepath=None):
        fig = plt.figure(figsize=(8, 6), dpi=300)
        nx.draw_networkx(G, pos=nx.spring_layout(G), with_labels=False,
                         edge_color='grey', width=1, node_size=200, cmap=plt.get_cmap('Set1'),
                         alpha=0.8)

        fig.canvas.draw()

        data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
        data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        data = self.normalize(data)

        if savepath is not None:
            os.makedirs(savepath, exist_ok=True)
            plt.savefig(savepath + '/' + title + '.pdf')
        else:
            writer.add_image(title, data)
        plt.close()

    def visualize(self, writer):
        raise NotImplementedError