""" Trains a model

"""
import json
import numpy as np
import torch
import time
from torch_geometric.data import DataLoader
import ogb

from trainable_scattering.models.ts_net import TSNet
from trainable_scattering.models.GCN import GCN
from trainable_scattering.models.simple_classifier import SCNet, RBF_SVM, LinearRegression
from trainable_scattering.models.baseline import Baseline
from trainable_scattering.models.attention_scatter_rbf import attention_rbf
from trainable_scattering.models.second_attention_scatter_rbf import (
    second_attention_rbf,
)  # Only learns the diffusion scales for the second level.
from trainable_scattering.models.GraphSAGE import GraphSAGE
from trainable_scattering.models.fast_scatter_sort import SortNet, SortRBFNet
from trainable_scattering.models.early_stopping import EarlyStopping
from trainable_scattering.utils import get_dataset, split_dataset
from trainable_scattering.models.tsnet_attention import TSNetAttention
from sklearn import metrics

# Take advantage of the clusters' CUDNN backend (for hypothetical gains)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True

def accuracy(model, dataset, name, loss_fn, device):
    with torch.no_grad():
        loader = DataLoader(dataset, batch_size=32, shuffle=False)
        preds = []
        outs = []
        losses = []
        ys = []
        # scores = []
        for data in loader:
            data = data.to(device)
            model = model.to(device)
            out = model(data)
            outs.append(out)
            pred = out.max(dim=1)[1]
            preds.append(pred)
            y = data.y
            if len(data.y.shape) > 1:
                y = torch.squeeze(y)  # the ogbg-molhiv dataset wraps everything in an extra dimension
            if y.dtype == torch.float32:  # reshape when doing MSE regression
                y = y[:, None]
            losses.append(loss_fn(out, y))
            ys.append(y)
        out = torch.cat(outs, dim=0).to(device)
        pred = torch.cat(preds, dim=0).to(device)
        # full_loader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)
        y = torch.cat(ys, dim=0).to(device)

        # for data in full_loader:
        #     data = data.to('cpu')
        # if len(data.y.shape) > 1:
        #     data.y = torch.squeeze(data.y)  # the ogbg-molhiv dataset wraps everything in an extra dimension
        # if data.y.dtype == torch.float32:  # reshape when doing MSE regression
        #     data.y = data.y[:, None]
        # loss = loss_fn(out, data.y)

        # Slight approximation if batches are not the same size
        loss = torch.mean(torch.stack(losses), dim=0).to(device)

        ap = None
        auc = None
        acc = None

        # Skip if doing regression
        if data.y.dtype != torch.float32:
            correct = float(pred.eq(y).sum().item())
            acc = correct / len(dataset)
            # convert to cpu before performing numpy calculations
            if out.shape[1] == 2: # if binary classification
                y = y.cpu()
                pred = pred.cpu()
                ap = metrics.average_precision_score(y, pred)
                auc = metrics.roc_auc_score(y, pred)
            else:
                ap = "n/a"
                # auc = metrics.roc_auc_score(y, pred, multi_class='ovr')
                # multiclass auc should be possible, but it requires some special format
                auc = 'n/a'
        # Calc P@10 metric, percent anomaly in top ten error scores
        # top_anomalies = np.argsort(scores)[::-1][:10]
        # p_at_10 = metrics.accuracy_score(y_true[top_anomalies], np.ones(10))
    return acc, pred, loss, y, out, ap, auc




