"""
The GOOD-Motif dataset motivated by `Spurious-Motif
<https://arxiv.org/abs/2201.12872>`_.
"""
import math
import os
import os.path as osp
import pickle as pkl
from collections import defaultdict
import matplotlib.pyplot as plt

import torch
from munch import Munch
from torch_geometric.data import Data, InMemoryDataset, extract_zip
from torch_geometric.utils import from_networkx, shuffle_node
from torch_geometric.data.separate import separate

from tqdm import tqdm
from sklearn.model_selection import train_test_split

from GOOD import register
from GOOD.utils.synthetic_data.BA3_loc import *


@register.dataset_register
class MUTAG(InMemoryDataset):
    r"""
    The MUTAGENICITY dataset.

    Node type mapping:
        0	C
        1	O
        2	Cl
        3	H
        4	N
        5	F
        6	Br
        7	S
        8	P
        9	I
        10	Na
        11	K
        12	Li
        13	Ca

    Label mapping:
        0: mutagen
        1: non-mutagen
    """

    def __init__(self, root: str, domain: str, shift: str = 'no_shift', subset: str = 'train', transform=None,
                 pre_transform=None, generate: bool = False, debias=False):

        self.name = self.__class__.__name__
        self.domain = domain
        self.minority_class = None
        self.metric = 'Accuracy'
        self.task = 'Multi-label classification'
        self.url = ''

        assert False

    @property
    def raw_dir(self):
        return osp.join(self.root)

    @staticmethod
    def load(dataset_root: str, domain: str, shift: str = 'no_shift', generate: bool = False, debias: bool =False, model_name=None, add_pos_feat=None):
        r"""
        A staticmethod for dataset loading. This method instantiates dataset class, constructing train, id_val, id_test,
        ood_val (val), and ood_test (test) splits. Besides, it collects several dataset meta information for further
        utilization.

        Args:
            dataset_root (str): The dataset saving root.
            domain (str): The domain selection. Allowed: 'degree' and 'time'.
            shift (str): The distributional shift we pick. Allowed: 'no_shift', 'covariate', and 'concept'.
            generate (bool): The flag for regenerating dataset. True: regenerate. False: download.

        Returns:
            dataset or dataset splits.
            dataset meta info.
        """
        assert domain == "basis" and shift == "no_shift", f"{domain} - {shift} not supported"
        meta_info = Munch()
        meta_info.dataset_type = 'syn'
        meta_info.model_level = 'graph'

        dataset = Mutagenicity(dataset_root + "/Mutagenicity/")

        if "DIR" in model_name:
            dataset._data.y = dataset._data.y.squeeze(-1).long()
        else:
            dataset.y = dataset.y.squeeze().float()

        idx = np.arange(len(dataset))
        np.random.shuffle(idx)
        
        n_train, n_valid = int(0.8 * len(idx)), int(0.1 * len(idx)) # coefficients from https://github.com/LFhase/GMT/blob/main/GSAT/configs/GIN-mutag.yml
        index_train = idx[:n_train]
        index_val = idx[n_train:n_train+n_valid]
        index_test = idx[n_train+n_valid:]

        train_dataset = dataset[index_train]
        id_val_dataset = dataset[index_val]
        id_test_dataset = dataset[index_test]

        
        # print(dataset)
        # print(dataset[0])
        # print(dataset[0].node_type)
        # print(dataset[0].x)
        # print(np.unique([g.y.item() for g in dataset], return_counts=True))
        # print()

        # count = defaultdict(int)
        # H_count_per_label = defaultdict(list)
        # for g in dataset:
        #     # if g.y.item() == 1:
        #         # continue
        #     # if sum(g.node_type == 3) >= 1:
        #     #     continue

        #     for n in range(14):
        #         if n in g.node_type:
        #             count[n] += 1

        #     H_count_per_label[g.y.item()].append(sum(g.node_type == 1).item())
        # plt.hist(H_count_per_label[0], label="y=0", alpha=0.5)
        # plt.hist(H_count_per_label[1], label="y=1", alpha=0.5)
        # plt.xlim((0,20))
        # plt.ylabel("frequency")
        # plt.xlabel("count of O present in the graph")
        # plt.legend()
        # plt.savefig("GOOD/kernel/pipelines/plots/mutag_distrib_of_O_per_class.png")
        # plt.close()
        # print()
        # print(count)
        # exit("AIO")


        meta_info.dim_node = train_dataset.num_node_features
        meta_info.dim_edge = train_dataset.num_edge_features
        meta_info.edge_feat_dims = 0

        meta_info.num_envs = 1
        meta_info.num_classes = 1

        # Define networks' output shape.
        if "DIR" in model_name:
            print("The task in 'Multi-label classification'")
            meta_info.num_classes = torch.unique(train_dataset._data.y).shape[0]
            task = 'Multi-label classification'
        else:
            task = 'Binary classification'

        train_dataset.minority_class = None
        id_val_dataset.minority_class = None
        id_test_dataset.minority_class = None
        train_dataset.metric = 'Accuracy'
        id_val_dataset.metric = 'Accuracy'
        id_test_dataset.metric = 'Accuracy'

        # --- clear buffer dataset._data_list ---        
        train_dataset._data_list = None
        if id_val_dataset:
            id_val_dataset._data_list = None
            id_test_dataset._data_list = None

        return {'train': train_dataset, 'id_val': id_val_dataset, 'id_test': id_test_dataset,
                'metric': 'Accuracy', 'task': task,
                'val': id_val_dataset, 'test': id_test_dataset}, meta_info

