import torch
import copy
from torch_sparse import SparseTensor
import os
from typing import Optional
from torch_geometric.loader import ClusterData
import sys

from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T

from torch_geometric.datasets import PPI
from torch_geometric.loader import DataLoader

class GraphSelect:
    def __init__(self, root, args):
        self.args = args
        if args.dataset in ['Cora', 'CiteSeer', 'PubMed']:
            self.dataset = Planetoid(root, args.dataset, transform=T.NormalizeFeatures())
        elif args.dataset in ['PPI']:
            self.dataset = [PPI(root, split='train'), PPI(root, split='val'), PPI(root, split='test')]
        else:
            print('Error ! this dataset does not support !')

    def graph_config(self, save_dir, types='transductive'):
        if types == 'transductive':
            return self.config_transductive(save_dir)
        elif types == 'inductive':
            return self.config_inductive(save_dir)
        else:
            print('Error ! select transductive or inductive')

    def config_transductive(self, save_dir):
        raw_data = self.dataset[0]
        if not os.path.isdir(save_dir):
            os.mkdir(save_dir)
        filename = f'{self.args.type}_{self.args.nprocs}.pt'
        path = os.path.join(save_dir or '', filename)
        if save_dir is not None and os.path.exists(path):
            data = torch.load(path)
        else:
            data = data_decomposition(raw_data, self.args.nprocs, self.args.type)
            if save_dir is not None:
                torch.save(data, path)
        train_data = DataLoader(data, batch_size=1, shuffle=True, generator=torch.Generator().manual_seed(self.args.seed)) if self.args.type in ['cluster', 'cluster_ran'] else data
        test_data = val_data = train_data
        if self.args.inductive:
            val_data = test_data = DataLoader([raw_data], batch_size=1, shuffle=False) if self.args.type in ['cluster', 'cluster_ran'] else raw_data
        return train_data, val_data, test_data

    def config_inductive(self, save_dir):
        train_dataset, val_dataset, test_dataset = self.dataset
        train_raw_loader = DataLoader(train_dataset, batch_size=1, shuffle=False)
        val_raw_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
        test_raw_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

        if not os.path.isdir(save_dir):
            os.mkdir(save_dir)
        filename = f'{self.args.type}_{self.args.nprocs}.pt'
        path = os.path.join(save_dir or '', filename)
        if save_dir is not None and os.path.exists(path):
            train_data, val_data, test_data = torch.load(path)
        else:
            train_data = [data_decomposition(raw_data, self.args.nprocs, self.args.type) for raw_data in train_raw_loader]
            val_data = [data_decomposition(raw_data, self.args.nprocs, self.args.type) for raw_data in val_raw_loader]
            test_data = [data_decomposition(raw_data, self.args.nprocs, self.args.type) for raw_data in test_raw_loader]
            if save_dir is not None:
                torch.save([train_data, val_data, test_data], path)
        train_loader = DataLoader(train_data, batch_size=1, shuffle=True, generator=torch.Generator().manual_seed(self.args.seed))
        if self.args.inductive:
            val_loader = val_raw_loader
            test_loader = test_raw_loader
        else:
            val_loader = DataLoader(val_data, batch_size=1, shuffle=False)
            test_loader = DataLoader(test_data, batch_size=1, shuffle=False)
        return train_loader, val_loader, test_loader

def adj_crop(adj, partptr, N):
    edge_index = []
    for i in range(N):
        adj_row = adj.narrow(0, partptr[i], partptr[i + 1] - partptr[i])
        adj_slice_1 = adj_row.narrow(1, partptr[0], partptr[i] - partptr[0])
        row, col, _ = adj_slice_1.coo()
        edge_index_1 = torch.stack([row + partptr[i], col], dim=0)
        adj_slice_2 = adj_row.narrow(1, partptr[i + 1], partptr[-1] - partptr[i + 1])
        row, col, _ = adj_slice_2.coo()
        edge_index_2 = torch.stack([row + partptr[i], col + partptr[i + 1]], dim=0)
        edge_index.append(torch.cat([edge_index_1, edge_index_2], dim=1))
    edge_index = torch.cat(edge_index, dim=1)
    return SparseTensor(row=edge_index[0], col=edge_index[1], sparse_sizes=adj.sizes())