def evaluate(model, train_ds, val_ds, test_ds, epoch, loss_fn, device):
    model.eval()
    train_acc, train_pred, train_loss, train_true, train_out, train_ap, train_auc = accuracy(model, train_ds, "Train", loss_fn, device)
    val_acc, val_pred, val_loss, val_true, val_out, val_ap, val_auc = accuracy(model, val_ds, "Val", loss_fn, device)
    test_acc, test_pred, test_loss, test_true, test_out, test_ap, test_auc = accuracy(model, test_ds, "Test", loss_fn, device)
    results = {
        "epoch": epoch,
        "train_acc": train_acc,
        "train_pred": train_pred,
        "train_loss": train_loss,
        "train_true": train_true,
        "train_average_precision": train_ap,
        "train_roc_auc": train_auc,
        "val_acc": val_acc,
        "val_pred": val_pred,
        "val_loss": val_loss,
        "val_true": val_true,
        "val_average_precision": val_ap,
        "val_roc_auc": val_auc,
        "test_acc": test_acc,
        "test_pred": test_pred,
        "test_loss": test_loss,
        "test_true": test_true,
        "test_outs": test_out,
        "test_average_precision": test_ap,
        "test_roc_auc": test_auc,
        "state_dict": model.state_dict()
    }
    model.train()
    return results

def get_model(args, dataset, train_ds):
    def init_centers(model, train_ds):
        # This is hacky (anonymous)
        loader = DataLoader(train_ds, batch_size=500, shuffle=True)
        for data in loader:
            model.set_centres(data.x)
            break
    if args["model"] == "ts_net":
        model = TSNet(
            dataset.num_node_features, dataset.num_classes, **args["model_args"]
        )
        loss_fn = torch.nn.CrossEntropyLoss()
    elif args["model"] == "fast_ts_net":
        model = SCNet(dataset.num_node_features, 64, 2, dataset.num_classes,)
        loss_fn = torch.nn.CrossEntropyLoss()
    elif args["model"] == "fast_ts_net_regression":
        model = SCNet(dataset.num_node_features, 64, 2, 1)
        loss_fn = torch.nn.MSELoss()
    elif args["model"] == "ts_svm":
        model = LinearRegression(dataset.num_node_features, dataset.num_classes,)
        loss_fn = torch.nn.MultiMarginLoss()
    elif args["model"] == "fast_rbf_net":
        model = RBF_SVM(
            dataset.num_node_features, dataset.num_classes, min(500, len(train_ds))
        )
        loss_fn = torch.nn.MultiMarginLoss()
        init_centers(model, train_ds)
    elif args["model"] == "fast_rbf_regression":
        model = RBF_SVM(
            dataset.num_node_features, 1, min(500, len(train_ds))
        )
        loss_fn = torch.nn.MSELoss()
        init_centers(model, train_ds)
    elif args["model"] == "attention_plain":
        model = TSNetAttention(
            dataset.num_node_features, dataset.num_classes, **args["model_args"]
        )
        loss_fn = torch.nn.CrossEntropyLoss()
    elif args["model"] == "attention_rbf":
        model = attention_rbf(
            dataset.num_node_features,
            dataset.num_classes,
            min(500, len(train_ds)),
            **args["model_args"]
        )
        loss_fn = torch.nn.MultiMarginLoss()
        init_centers(model, train_ds)
    elif args["model"] == "second_attention_rbf":
        model = second_attention_rbf(
            dataset.num_node_features,
            dataset.num_classes,
            min(500, len(train_ds)),
            **args["model_args"]
        )
        loss_fn = torch.nn.MultiMarginLoss()
        init_centers(model, train_ds)
    elif args["model"] == "fast_sort":
        model = SortNet(dataset.num_node_features, 64, 2, 5, dataset.num_classes)
        loss_fn = torch.nn.CrossEntropyLoss()
    elif args["model"] == "fast_sort_regression":
        model = SortNet(dataset.num_node_features, 64, 2, 5, 1)
        loss_fn = torch.nn.MSELoss()
    elif args["model"] == "fast_rbf_sort":
        model = SortRBFNet(dataset.num_node_features, 64, 5, min(500, len(train_ds)), dataset.num_classes)
        loss_fn = torch.nn.MultiMarginLoss()
    elif args["model"] == "baseline":
        model = Baseline(dataset.num_node_features, dataset.num_classes)
        loss_fn = torch.nn.CrossEntropyLoss()
    elif args["model"] == "gcn":
        model = GCN(dataset.num_node_features, 64, dataset.num_classes)
        loss_fn = torch.nn.CrossEntropyLoss()
    elif args["model"] == "graph_sage":
        model = GraphSAGE(dataset.num_node_features, dataset.num_classes)
        loss_fn = torch.nn.CrossEntropyLoss()
    elif args["model"] == "gcn_regression":
        model = GCN(dataset.num_node_features, 64, 1)
        loss_fn = torch.nn.MSELoss()
    elif args["model"] == "graph_sage_regression":
        model = GraphSAGE(dataset.num_node_features, 1)
        loss_fn = torch.nn.MSELoss()
    elif args["model"] == "baseline_regression":
        model = Baseline(dataset.num_node_features, 1)
        loss_fn = torch.nn.MSELoss()
    
    else:
        raise NotImplementedError()
    return model, loss_fn


