"""The function partition_by_category and subgraphing are borrowed from
https://github.com/FedML-AI/FedGraphNN

Copyright [FedML] [Chaoyang He, Salman Avestimehr]

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import os
import os.path as osp
import networkx as nx

import torch
from torch_geometric.data import InMemoryDataset, download_url, Data
from torch_geometric.utils import from_networkx


# RecSys
def read_mapping(path, filename):
    mapping = {}
    with open(os.path.join(path, filename)) as f:
        for line in f:
            s = line.strip().split()
            mapping[int(s[0])] = int(s[1])

    return mapping


def partition_by_category(graph, mapping_item2category):
    partition = {}
    for key in mapping_item2category:
        partition[key] = [mapping_item2category[key]]
        for neighbor in graph.neighbors(key):
            if neighbor not in partition:
                partition[neighbor] = []
            partition[neighbor].append(mapping_item2category[key])
    return partition


def subgraphing(g, partion, mapping_item2category):
    nodelist = [[] for i in set(mapping_item2category.keys())]
    for k, v in partion.items():
        for category in v:
            nodelist[category].append(k)

    graphs = []
    for nodes in nodelist:
        if len(nodes) < 2:
            continue
        graph = nx.subgraph(g, nodes)
        graphs.append(from_networkx(graph))
    return graphs


def read_RecSys(path, FL=False):
    mapping_user = read_mapping(path, 'user.dict')
    mapping_item = read_mapping(path, 'item.dict')

    G = nx.Graph()
    with open(osp.join(path, 'graph.txt')) as f:
        for line in f:
            s = line.strip().split()
            s = [int(i) for i in s]
            G.add_edge(mapping_user[s[0]], mapping_item[s[1]], edge_type=s[2])
    dic = {}
    for node in G.nodes:
        dic[node] = node
    nx.set_node_attributes(G, dic, "index_orig")
    H = nx.Graph()
    H.add_nodes_from(sorted(G.nodes(data=True)))
    H.add_edges_from(G.edges(data=True))
    G = H
    if FL:
        mapping_item2category = read_mapping(path, "category.dict")
        partition = partition_by_category(G, mapping_item2category)
        graphs = subgraphing(G, partition, mapping_item2category)
        return graphs
    else:
        return [from_networkx(G)]


class RecSys(InMemoryDataset):
    r"""
    Arguments:
        root (string): Root directory where the dataset should be saved.
        name (string): The name of the dataset (:obj:`"epinions"`,
        :obj:`"ciao"`).
        FL (Bool): Federated setting or centralized setting.
        transform (callable, optional): A function/transform that takes in an
            :obj:`torch_geometric.data.Data` object and returns a transformed
            version. The data object will be transformed before every access.
            (default: :obj:`None`)
        pre_transform (callable, optional): A function/transform that takes in
            an :obj:`torch_geometric.data.Data` object and returns a
            transformed version. The data object will be transformed before
            being saved to disk. (default: :obj:`None`)
    """
    def __init__(self,
                 root,
                 name,
                 FL=False,
                 splits=[0.8, 0.1, 0.1],
                 transform=None,
                 pre_transform=None):
        self.FL = FL
        if self.FL:
            self.name = 'FL' + name
        else:
            self.name = name
        self._customized_splits = splits
        super().__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        names = ['user.dict', 'item.dict', 'category.dict', 'graph.txt']
        return names

    @property
    def processed_file_names(self):
        return ['data.pt']

    @property
    def raw_dir(self):
        return osp.join(self.root, self.name, 'raw')

    @property
    def processed_dir(self):
        return osp.join(self.root, self.name, 'processed')

    def download(self):
        """
            Download raw files to `self.raw_dir` from FedGraphNN.
            Paper: https://arxiv.org/abs/2104.07145
            Repo: https://github.com/FedML-AI/FedGraphNN
        """
        url = 'https://raw.githubusercontent.com/FedML-AI/FedGraphNN' \
              '/82912342950e0cd1be2b683e48ef8bfd5cb0a276/data' \
              '/recommender_system/'
        if self.name.startswith('FL'):
            suffix = self.name[2:]
        else:
            suffix = self.name
        url = osp.join(url, suffix)
        for name in self.raw_file_names:
            download_url(f'{url}/{name}', self.raw_dir)

    def process(self):
        # Read data into huge `Data` list.
        data_list = read_RecSys(self.raw_dir, self.FL)

        data_list_w_masks = []
        for data in data_list:
            if self.name.endswith('epinions'):
                data.edge_type = data.edge_type - 1
            if data.num_edges == 0:
                continue
            indices = torch.randperm(data.num_edges)
            data.train_edge_mask = torch.zeros(data.num_edges,
                                               dtype=torch.bool)
            data.train_edge_mask[indices[:round(self._customized_splits[0] *
                                                data.num_edges)]] = True
            data.valid_edge_mask = torch.zeros(data.num_edges,
                                               dtype=torch.bool)
            data.valid_edge_mask[indices[
                round(self._customized_splits[0] *
                      data.num_edges):round((self._customized_splits[0] +
                                             self._customized_splits[1]) *
                                            data.num_edges)]] = True
            data.test_edge_mask = torch.zeros(data.num_edges, dtype=torch.bool)
            data.test_edge_mask[indices[round((self._customized_splits[0] +
                                               self._customized_splits[1]) *
                                              data.num_edges):]] = True
            data_list_w_masks.append(data)
        data_list = data_list_w_masks

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])
