import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from ogb.nodeproppred import DglNodePropPredDataset
import dgl
from dgl.data import CoraFullDataset, RedditDataset, AmazonCoBuyComputerDataset, RomanEmpireDataset
import copy
from utils.misc import shuffle_tensor
import logging
import os
from torch_geometric.io import fs

class IncrementalGraph(nn.Module):
    def __init__(self, graph):
        super().__init__()
        self.graph = dgl.to_simple(dgl.add_reverse_edges(dgl.add_self_loop(graph)))
        self.labels = self.graph.ndata['label']
        self.d_data = self.graph.ndata['feat'].shape[1]
        self.n_cls = self.labels.max().item() + 1
        self.n_nodes = self.labels.shape[0]
        self.current_subgraph = None
        self.old_to_new = torch.full((self.n_nodes,), -1, dtype=torch.int64)
        self.mask = np.full((self.n_nodes,), False, dtype=bool)

    def update_subgraph(self, node_ids=[], device=None, task=0):
        prev_num_nodes = 0 if self.current_subgraph is None else self.current_subgraph.num_nodes('_N')
        ids_current_batch = torch.arange(prev_num_nodes, prev_num_nodes + len(node_ids), dtype=torch.int64)
        self.old_to_new[node_ids] = ids_current_batch
        self.mask[node_ids] = True
        if self.current_subgraph is None:
            self.current_subgraph = dgl.node_subgraph(self.graph, node_ids, store_ids=True).to(device)
            self.current_subgraph.ndata['task'] = torch.full((len(node_ids),), task, dtype=torch.int64).to(device)
        else:
            self.current_subgraph.add_nodes(len(node_ids), {'feat': self.graph.ndata['feat'][node_ids].to(device),
                                                            'label': self.graph.ndata['label'][node_ids].to(device),
                                                            'split': self.graph.ndata['split'][node_ids].to(device),
                                                            'task': torch.full((len(node_ids),), task, dtype=torch.int64).to(device)})
            edges = self.graph.out_edges(node_ids)
            mask = self.mask[edges[1]]
            edges = (edges[0][mask], edges[1][mask])
            u, v = self.old_to_new[edges[0]].to(device), self.old_to_new[edges[1]].to(device)
            self.current_subgraph.add_edges(u, v)
            self.current_subgraph.add_edges(v[v<prev_num_nodes], u[v<prev_num_nodes])

        ids_current_batch = ids_current_batch.to(device)
        train_ids_current_batch = ids_current_batch[self.current_subgraph.ndata['split'][ids_current_batch]==0]

        return self.current_subgraph, train_ids_current_batch

def get_graph_dataset(name, ratio_valid_test=None, args=None):
    if name == 'CoraFull':
        data = CoraFullDataset()
        graph, label = data[0], data[0].dstdata['label'].view(-1, 1)
    elif name == 'Reddit':
        data = RedditDataset(self_loop=False)
        graph, label = data[0], data[0].dstdata['label'].view(-1, 1)
    elif name == 'Arxiv':
        data = DglNodePropPredDataset('ogbn-arxiv', root=f'{args.data_path}/ogb_downloaded')
        graph, label = data[0]
        graph.ndata['label'] = label
        graph.ndata['time'] = graph.ndata['year'].view(-1)
    elif name == 'AmazonComputer':
        data = AmazonCoBuyComputerDataset()
        graph, label = data[0], data[0].dstdata['label'].view(-1, 1)
    elif name == 'RomanEmpire':
        data = RomanEmpireDataset()
        graph, label = data[0], data[0].dstdata['label'].view(-1, 1) 
    elif name == 'Elliptic':
        graph = get_elliptic_graph(args)
        label = graph.ndata['label'].view(-1, 1) 
    else:
        raise ValueError(f"Dataset {name} not supported.")
    
    labels = label.squeeze().numpy()
    n_cls = labels.max().item() + 1
    cls_id_map = {c: np.where(labels == c)[0] for c in range(n_cls)}
    small_classes = [c for c in range(n_cls) if len(cls_id_map[c]) < 2]
    if len(small_classes) > 0:
        logging.info(f"Removing small classes: {small_classes}")
        mask = np.ones(len(labels), dtype=bool)
        for c in small_classes:
            mask[cls_id_map[c]] = False
        labels = labels[mask]
        graph = graph.subgraph(mask)
        unique_labels = sorted(np.unique(labels))
        label_map = {old_label: new_label for new_label, old_label in enumerate(unique_labels)}
        labels = torch.tensor([label_map[label] for label in labels])
        graph.ndata['label'] = labels
        n_cls -= len(small_classes)

    if name == 'RomanEmpire':
        # very unbalanced tasks, therefore we reorder the classes
        unique_labels = torch.unique(graph.ndata['label'])
        label_map = {old_label.item(): new_label for new_label, old_label in enumerate(shuffle_tensor(unique_labels, random_seed=476))}
        labels = torch.tensor([label_map[label] for label in labels])
        graph.ndata['label'] = labels

    # split the dataset into train, val, test sets
    labels = label.squeeze().numpy()
    train_ids, temp_ids = train_test_split(np.arange(len(labels)), test_size=sum(ratio_valid_test), stratify=labels, random_state=42)
    valid_ids, test_ids = train_test_split(temp_ids, test_size=ratio_valid_test[1]/sum(ratio_valid_test), stratify=labels[temp_ids], random_state=42)
    split = np.zeros(graph.number_of_nodes(), dtype=int)
    split[valid_ids] = 1
    split[test_ids] = 2
    graph.ndata['split'] = torch.tensor(split)

    if name == 'Elliptic':
        graph.ndata['split'][graph.ndata['label'] == -1] = 3 # do not use unknown labels for training/testing

    return graph

def get_elliptic_graph(args):
    url = 'https://data.pyg.org/datasets/elliptic'
    raw_dir = f'{args.data_path}/elliptic'
    os.makedirs(raw_dir, exist_ok=True)

    raw_file_names = [
            'elliptic_txs_features.csv',
            'elliptic_txs_edgelist.csv',
            'elliptic_txs_classes.csv',
        ]
    for file_name in raw_file_names:
        if not os.path.exists(os.path.join(raw_dir, file_name)):
            fs.cp(f'{url}/{file_name}.zip', raw_dir, extract=True)

    feat_df = pd.read_csv(os.path.join(raw_dir, raw_file_names[0]), header=None)
    edge_df = pd.read_csv(os.path.join(raw_dir, raw_file_names[1]))
    class_df = pd.read_csv(os.path.join(raw_dir, raw_file_names[2]))

    columns = {0: 'txId', 1: 'time_step'}
    feat_df = feat_df.rename(columns=columns)

    x = torch.from_numpy(feat_df.loc[:, 2:].values).to(torch.float)

    # 0=licit,  1=illicit, -1=unknown
    mapping = {'unknown': -1, '1': 1, '2': 0}
    class_df['class'] = class_df['class'].map(mapping)
    y = torch.from_numpy(class_df['class'].values)

    mapping = {idx: i for i, idx in enumerate(feat_df['txId'].values)}
    edge_df['txId1'] = edge_df['txId1'].map(mapping)
    edge_df['txId2'] = edge_df['txId2'].map(mapping)
    edge_index = torch.from_numpy(edge_df.values).t().contiguous()

    time_step = torch.from_numpy(feat_df['time_step'].values)

    elliptic_graph = dgl.graph((edge_index[0], edge_index[1]), num_nodes=len(feat_df))
    elliptic_graph.ndata['feat'] = x
    elliptic_graph.ndata['label'] = y
    elliptic_graph.ndata['time'] = time_step

    return elliptic_graph