def train_model(in_dir, out_file):
    with open(str(in_dir), "r") as fp:
        args = json.load(fp)
    out_name, out_end = str(out_file).split(".")
    dev_count = torch.cuda.device_count()
    if dev_count == 0:
        device = torch.device("cpu")
    else:
        device = torch.device("cuda:%d" % np.random.randint(dev_count) if torch.cuda.is_available() else "cpu")
    train_ds, val_ds, test_ds, train_loader, dataset = get_dataset(args, device)
    print(dataset)
    print("Device:", device)
    model, loss_fn = get_model(args, dataset, train_ds)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    # Early stopping on validation loss, patience 100 epochs, sampling every 10
    es = EarlyStopping(patience=10, mode="min", percentage=False)
    model = model.to(device)

    model.train()
    start = time.time()
    for epoch in range(1, 1000 + 1):
        for data in train_loader:
            optimizer.zero_grad()
            data = data.to(device)
            out = model(data)
            if len(data.y.shape)>1:
                data.y = torch.squeeze(data.y) # the ogbg-molhiv dataset wraps everything in an extra dimension
            if data.y.dtype == torch.float32:
                data.y = data.y[:, None]
            loss = loss_fn(out, data.y)
            loss.backward()
            optimizer.step()
        if epoch % 10 == 0:
            results = evaluate(model, train_ds, val_ds, test_ds, epoch, loss_fn, device)
            if epoch % 100 == 0:
                torch.save(results, "%s_%d.%s" % (out_name, epoch, out_end))
            metric = (
                results["train_loss"].item(),
                results["train_acc"],
                results["val_loss"].item(),
                results["val_acc"],
            )
            end = time.time()
            total = int(end - start)
            start = end
            print(f'''MODEL {args['model']} EPOCH {epoch} finishes in {total}s
            TRAIN: Loss {results['train_loss']}. Accuracy {results['train_acc']}. Average Precision {results['train_average_precision']}. ROC AUC {results['train_roc_auc']}.
            VAL: Loss {results['val_loss']}. Accuracy {results['val_acc']}. Average Precision {results['val_average_precision']}. ROC AUC {results['val_roc_auc']}.
            ''')
            #print(epoch, 'TL %0.3f TA %0.3f VL %0.3f VA %0.3f Time %0.1f' % (*metric, total))
            if es.step(results["val_loss"], results):
                torch.save(results, "%s_%d.%s" % (out_name, epoch, out_end))
                print("early stopping at epoch %d" % epoch)
                print("best model was at epoch %d" % es.best_model["epoch"])
                print(f''' Achieved TRAIN: Accuracy
                                   {es.best_model["train_acc"]}. Average Precision {es.best_model["train_average_precision"]}.ROC AUC {es.best_model["train_roc_auc"]}.\n
                                   VAL Accuracy {es.best_model["val_acc"]}. Average Precision {es.best_model["val_average_precision"]}. ROC AUC {es.best_model["val_roc_auc"]}.\n
                                   TEST Accuracy {es.best_model["test_acc"]}. Average Precision {es.best_model["test_average_precision"]}. ROC AUC {es.best_model["test_roc_auc"]}.\n
                                   '''
                               )
                print(es.best_model)
                break
    model.eval()
    results = evaluate(model, train_ds, val_ds, test_ds, epoch, loss_fn, device)
    torch.save(es.best_model, str(out_file))
