#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : rawpixels.py
# Author : Anonymous2, Anonymous1
# Email  : anonymous2@anon, anonymous1@anon
#
# Distributed under terms of the MIT license.

import dgl
from dgl.data.utils import download, save_graphs, load_graphs
from dgl import backend as F
from dgl.data.dgl_dataset import DGLDataset

import os
import numpy as np
from tqdm import tqdm

import torch
from torchvision import datasets, transforms

from megraph.datasets.utils.graph_generation import grid_graph

from . import register_function


class RawPixDGL(DGLDataset):
    def __init__(
        self,
        name,
        sub_name="s100",
        add_pos_feat=True,
        raw_dir=None,
        verbose=False,
        transform=None,
    ):
        assert name in ["raw_CIFAR10", "raw_MNIST"]
        self.raw_name = name
        self.is_mnist = name == "raw_MNIST"
        self.img_size = 28 if self.is_mnist else 32
        self.sub_ratio = int(sub_name[1:]) / 100.0
        self.add_pos_feat = add_pos_feat
        full_name = name
        if not add_pos_feat:
            full_name += "_nopos"
        full_name += "_" + sub_name
        super().__init__(
            name=full_name,
            raw_dir=raw_dir,
            verbose=verbose,
            transform=transform,
        )

    def has_cache(self):
        graph_path = os.path.join(self.save_path, "dgl_graph.bin")
        return os.path.exists(graph_path)

    def _get_sub_dataset(self, data, labels):
        idxs = []
        for i in range(self.num_classes):
            class_idx = np.where(labels == i)[0]
            num = np.ceil(len(class_idx) * self.sub_ratio).astype(np.int64)
            idxs.extend(class_idx[:num])
        idxs = np.array(idxs)
        np.random.shuffle(idxs)
        sub_num = int(len(labels) * self.sub_ratio)
        idxs = idxs[:sub_num]
        assert len(idxs) == sub_num
        return data[idxs], labels[idxs]

    def get_data_and_labels(self, is_train=True):
        if self.is_mnist:
            data_mean, data_std = [0.1307], [0.3081]
        else:
            data_mean, data_std = [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]
        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(mean=data_mean, std=data_std),
            ]
        )
        base_dataset = datasets.MNIST if self.is_mnist else datasets.CIFAR10
        dataset = base_dataset(
            root=os.path.join(self.raw_dir, self.raw_name),
            transform=transform,
            train=is_train,
            download=True,
        )
        data, labels = [], []
        for (x, label) in dataset:
            data.append(x.permute(1, 2, 0))  # [32, 32, 3]
            labels.append(label)
        # print(type(data[0]), data[0].shape)
        data = torch.stack(data, dim=0)
        labels = np.array(labels)
        if self.sub_ratio < 1:
            data, labels = self._get_sub_dataset(data, labels)
        return data, torch.from_numpy(labels)

    def prepare(self):
        train_data, train_labels = self.get_data_and_labels(is_train=True)
        test_data, test_labels = self.get_data_and_labels(is_train=False)

        self.data = F.cat([train_data, test_data], dim=0)
        self.label = F.cat([train_labels, test_labels], dim=0)
        self.n_samples = len(train_labels) + len(test_labels)

    def process(self):
        self.prepare()
        print(f"processing {self.n_samples} graphs")
        self.graphs = []
        for i in tqdm(range(self.n_samples)):
            sz = self.img_size
            num_nodes = sz**2
            g = grid_graph(sz, sz)
            g = dgl.from_networkx(g)
            node_feat = self.data[i].reshape(num_nodes, -1)
            if self.add_pos_feat:
                pos_x = np.arange(num_nodes) // sz / (sz - 1)
                pos_x = torch.from_numpy(pos_x).unsqueeze(-1)
                pos_y = np.arange(num_nodes) % sz / (sz - 1)
                pos_y = torch.from_numpy(pos_y).unsqueeze(-1)
                node_feat = F.cat([node_feat, pos_x, pos_y], dim=1)
            g.ndata["feat"] = node_feat
            self.graphs.append(g)

    def __getitem__(self, idx):
        r"""Get graph and label by index

        Parameters
        ----------
        idx : int
            Item index

        Returns
        -------
        (:class:`dgl.DGLGraph`, Tensor)
        """
        if self._transform is None:
            g = self.graphs[idx]
        else:
            g = self._transform(self.graphs[idx])
        return g, self.label[idx]

    def save(self):
        """save the graph list and the labels"""
        graph_path = os.path.join(self.save_path, "dgl_graph.bin")
        save_graphs(str(graph_path), self.graphs, {"labels": self.label})

    def load(self):
        graphs, label_dict = load_graphs(os.path.join(self.save_path, "dgl_graph.bin"))
        self.graphs = graphs
        self.label = label_dict["labels"]

    def get_idx_split(self):
        num_raw_train = 60000 if self.raw_name == "raw_MNIST" else 50000
        splits = [num_raw_train, 5000, 10000]
        for i, s in enumerate(splits):
            splits[i] = int(s * self.sub_ratio)
        splits[0] -= splits[1]  # num_train = num_raw_train - num_val

        idxs, cur_idx = [], 0
        for s in splits:
            idxs.append(np.arange(cur_idx, cur_idx + s))
            cur_idx += s

        return {
            "train": F.tensor(idxs[0], dtype=F.data_type_dict["int64"]),
            "valid": F.tensor(idxs[1], dtype=F.data_type_dict["int64"]),
            "test": F.tensor(idxs[2], dtype=F.data_type_dict["int64"]),
        }

    def __len__(self):
        """Return the number of graphs in the dataset."""
        return self.label.shape[0]

    def statistics(self):
        return None, 10, None

    @property
    def node_feat_size(self):
        feat_size = 1 if self.raw_name == "raw_MNIST" else 3
        if self.add_pos_feat:
            feat_size += 2
        return feat_size

    @property
    def num_classes(self):
        """Number of labels for each graph, i.e. number of prediction tasks."""
        return 10


@register_function("raw_cifar10", dict(task="gpred", inductive=True, need_subname=True))
def get_raw_cifar10(raw_dir=None, name=None, **kwargs):
    sub_name = name or "s100"
    return RawPixDGL(name="raw_CIFAR10", sub_name=sub_name, raw_dir=raw_dir, **kwargs)


@register_function("raw_mnist", dict(task="gpred", inductive=True, need_subname=True))
def get_raw_mnist(raw_dir=None, name=None, **kwargs):
    sub_name = name or "s100"
    return RawPixDGL(name="raw_MNIST", sub_name=sub_name, raw_dir=raw_dir, **kwargs)


@register_function(
    "raw_cifar10_nopos", dict(task="gpred", inductive=True, need_subname=True)
)
def get_raw_cifar10(raw_dir=None, name=None, **kwargs):
    sub_name = name or "s100"
    return RawPixDGL(
        name="raw_CIFAR10",
        sub_name=sub_name,
        raw_dir=raw_dir,
        add_pos_feat=False,
        **kwargs,
    )


@register_function(
    "raw_mnist_nopos", dict(task="gpred", inductive=True, need_subname=True)
)
def get_raw_mnist(raw_dir=None, name=None, **kwargs):
    sub_name = name or "s100"
    return RawPixDGL(
        name="raw_MNIST",
        sub_name=sub_name,
        raw_dir=raw_dir,
        add_pos_feat=False,
        **kwargs,
    )