def data_decomposition(raw_data, nprocs, method='primal_dual', seed=0):
    if method in ['metis', 'primal_dual']:
        cluster_data = NonOverlapData(raw_data, nprocs)
        data = cluster_data.__decompose__()
        if method in ['metis']:
            data.edge_index_interface = None
    elif method in ['cluster']:
        cluster_data = MultiplicativeData(raw_data, nprocs)
        data = cluster_data.__decompose__()
    elif method in ['metis_ran', 'primal_dual_ran']:
        cluster_data = NonOverlapSplitData(raw_data, nprocs, seed=seed)
        data = cluster_data.__decompose__()
        if method in ['metis_ran']:
            data.edge_index_interface = None
    elif method in ['cluster_ran']:
        cluster_data = MultiplicativeSplitData(raw_data, nprocs, seed=seed)
        data = cluster_data.__decompose__()
    else:
        data = raw_data
    return data

class NonOverlapData(ClusterData):
    def __init__(self, data, num_parts: int, recursive: bool = False,
                 save_dir: Optional[str] = None, log: bool = True):
        super().__init__(data, num_parts, recursive, save_dir, log)

    def __decompose__(self):
        data = copy.deepcopy(self.data)
        row, col, _ = self.data.adj.coo()
        data.total_edge_index = torch.stack([row, col], dim=0)
        edge_index = []
        for idx in range(self.num_parts):
            start = int(self.partptr[idx])
            length = int(self.partptr[idx + 1]) - start
            adj = data.adj.narrow(0, start, length).narrow(1, start, length)
            row, col, _ = adj.coo()
            edge_index.append(torch.stack([row, col], dim=0) + start)
        data.edge_index = torch.cat(edge_index, dim=1)

        interface_adj = adj_crop(self.data.adj, self.partptr, self.num_parts)
        row, col, _ = interface_adj.coo()
        data.edge_index_interface = torch.stack([row, col], dim=0)
        data.adj = None
        return data

class MultiplicativeData(ClusterData):
    def __init__(self, data, num_parts: int, recursive: bool = False,
                 save_dir: Optional[str] = None, log: bool = True):
        super().__init__(data, num_parts, recursive, save_dir, log)

    def __decompose__(self):
        datasets = [self.__getitem__(idx) for idx in range(self.num_parts)]
        return datasets

class RandomSplitData(ClusterData):
    def __init__(self, data, num_parts: int, recursive: bool = False,
                 save_dir: Optional[str] = None, log: bool = True, seed: int = 0):
        assert data.edge_index is not None

        self.num_parts = num_parts

        recursive_str = '_recursive' if recursive else ''
        filename = f'partition_{num_parts}{recursive_str}.pt'
        path = os.path.join(save_dir or '', filename)
        if save_dir is not None and os.path.exists(path):
            adj, partptr, perm = torch.load(path)
        else:
            if log:  # pragma: no cover
                print('Computing Random split partitioning...', file=sys.stderr)

            N, E = data.num_nodes, data.num_edges
            adj = SparseTensor(
                row=data.edge_index[0], col=data.edge_index[1],
                value=torch.arange(E, device=data.edge_index.device),
                sparse_sizes=(N, N))
            perm = torch.randperm(N, generator=torch.Generator().manual_seed(seed))
            partptr = torch.linspace(0, N, num_parts + 1, dtype=torch.long)
            adj = adj.permute(perm)
            if save_dir is not None:
                torch.save((adj, partptr, perm), path)

            if log:  # pragma: no cover
                print('Done!', file=sys.stderr)

        self.data = self.__permute_data__(data, perm, adj)
        self.partptr = partptr
        self.perm = perm

