import json
import numpy as np
import pandas as pd
import networkx as nx
import torch
from torch.nn import Linear
from torch_scatter import scatter_mean
from torch_geometric.nn import MessagePassing
from torch_geometric.data import DataLoader
from torch_geometric.utils import degree
from torch_geometric.transforms import Compose
from torch_geometric.utils import to_networkx, from_networkx
from torch_geometric.datasets import TUDataset, MNISTSuperpixels
from tqdm import trange
import os

from ts_net import TSNet
from simple_classifier import SCNet, LinearRegression, RBF_SVM, SVM
from fast_scatter import FastScatterTransform
from early_stopping import EarlyStopping


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"

class EarlyStopping(object):
    """ Early Stopping pytorch implementation from Stefano Nardo https://gist.github.com/stefanonardo/693d96ceb2f531fa05db530f3e21517d """
    def __init__(self, mode='min', min_delta=0, patience=10, percentage=False):
        self.mode = mode
        self.min_delta = min_delta
        self.patience = patience
        self.best = None
        self.num_bad_epochs = 0
        self.is_better = None
        self._init_is_better(mode, min_delta, percentage)

        if patience == 0:
            self.is_better = lambda a, b: True
            self.step = lambda a: False

    def step(self, metrics):
        if self.best is None:
            self.best = metrics
            return False

        if metrics != metrics: # slight modification from source, to handle non-tensor metrics. If NAN, return True.
            return True

        if self.is_better(metrics, self.best):
            self.num_bad_epochs = 0
            self.best = metrics
        else:
            self.num_bad_epochs += 1

        if self.num_bad_epochs >= self.patience:
            return True

        return False

    def _init_is_better(self, mode, min_delta, percentage):
        if mode not in {'min', 'max'}:
            raise ValueError('mode ' + mode + ' is unknown!')
        if not percentage:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - min_delta
            if mode == 'max':
                self.is_better = lambda a, best: a > best + min_delta
        else:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - (
                            best * min_delta / 100)
            if mode == 'max':
                self.is_better = lambda a, best: a > best + (
                            best * min_delta / 100)

class NetworkXTransform(object):
    def __init__(self, cat=False):
        self.cat = cat

    def __call__(self, data):
        x = data.x
        netx_data = to_networkx(data)
        ecc = self.nx_transform(netx_data)
        nx.set_node_attributes(netx_data, ecc, 'x')
        ret_data = from_networkx(netx_data)
        ret_x = ret_data.x.view(-1, 1).type(torch.float32)
        if x is not None and self.cat:
            x = x.view(-1, 1) if x.dim() == 1 else x
            data.x = torch.cat([x, ret_x], dim=-1)
        else:
            data.x = ret_x
        return data
    
    def nx_transform(self, networkx_data):
        """ returns a node dictionary with a single attribute
        """
        raise NotImplementedError


class Eccentricity(NetworkXTransform):
    def nx_transform(self, data):
        return nx.eccentricity(data)


class ClusteringCoefficient(NetworkXTransform):
    def nx_transform(self, data):
        return nx.clustering(data)


def get_transform(name):
    if name == "eccentricity":
        transform = Eccentricity()
    elif name == "clustering_coefficient":
        transform = ClusteringCoefficient()
    elif name == "scatter":
        transform = Compose([Eccentricity(), ClusteringCoefficient(cat=True)])
    elif name == "fast_scatter":
        transform = Compose([Eccentricity(), ClusteringCoefficient(cat=True), FastScatterTransform(device)])
    else:
        raise NotImplementedError("Unknown transform %s" % name)
    return transform


def split_dataset(dataset, splits=(0.8, 0.1, 0.1), seed=0):
    """ Splits data into non-overlapping datasets of given proportions.
    """
    splits = np.array(splits)
    splits = splits / np.sum(splits)
    n = len(dataset)
    torch.manual_seed(seed)
    val_size = int(splits[1] * n)
    test_size = int(splits[2] * n)
    train_size = n - val_size - test_size
    ds = dataset.shuffle()
    return ds[:train_size], ds[train_size : train_size + val_size], ds[-test_size:]


def accuracy(model, dataset, name):
    loader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)
    for data in loader:
        data = data.to(device)
        pred = model(data).max(dim=1)[1]
    correct = float(pred.eq(data.y).sum().item())
    acc = correct / len(dataset)
    return acc, pred


