import json
import numpy as np
import networkx as nx
import torch
from torch.nn import Linear
from torch_scatter import scatter_mean
from torch_geometric.nn import MessagePassing
from torch_geometric.data import DataLoader
from torch_geometric.utils import degree
from torch_geometric.transforms import Compose
from torch_geometric.utils import to_networkx, from_networkx
from torch_geometric.datasets import TUDataset
from tqdm import trange
import os
import pandas as pd

from ts_net import TSNet


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"


class NetworkXTransform(object):
    def __init__(self, cat=False):
        self.cat = cat

    def __call__(self, data):
        x = data.x
        netx_data = to_networkx(data)
        ecc = self.nx_transform(netx_data)
        nx.set_node_attributes(netx_data, ecc, 'x')
        ret_data = from_networkx(netx_data)
        ret_x = ret_data.x.view(-1, 1).type(torch.float32)
        if x is not None and self.cat:
            x = x.view(-1, 1) if x.dim() == 1 else x
            data.x = torch.cat([x, ret_x], dim=-1)
        else:
            data.x = ret_x
        return data
    
    def nx_transform(self, networkx_data):
        """ returns a node dictionary with a single attribute
        """
        raise NotImplementedError


class Eccentricity(NetworkXTransform):
    def nx_transform(self, data):
        return nx.eccentricity(data)


class ClusteringCoefficient(NetworkXTransform):
    def nx_transform(self, data):
        return nx.clustering(data)


def get_transform(name):
    if name == "eccentricity":
        transform = Eccentricity()
    elif name == "clustering_coefficient":
        transform = ClusteringCoefficient()
    elif name == "scatter":
        transform = Compose([Eccentricity(), ClusteringCoefficient(cat=True)])
    else:
        raise NotImplementedError("Unknown transform %s" % name)
    return transform


def split_dataset(dataset, splits=(0.8, 0.1, 0.1), seed=0):
    """ Splits data into non-overlapping datasets of given proportions.
    """
    print("Our splits are ",splits)
    splits = np.array(splits)
    splits = splits / np.sum(splits)
    n = len(dataset)
    torch.manual_seed(seed)
    val_size = int(splits[1] * n)
    test_size = int(splits[2] * n)
    train_size = n - val_size - test_size
    ds = dataset.shuffle()
    return ds[:train_size], ds[train_size : train_size + val_size], ds[-test_size:]


def accuracy(model, dataset, name):
    loader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)
    for data in loader:
        data = data.to(device)
        pred = model(data).max(dim=1)[1]
    correct = float(pred.eq(data.y).sum().item())
    acc = correct / len(dataset)
    return acc, pred

class EarlyStopping(object):
    """ Early Stopping pytorch implementation from Stefano Nardo https://gist.github.com/stefanonardo/693d96ceb2f531fa05db530f3e21517d """
    def __init__(self, mode='min', min_delta=0, patience=10, percentage=False):
        self.mode = mode
        self.min_delta = min_delta
        self.patience = patience
        self.best = None
        self.num_bad_epochs = 0
        self.is_better = None
        self._init_is_better(mode, min_delta, percentage)

        if patience == 0:
            self.is_better = lambda a, b: True
            self.step = lambda a: False

    def step(self, metrics):
        if self.best is None:
            self.best = metrics
            return False

        if metrics != metrics: # slight modification from source, to handle non-tensor metrics. If NAN, return True.
            return True

        if self.is_better(metrics, self.best):
            self.num_bad_epochs = 0
            self.best = metrics
        else:
            self.num_bad_epochs += 1

        if self.num_bad_epochs >= self.patience:
            return True

        return False

    def _init_is_better(self, mode, min_delta, percentage):
        if mode not in {'min', 'max'}:
            raise ValueError('mode ' + mode + ' is unknown!')
        if not percentage:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - min_delta
            if mode == 'max':
                self.is_better = lambda a, best: a > best + min_delta
        else:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - (
                            best * min_delta / 100)
            if mode == 'max':
                self.is_better = lambda a, best: a > best + (
                            best * min_delta / 100)