class NonOverlapSplitData(RandomSplitData):
    def __init__(self, data, num_parts: int, recursive: bool = False,
                 save_dir: Optional[str] = None, log: bool = True, seed: int = 0):
        super().__init__(data, num_parts, recursive, save_dir, log, seed)

    def __decompose__(self):
        data = copy.deepcopy(self.data)
        row, col, _ = self.data.adj.coo()
        data.total_edge_index = torch.stack([row, col], dim=0)
        edge_index = []
        for idx in range(self.num_parts):
            start = int(self.partptr[idx])
            length = int(self.partptr[idx + 1]) - start
            adj = data.adj.narrow(0, start, length).narrow(1, start, length)
            row, col, _ = adj.coo()
            edge_index.append(torch.stack([row, col], dim=0) + start)
        data.edge_index = torch.cat(edge_index, dim=1)

        interface_adj = adj_crop(self.data.adj, self.partptr, self.num_parts)
        row, col, _ = interface_adj.coo()
        data.edge_index_interface = torch.stack([row, col], dim=0)
        data.adj = None
        return data

class MultiplicativeSplitData(RandomSplitData):
    def __init__(self, data, num_parts: int, recursive: bool = False,
                 save_dir: Optional[str] = None, log: bool = True, seed: int = 0):
        super().__init__(data, num_parts, recursive, save_dir, log, seed)

    def __decompose__(self):
        datasets = [self.__getitem__(idx) for idx in range(self.num_parts)]
        return datasets

class OrderedSplitData(ClusterData):
    def __init__(self, data, num_parts: int, recursive: bool = False,
                 save_dir: Optional[str] = None, log: bool = True, seed: int = 0):
        assert data.edge_index is not None

        self.num_parts = num_parts

        recursive_str = '_recursive' if recursive else ''
        filename = f'partition_{num_parts}{recursive_str}.pt'
        path = os.path.join(save_dir or '', filename)
        if save_dir is not None and os.path.exists(path):
            adj, partptr, perm = torch.load(path)
        else:
            if log:  # pragma: no cover
                print('Computing Random split partitioning...', file=sys.stderr)

            N, E = data.num_nodes, data.num_edges
            adj = SparseTensor(
                row=data.edge_index[0], col=data.edge_index[1],
                value=torch.arange(E, device=data.edge_index.device),
                sparse_sizes=(N, N))
            perm = torch.arange(N)
            partptr = torch.linspace(0, N, num_parts + 1, dtype=torch.long)

            if save_dir is not None:
                torch.save((adj, partptr, perm), path)

            if log:  # pragma: no cover
                print('Done!', file=sys.stderr)

        self.data = self.__permute_data__(data, perm, adj)
        self.partptr = partptr
        self.perm = perm

class NonOverlapOrderedSplitData(OrderedSplitData):
    def __init__(self, data, num_parts: int, recursive: bool = False,
                 save_dir: Optional[str] = None, log: bool = True, seed: int = 0):
        super().__init__(data, num_parts, recursive, save_dir, log, seed)

    def __decompose__(self):
        data = copy.deepcopy(self.data)
        row, col, _ = self.data.adj.coo()
        data.total_edge_index = torch.stack([row, col], dim=0)
        edge_index = []
        for idx in range(self.num_parts):
            start = int(self.partptr[idx])
            length = int(self.partptr[idx + 1]) - start
            adj = data.adj.narrow(0, start, length).narrow(1, start, length)
            row, col, _ = adj.coo()
            edge_index.append(torch.stack([row, col], dim=0) + start)
        data.edge_index = torch.cat(edge_index, dim=1)

        interface_adj = adj_crop(self.data.adj, self.partptr, self.num_parts)
        row, col, _ = interface_adj.coo()
        data.edge_index_interface = torch.stack([row, col], dim=0)
        data.adj = None
        return data