import numpy as np
from sklearn.model_selection import StratifiedKFold

import torch

import sys
sys.path.append("/root/autodl-tmp")


from data.citation.gen_data import CitationDataset  # type: ignore
from data.wiki.gen_data import WikiDataset  # type: ignore
from data.KG.gen_data import KGOFADataset  # type: ignore
from data.chembl.gen_data import MolOFADataset  # type: ignore
from data.RS.gen_data import RecommendationDataset  # type: ignore
from data.GQA.gen_data import GQADataset  # type: ignore
from data.SN.gen_data import SocialDataset  # type: ignore

name2dataset = {"arxiv": CitationDataset, "cora": CitationDataset, "pubmed": CitationDataset, "citeseer": CitationDataset,
                "wikics": WikiDataset,
                "WN18RR": KGOFADataset, "FB15K237": KGOFADataset,
                "chempcba": MolOFADataset, "chemhiv": MolOFADataset, "chembbbp": MolOFADataset, "chembace":MolOFADataset,
                "chemtoxcast": MolOFADataset, "chemtox21": MolOFADataset, "chemcyp450": MolOFADataset, "chemmuv": MolOFADataset,
                "bookhis": RecommendationDataset, "sportsfit": RecommendationDataset, "elecomp": RecommendationDataset,
                "elephoto": RecommendationDataset, "products": RecommendationDataset,
                "expla_graph": GQADataset, "scene_graphs": GQADataset, "webqsp": GQADataset,
                "instagram": SocialDataset, "reddit": SocialDataset,
}


def k_fold_ind(labels, fold):
    """Generate stratified k fold split index based on labels

    Arguments:
        labels {np.ndarray} -- labels of the data
        fold {int} -- number of folds

    Returns:
        list[numpy.ndarray] -- A list whose elements are indices of data
        in the fold.
    """
    ksfold = StratifiedKFold(n_splits=fold, shuffle=True, random_state=10)
    folds = []
    for _, t_index in ksfold.split(
        np.zeros_like(np.array(labels)), np.array(labels, dtype=int)
    ):
        folds.append(t_index)
    return folds


def k_fold2_split(folds, data_len):
    """Split the data index into train/test/validation based on fold,
    one fold for testing, one fold for validation and the rest for training.

    Arguments:
        folds {list[numpy.ndarray]} -- fold information
        data_len {int} -- lenght of the data

    Returns:
        list[list[numpy.ndarray]] -- a list of train/test/validation split
        indices.
    """
    splits = []
    for i in range(len(folds)):
        test_arr = np.zeros(data_len, dtype=bool)
        test_arr[folds[i]] = 1
        val_arr = np.zeros(data_len, dtype=bool)
        val_arr[folds[int((i + 1) % len(folds))]] = 1
        train_arr = np.logical_not(np.logical_or(test_arr, val_arr))
        train_ind = train_arr.nonzero()[0]
        test_ind = test_arr.nonzero()[0]
        val_ind = val_arr.nonzero()[0]
        splits.append([train_ind, test_ind, val_ind])
    return splits


def ArxivSplitter(dataset):
    text_g = dataset.data
    kfold = k_fold_ind(text_g.y, 10)
    text_split = k_fold2_split(kfold, len(text_g.y))[0]
    split = {}
    split["train"] = text_split[0]
    split["valid"] = text_split[1]
    split["test"] = text_split[2]
    return split


def CiteSplitter(dataset):
    text_g = dataset.data
    split = {"train": text_g.train_masks[0].nonzero(as_tuple=True)[0],
             "valid": text_g.val_masks[0].nonzero(as_tuple=True)[0],
             "test": text_g.test_masks[0].nonzero(as_tuple=True)[0]
    }
    return split


def KGSplitter(dataset):
    converted_triplet = dataset.get_idx_split()
    split = {}
    count = 0
    for name in converted_triplet:
        split[name] = torch.arange(count, count + len(converted_triplet[name][0]))
        count += len(converted_triplet[name][0])
    return split


def WikiSplitter(dataset):
    text_g = dataset.data
    wiki_split_idx = 0
    split = {"train": torch.where(text_g.train_mask[:, wiki_split_idx])[0].numpy(),
             "valid": torch.where(text_g.val_mask[:, wiki_split_idx])[0].numpy(),
             "test": torch.where(text_g.test_mask)[0].numpy()
    }
    return split


def MolSplitter(dataset):
    return dataset.get_idx_split()


def GQASplitter(dataset):
    return dataset.get_idx_split()


def SocialSplitter(dataset):
    return dataset.get_idx_split()


name2split = {"arxiv": ArxivSplitter, "cora": CiteSplitter, "pubmed": CiteSplitter, "wikics": WikiSplitter,
              "WN18RR": KGSplitter, "FB15K237": KGSplitter,
              "chempcba": MolSplitter, "chemhiv": MolSplitter, "chembbbp": MolSplitter, "chembace": MolSplitter,
              "chemtoxcast": MolSplitter, "chemtox21": MolSplitter, "chemcyp450": MolSplitter, "chemmuv": MolSplitter,
              "scene_graphs":GQASplitter , "expla_graph":GQASplitter, "webqsp":GQASplitter,
            #   "bookhis": OfficialSplitter, "sportsfit": OfficialSplitter.
              "reddit": SocialSplitter, "instgram": SocialSplitter,
}


class UnifiedTaskConstructor: 
    def __init__(self, graph_encoder, root):
        self.root = root
        self.graph_encoder = graph_encoder
        self.dataset = {}
        self.split = {}

    def get_ofa_data(self, dataset_name):
        if dataset_name not in self.dataset:
            self.dataset[dataset_name] = name2dataset[dataset_name](dataset_name, self.graph_encoder, self.root)
        return self.dataset[dataset_name]

    def get_data_split(self, dataset_name):
        """
        Split data based on task_level
        """
        if dataset_name not in self.split:
            self.split[dataset_name] = name2split[dataset_name](self.dataset[dataset_name])
        return self.split[dataset_name]