def evaluate(model, train_ds, val_ds, test_ds, epoch):
    model.eval()
    train_acc, train_pred = accuracy(model, train_ds, "Train")
    val_acc, val_pred = accuracy(model, val_ds, "Val")
    test_acc, test_pred = accuracy(model, test_ds, "Test")
    val_acc, val_pred = accuracy(model, val_ds, "Validation")
    results = {
        "epoch": epoch,
        "train_acc": train_acc,
        "train_pred": train_pred,
        "val_acc": val_acc,
        "val_pred": val_pred,
        "test_acc": test_acc,
        "test_pred": test_pred,
        "val_acc": val_acc,
        "val_pred": val_pred,
        "state_dict": model.state_dict(),
    }
    model.train()
    return results


def train_model(in_dir, out_file):
    with open(str(in_dir), "r") as fp:
        run_args = json.load(fp)
    print(run_args)
    out_name, out_end = os.path.abspath(str(out_file)).split('.')
    if "transform" in run_args:
        transform = get_transform(run_args["transform"])
    else:
        transform = None

    if run_args["dataset"] in ["COLLAB", "REDDIT-MULTI-5K", "IMDB-BINARY", "IMDB-MULTI","BZR","OHSU"]:
#        dataset = TUDataset(
#            root="/data/anonymous/tu/%s" % run_args["transform"], name=run_args["dataset"], pre_transform=transform, use_node_attr=True
#        )
        
#        dataset = TUDataset(
#            root="../../../../data", name=run_args["dataset"],
#            pre_transform=transform,use_node_attr=True
#        )
        dataset = TUDataset(
            root="", name=run_args["dataset"],
            pre_transform=transform,use_node_attr=True
        )
        train_ds, val_ds, test_ds = split_dataset(dataset)
        train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=8)

    elif run_args["dataset"] in ["ogbg-molhiv"]:
        from ogb.graphproppred import PygGraphPropPredDataset
        d_name = "ogbg-molhiv"
        dataset = PygGraphPropPredDataset(name=run_args["dataset"])

        split_idx = dataset.get_idx_split()
        train_loader = DataLoader(dataset[split_idx["train"]], batch_size=32, shuffle=True, num_workers=8)
        valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=32, shuffle=False)
        test_loader = DataLoader(dataset[split_idx["test"]], batch_size=32, shuffle=False)
    elif run_args["dataset"] in ["MNISTSuperpixels"]:
        dataset = MNISTSuperpixels(root="/data/anonymous/pytorch_geometric_datasets/MNIST/%s" % run_args["transform"], pre_transform=transform)
        train_ds, val_ds, test_ds = split_dataset(dataset)
        train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=8)
    else:
        raise NotImplementedError("Dataset %s not implemented" % run_args["dataset"])

    if run_args["model"] == "ts_net":
        #model = LinearRegression(dataset[0].x.shape[1], dataset.num_classes)
        model = SCNet(dataset[0].x.shape[1], run_args["intermediate_channels"],run_args["num layers"], dataset.num_classes)
        #model = SCNet(dataset[0].x.shape[1], dataset.num_classes, trainable_laziness=False)
    elif run_args["model"] == "fast_rbf_net":
        model = RBF_SVM(
            dataset.num_node_features,
            dataset.num_classes,
        )
    elif run_args["model"] == "fast_svm_net":
        model = SVM(
            dataset.num_node_features,
            dataset.num_classes,
        )
    else:
        raise NotImplementedError()

    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    loss_fn = torch.nn.CrossEntropyLoss()
    es = EarlyStopping(patience = 100, mode='max', percentage=True)
    model.train()
    for epoch in range(1, 1000 + 1):
        for data in train_loader:
            optimizer.zero_grad()
            data = data.to(device)
            out = model(data)
            loss = loss_fn(out, data.y)
            loss.backward()
            optimizer.step()
        if epoch % 10 == 0:
            results = evaluate(model, train_ds, val_ds, test_ds, epoch)
            print('Epoch:', epoch, results['train_acc'], results['val_acc'], results['test_acc'])
            if epoch % 100 == 0:
                torch.save(results, '%s_%d.%s' % (out_name, epoch, out_end))
            metric = results['val_acc']
            if es.step(metric, results):
                torch.save(results, '%s_%d.%s' % (out_name, epoch, out_end))
                torch.save(es.best_model, '%s_best.%s' % (out_name, out_end))
                print('early stopping at epoch %d' % epoch)
                print('best model was at epoch %d' % es.best_model['epoch'])
                print('Achieved', es.best_model['train_acc'], es.best_model['val_acc'], es.best_model['test_acc'])
                break
    model.eval()
    results = evaluate(model, train_ds, val_ds, test_ds, epoch)
    torch.save(results, str(out_file))


if __name__ == '__main__':
    train_model('../experimental/args.json', '../../models/tmp/model.pth')