def evaluate(model, train_ds, test_ds, val_ds):
    train_acc, train_pred = accuracy(model, train_ds, "Train")
    test_acc, test_pred = accuracy(model, test_ds, "Test")
    val_acc, val_pred = accuracy(model, val_ds, "Validation")
    results = {
        "train_acc": train_acc,
        "train_pred": train_pred,
        "test_acc": test_acc,
        "test_pred": test_pred,
        "val_acc": val_acc,
        "val_pred": val_pred,
        "state_dict": model.state_dict(),
    }
    return results

def train_model(args, out_file):
#    with open(str(in_dir), "r") as fp:
#        run_args = json.load(fp)
    #out_name, out_end = os.path.abspath(str(out_file)).split('.')
    run_args = args
    run_args["transform"] = "scatter"
    

    if "transform" in run_args:
        transform = get_transform(run_args["transform"])
    else:
        transform = None

    if run_args["dataset"] in ["COLLAB", "REDDIT-MULTI-5K", "IMDB-BINARY","IMDB-MULTI","BZR","OHSU"]:
        dataset = TUDataset(
            root="", name=run_args["dataset"], pre_transform=transform, use_node_attr=True
        )
        train_ds, val_ds, test_ds = split_dataset(dataset,splits=run_args["split"])
        train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=8)

    elif run_args["dataset"] in ["ogbg-molhiv"]:
        from ogb.graphproppred import PygGraphPropPredDataset
        d_name = "ogbg-molhiv"
        dataset = PygGraphPropPredDataset(name=run_args["dataset"])

        split_idx = dataset.get_idx_split()
        train_loader = DataLoader(dataset[split_idx["train"]], batch_size=32, shuffle=True, num_workers=8)
        valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=32, shuffle=False)
        test_loader = DataLoader(dataset[split_idx["test"]], batch_size=32, shuffle=False)

    if run_args["model"] == "ts_net":
        model = TSNet(dataset.num_node_features, dataset.num_classes, trainable_laziness=False)
    else:
        raise NotImplementedError()

    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    loss_fn = torch.nn.CrossEntropyLoss()
    early_stopper = EarlyStopping(mode = 'max',patience=100,percentage=True)

    model.train()
    for epoch in trange(1, 1000 + 1):
        for data in train_loader:
            optimizer.zero_grad()
            data = data.to(device)
            out = model(data)
            loss = loss_fn(out, data.y)
            loss.backward()
            optimizer.step()
        if epoch % 10 == 0:
            results = evaluate(model, train_ds, test_ds, val_ds)
            print('Epoch:', epoch, results['train_acc'], results['test_acc'])
            if early_stopper.step(results['val_acc']):
                print("Early stopping criterion met. Ending training.")
                break # if the validation accuracy decreases for eight consecutive epochs, break. model.eval()
    results = evaluate(model, train_ds, test_ds,val_ds)
    return results
if __name__ == '__main__':
    args = {
        "dataset": "IMDB-BINARY",
        "model": "ts_net",
        "intermediate_channels":64,
        "num layers":0,
        "split":(0.8,0.1,0.1),
        "model_args": {
            "epsilon": 1e-16,
            "num_layers": 0
        },
        "model_dir": "/home/anonymous/trainable_scattering/models/v1/0",
        "transform": "fast_scatter"
    }
    # test different numbers of intermediate channels
    results = train_model(args, '../../../../data/fast_scatter_tests/model.pth')
    tests_compiled = {
        'Experiment':'Changing Splits',
        'Lin1 Out Dimension':args['intermediate_channels'],
        'Train-Validation-Test Split':[args["split"]],
        'Number of Linear Layers':2+args["num layers"],
        'Final accuracy':[results['test_acc']],
        'Dataset':args["dataset"]
    }
    for i in range(1,8):
        args['split']=(0.8-0.09*i,0.1,0.1+0.09*i)
        print(f"Testing with {args['split']} train-validation-test split")
        results = train_model(args, '../../../../data/fast_scatter_tests/model.pth')
        tests_compiled['Train-Validation-Test Split'].append(args['split'])
        tests_compiled['Final accuracy'].append(results['test_acc'])
    df = pd.DataFrame(tests_compiled)
    df.to_csv('splitting_'+args['dataset']+'-'+tests_compiled['Experiment']+'.csv')
    #train_model('../experimental/args.json', '../../models/tmp/model.pth')