class Mutagenicity(InMemoryDataset):
    """
        From https://github.com/LFhase/GMT/blob/main/GSAT/datasets/mutag.py
    """
    def __init__(self, root):
        print("root=", root)
        super().__init__(root=root)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['Mutagenicity_A.txt', 'Mutagenicity_edge_gt.txt', 'Mutagenicity_edge_labels.txt',
                'Mutagenicity_graph_indicator.txt', 'Mutagenicity_graph_labels.txt', 'Mutagenicity_label_readme.txt',
                'Mutagenicity_node_labels.txt', 'Mutagenicity.pkl']

    @property
    def processed_file_names(self):
        return ['data.pt']

    def process(self):
        with open(self.raw_dir + '/Mutagenicity.pkl', 'rb') as fin:
            _, original_features, original_labels = pkl.load(fin)

        edge_lists, graph_labels, edge_label_lists, node_type_lists = self.get_graph_data()

        print(original_labels.shape)

        data_list = []
        counter = 0
        for i in range(original_labels.shape[0]):
            num_nodes = len(node_type_lists[i])
            edge_index = torch.tensor(edge_lists[i], dtype=torch.long).T

            y = torch.tensor(graph_labels[i]).float().reshape(-1, 1)
            x = torch.tensor(original_features[i][:num_nodes]).float()
            assert original_features[i][num_nodes:].sum() == 0
            edge_label = torch.tensor(edge_label_lists[i]).float()
            if y.item() != 0:
                edge_label = torch.zeros_like(edge_label).float()

            node_label = torch.zeros(x.shape[0])
            signal_nodes = list(set(edge_index[:, edge_label.bool()].reshape(-1).tolist()))
            if y.item() == 0:
                node_label[signal_nodes] = 1

            if len(signal_nodes) != 0:
                node_type = torch.tensor(node_type_lists[i])
                node_type = set(node_type[signal_nodes].tolist())
                assert node_type in ({4, 1}, {4, 3}, {4, 1, 3})  # NO or NH

            # The original code was removing those graphs, but the statistics in Table 8 do not match
            # In our implementation, we keep all graphs
            # if y.item() == 0 and len(signal_nodes) == 0:
            #     counter += 1
            #     continue

            data_list.append(Data(x=x, y=y, edge_index=edge_index, node_label=node_label, edge_label=edge_label, node_type=torch.tensor(node_type_lists[i])))

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

    def get_graph_data(self):
        pri = self.raw_dir + '/Mutagenicity_'

        file_edges = pri + 'A.txt'
        # file_edge_labels = pri + 'edge_labels.txt'
        file_edge_labels = pri + 'edge_gt.txt'
        file_graph_indicator = pri + 'graph_indicator.txt'
        file_graph_labels = pri + 'graph_labels.txt'
        file_node_labels = pri + 'node_labels.txt'

        edges = np.loadtxt(file_edges, delimiter=',').astype(np.int32)
        try:
            edge_labels = np.loadtxt(file_edge_labels, delimiter=',').astype(np.int32)
        except Exception as e:
            print(e)
            print('use edge label 0')
            edge_labels = np.zeros(edges.shape[0]).astype(np.int32)

        graph_indicator = np.loadtxt(file_graph_indicator, delimiter=',').astype(np.int32)
        graph_labels = np.loadtxt(file_graph_labels, delimiter=',').astype(np.int32)

        try:
            node_labels = np.loadtxt(file_node_labels, delimiter=',').astype(np.int32)
        except Exception as e:
            print(e)
            print('use node label 0')
            node_labels = np.zeros(graph_indicator.shape[0]).astype(np.int32)

        graph_id = 1
        starts = [1]
        node2graph = {}
        for i in range(len(graph_indicator)):
            if graph_indicator[i] != graph_id:
                graph_id = graph_indicator[i]
                starts.append(i+1)
            node2graph[i+1] = len(starts)-1
        # print(starts)
        # print(node2graph)
        graphid = 0
        edge_lists = []
        edge_label_lists = []
        edge_list = []
        edge_label_list = []
        for (s, t), l in list(zip(edges, edge_labels)):
            sgid = node2graph[s]
            tgid = node2graph[t]
            if sgid != tgid:
                print('edges connecting different graphs, error here, please check.')
                print(s, t, 'graph id', sgid, tgid)
                exit(1)
            gid = sgid
            if gid != graphid:
                edge_lists.append(edge_list)
                edge_label_lists.append(edge_label_list)
                edge_list = []
                edge_label_list = []
                graphid = gid
            start = starts[gid]
            edge_list.append((s-start, t-start))
            edge_label_list.append(l)

        edge_lists.append(edge_list)
        edge_label_lists.append(edge_label_list)

        # node labels
        node_label_lists = []
        graphid = 0
        node_label_list = []
        for i in range(len(node_labels)):
            nid = i+1
            gid = node2graph[nid]
            # start = starts[gid]
            if gid != graphid:
                node_label_lists.append(node_label_list)
                graphid = gid
                node_label_list = []
            node_label_list.append(node_labels[i])
        node_label_lists.append(node_label_list)

        return edge_lists, graph_labels, edge_label_lists, node_label_lists
