
import os
import random
from sys import argv
from typing import List

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.data import DataLoader
from torch_scatter import scatter_mean
from tqdm import tqdm

from gnn import *
from utils.argparser import argument_parser
from utils.util import load_data


def __loss_aux(output, loss, data, binary_prediction):
    if binary_prediction:
        labels = torch.zeros_like(output).scatter_(
            1, torch.maximum(data.node_labels,torch.zeros_like(data.node_labels)).unsqueeze(1), 1.)
    else:
        raise NotImplementedError()
    mask=torch.where(data.node_labels>=0,1,0).unsqueeze(1)
    return nn.BCEWithLogitsLoss(reduction='mean',weight=mask)(output, labels)


def train(
        model,
        device,
        training_data,
        optimizer,
        criterion,
        scheduler,
        binary_prediction=True) -> float:
    model.train()

    loss_accum = []

    for data in tqdm(training_data):
        #data = data.to(device)
        for i in range(len(data)):
            data[i]=data[i].to(device)
        edge_indexes=[]
        edge_attrs=[]
        batches=[]
        for i in range(len(data)):
            edge_indexes.append(data[i].edge_index)
            edge_attrs.append(data[i].edge_attr)
            batches.append(data[i].batch)
        output = model(x=data[0].x,
                       edge_index=edge_indexes,
                       edge_attr=edge_attrs,
                       batch=batches)

        loss = __loss_aux(
            output=output,
            loss=criterion,
            data=data[-1],
            binary_prediction=binary_prediction)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        loss_accum.append(loss.detach().cpu().numpy())

    average_loss = np.mean(loss_accum)

    print(f"Train loss: {average_loss}")

    return average_loss, loss_accum


def __accuracy_aux(node_labels, predicted_labels, batch, device):

    mask=torch.where(node_labels>=0,1,0)
    ind=torch.range(1,node_labels.shape[0],1,dtype=torch.int32).to(mask.device)
    mask=mask*ind
    mask=mask[mask>0]-1
    node_labels_filtered=node_labels[mask]
    predicted_labels_filtered=predicted_labels[mask]
    results = torch.eq(
        predicted_labels_filtered,
        node_labels_filtered).type(
        torch.FloatTensor).to(device)

    # micro average -> mean between all nodes
    micro = torch.sum(results)

    # macro average -> mean between the mean of nodes for each graph
    macro = micro

    return micro, macro


def test(
        model,
        device,
        criterion,
        epoch,
        train_data,
        test_data1,
        test_data2=None,
        binary_prediction=True):
    model.eval()

    # ----- TRAIN ------
    train_micro_avg = 0.
    train_macro_avg = 0.

    if train_data is not None:
        n_nodes = 0
        n_graphs = 0
        for data in train_data:
            for i in range(len(data)):
                data[i]=data[i].to(device)
            edge_indexes=[]
            edge_attrs=[]
            batches=[]
            for i in range(len(data)):
                edge_indexes.append(data[i].edge_index)
                edge_attrs.append(data[i].edge_attr)
                batches.append(data[i].batch)

            with torch.no_grad():
                output = model(
                    x=data[0].x,
                    edge_index=edge_indexes,
                    edge_attr=edge_attrs,
                    batch=batches)

            output = torch.sigmoid(output)
            _, predicted_labels = output.max(dim=1)

            micro, macro = __accuracy_aux(
                node_labels=data[-1].node_labels,
                predicted_labels=predicted_labels,
                batch=data[-1].batch, device=device)

            train_micro_avg += micro.cpu().numpy()
            train_macro_avg += macro.cpu().numpy()
            n_nodes += data[-1].num_nodes
            n_graphs += data[-1].num_graphs

        train_micro_avg = train_micro_avg / n_nodes
        train_macro_avg = train_macro_avg / n_graphs

    # ----- /TRAIN ------

    # ----- TEST 1 ------
    test1_micro_avg = 0.
    test1_macro_avg = 0.
    test1_loss = []
    test1_avg_loss = 0.

    if test_data1 is not None:
        n_nodes = 0
        n_graphs = 0
        for data in test_data1:
            for i in range(len(data)):
                data[i]=data[i].to(device)
            edge_indexes=[]
            edge_attrs=[]
            batches=[]
            for i in range(len(data)):
                edge_indexes.append(data[i].edge_index)
                edge_attrs.append(data[i].edge_attr)
                batches.append(data[i].batch)


            with torch.no_grad():
                output = model(
                    x=data[0].x,
                    edge_index=edge_indexes,
                    edge_attr=edge_attrs,
                    batch=batches)

            loss = __loss_aux(
                output=output,
                loss=criterion,
                data=data[-1],
                binary_prediction=binary_prediction)

            test1_loss.append(loss.detach().cpu().numpy())

            output = torch.sigmoid(output)
            _, predicted_labels = output.max(dim=1)

            micro, macro = __accuracy_aux(
                node_labels=data[-1].node_labels,
                predicted_labels=predicted_labels,
                batch=data[-1].batch, device=device)

            test1_micro_avg += micro.cpu().numpy()
            test1_macro_avg += macro.cpu().numpy()
            n_nodes += data[-1].num_nodes
            n_graphs += data[-1].num_graphs

        test1_avg_loss = np.mean(test1_loss)

        test1_micro_avg = test1_micro_avg / n_nodes
        test1_macro_avg = test1_macro_avg / n_graphs

    # ----- /TEST 1 ------

    # ----- TEST 2 ------
    test2_micro_avg = 0.
    test2_macro_avg = 0.
    test2_loss = []
    test2_avg_loss = 0.

    if test_data2 is not None:
        n_nodes = 0
        n_graphs = 0
        for data in test_data2:
            for i in range(len(data)):
                data[i]=data[i].to(device)
            edge_indexes=[]
            edge_attrs=[]
            batches=[]
            for i in range(len(data)):
                edge_indexes.append(data[i].edge_index)
                edge_attrs.append(data[i].edge_attr)
                batches.append(data[i].batch)


            with torch.no_grad():
                output = model(
                    x=data[0].x,
                    edge_index=edge_indexes,
                    edge_attr=edge_attrs,
                    batch=batches)

            loss = __loss_aux(
                output=output,
                loss=criterion,
                data=data[-1],
                binary_prediction=binary_prediction)

            test2_loss.append(loss.detach().cpu().numpy())

            output = torch.sigmoid(output)
            _, predicted_labels = output.max(dim=1)

            micro, macro = __accuracy_aux(
                node_labels=data[-1].node_labels,
                predicted_labels=predicted_labels,
                batch=data[-1].batch, device=device)

            test2_micro_avg += micro.cpu().numpy()
            test2_macro_avg += macro.cpu().numpy()
            n_nodes += data[-1].num_nodes
            n_graphs += data[-1].num_graphs

        test2_avg_loss = np.mean(test2_loss)

        test2_micro_avg = test2_micro_avg / n_nodes
        test2_macro_avg = test2_macro_avg / n_graphs

    # ----- /TEST 2 ------

    print(
        f"Train accuracy: micro: {train_micro_avg}\tmacro: {train_macro_avg}")
    print(f"Test1 loss: {test1_avg_loss}")
    print(f"Test2 loss: {test2_avg_loss}")
    print(f"Test accuracy: micro: {test1_micro_avg}\tmacro: {test1_macro_avg}")
    print(f"Test accuracy: micro: {test2_micro_avg}\tmacro: {test2_macro_avg}")

    return (train_micro_avg, train_macro_avg), \
        (test1_avg_loss, test1_micro_avg, test1_macro_avg), \
        (test2_avg_loss, test2_micro_avg, test2_macro_avg)


def seed_everything(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def main(
        args,
        manual,
        train_data=None,
        test1_data=None,
        test2_data=None,
        n_classes=None,
        save_model=None,
        load_model=None,
        train_model=True,
        plot=None,
        truncated_fn=None):
    # set up seeds and gpu device
    seed_everything(args.seed)

    if torch.cuda.is_available():
        device = torch.device("cuda:" + str(args.device))
    else:
        device = torch.device("cpu")

    if not manual:
        raise NotImplementedError()

    else:
        assert train_data is not None
        assert test1_data is not None
        assert test2_data is not None
        assert n_classes is not None
        # manual settings
        print("Using preloaded data")
        train_graphs = train_data
        test_graphs1 = test1_data
        test_graphs2 = test2_data

        if args.task_type == "node":
            num_classes = n_classes
        else:
            raise NotImplementedError()

    # np.random.shuffle(train_graphs)
    pin=True
    train_loader = DataLoader(
        train_graphs,
        batch_size=args.batch_size,
        shuffle=True,
        pin_memory=pin,
        num_workers=0)
    test1_loader = DataLoader(
        test_graphs1,
        batch_size=512,
        pin_memory=pin,
        num_workers=0)
    test2_loader = DataLoader(
        test_graphs2,
        batch_size=512,
        pin_memory=pin,
        num_workers=0)

    if args.network == "acgnn":
        _model = ACGNN
    elif args.network == "acrgnn":
        _model = ACRGNN
    elif args.network == "acrgnn-single":
        _model = SingleACRGNN
    elif args.network == "gin":
        _model = GIN
    else:
        raise ValueError()

    model = _model(
        input_dim=train_graphs[0][0].num_features,
        hidden_dim=args.hidden_dim,
        output_dim=num_classes,
        num_layers=args.num_layers,
        aggregate_type=args.aggregate,
        readout_type=args.readout,
        combine_type=args.combine,
        combine_layers=args.combine_layers,
        num_mlp_layers=args.num_mlp_layers,
        task=args.task_type,
        time_range=args.time_range,
        num_relation=args.num_relation,
        truncated_fn=truncated_fn)

    if load_model is not None:
        print("Loading Model")
        model.load_state_dict(torch.load(load_model))

    model = model.to(device)

    criterion = nn.BCEWithLogitsLoss(reduction='mean')
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

    if not args.filename == "":
        with open(args.filename, 'w') as f:
            f.write(
                "train_loss,test1_loss,test2_loss,train_micro,train_macro,test1_micro,test1_macro,test2_micro,test2_macro\n")

            with open(args.filename + ".train", 'w') as f:
                f.write(
                    "train_loss\n")
            with open(args.filename + ".test", 'w') as f:
                f.write(
                    "test1_loss,test2_loss\n")

    if train_model:
        # `epoch` is only for printing purposes
        for epoch in range(1, args.epochs + 1):

            print(f"Epoch {epoch}/{args.epochs}")

            # TODO: binary prediction
            avg_loss, loss_iter = train(
                model=model,
                device=device,
                training_data=train_loader,
                optimizer=optimizer,
                criterion=criterion,
                scheduler=scheduler,
                binary_prediction=True)

            (train_micro, train_macro), (test1_loss, test1_micro, test1_macro), (test2_loss, test2_micro, test2_macro) = test(
                model=model, device=device, train_data=train_loader, test_data1=test1_loader, test_data2=test2_loader, epoch=epoch, criterion=criterion)

            file_line = f"{avg_loss: .10f}, {test1_loss: .10f}, {test2_loss: .10f}, {train_micro: .8f}, {train_macro: .8f}, {test1_micro: .8f}, {test1_macro: .8f}, {test2_micro: .8f}, {test2_macro: .8f}"

            if not args.filename == "":
                with open(args.filename, 'a') as f:
                    f.write(file_line + "\n")

            if not args.filename == "":
                with open(args.filename + ".train", 'a') as f:
                    for l in loss_iter:
                        f.write(f"{l: .15f}\n")

                with open(args.filename + ".test", 'a') as f:
                    f.write(f"{test1_loss: .15f}, {test2_loss: .15f}\n")

        if save_model is not None:
            torch.save(model.state_dict(), save_model)

        if plot is not None:
            iter_losses = np.loadtxt(args.filename + ".train", skiprows=1)
            epoch_t1_losses, epoch_t2_losses = np.loadtxt(
                args.filename + ".test", delimiter=",", skiprows=1).T

            iters = np.arange(len(iter_losses))

            batch = (len(iter_losses) / len(epoch_t1_losses))
            epochs = np.arange(len(epoch_t1_losses)) * batch + batch

            plt.figure(figsize=(16, 10))
            plt.plot(
                iters,
                iter_losses,
                color="#377eb8",
                marker="*",
                linestyle="-",
                label="Train")
            plt.plot(
                epochs,
                epoch_t1_losses,
                color="#ff7f00",
                marker="o",
                linestyle="-",
                label="Test1")
            plt.plot(
                epochs,
                epoch_t2_losses,
                color="#4daf4a",
                marker="x",
                linestyle="-",
                label="Tets2")

            plt.title(
                f"{plot.split('/')[-1].split('.')[0]} - H{args.hidden_dim} - B{args.batch_size} - L{args.num_layers} - Epochs{args.epochs}")

            plt.ylim(bottom=0)
            plt.legend(loc='upper right')
            plt.savefig(plot, dpi=150, bbox_inches='tight')
            plt.close()

        return file_line + "\n"

    else:

        (train_micro, train_macro), (test1_loss, test1_micro, test1_macro), (test2_loss, test2_micro, test2_macro) = test(
            model=model, device=device, train_data=train_loader, test_data1=test1_loader, test_data2=test2_loader, epoch=-1, criterion=criterion)

        file_line = f" {-1: .8f}, {test1_loss: .10f}, {test2_loss: .10f}, {train_micro: .8f}, {train_macro: .8f}, {test1_micro: .8f}, {test1_macro: .8f}, {test2_micro: .8f}, {test2_macro: .8f}"

        if not args.filename == "":
            with open(args.filename, 'a') as f:
                f.write(file_line + "\n")

        return file_line + ","


if __name__ == '__main__':

    # agg, read, comb
    _networks = [
        [{"mean": "A"}, {"mean": "A"}, {"simple": "T"}],
        [{"mean": "A"}, {"mean": "A"}, {"mlp": "MLP"}],
        [{"mean": "A"}, {"max": "M"}, {"simple": "T"}],
        [{"mean": "A"}, {"max": "M"}, {"mlp": "MLP"}],
        [{"mean": "A"}, {"add": "S"}, {"simple": "T"}],
        [{"mean": "A"}, {"add": "S"}, {"mlp": "MLP"}],

        [{"max": "M"}, {"mean": "A"}, {"simple": "T"}],
        [{"max": "M"}, {"mean": "A"}, {"mlp": "MLP"}],
        [{"max": "M"}, {"max": "M"}, {"simple": "T"}],
        [{"max": "M"}, {"max": "M"}, {"mlp": "MLP"}],
        [{"max": "M"}, {"add": "S"}, {"simple": "T"}],
        [{"max": "M"}, {"add": "S"}, {"mlp": "MLP"}],

        [{"add": "S"}, {"mean": "A"}, {"simple": "T"}],
        [{"add": "S"}, {"mean": "A"}, {"mlp": "MLP"}],
        [{"add": "S"}, {"max": "M"}, {"simple": "T"}],
        [{"add": "S"}, {"max": "M"}, {"mlp": "MLP"}],
        [{"add": "S"}, {"add": "S"}, {"simple": "T"}],
        [{"add": "S"}, {"add": "S"}, {"mlp": "MLP"}],
    ]

    h = 64

    file_path = "data"
    data_path = "datasets"
    extra_name = "results/"

    print("Start running")
    data_dir = '.'
    import sys
    for key in [sys.argv[1]]:
        for enum, _set in enumerate([
            [(f"{data_dir}/{key}/train-random-erdos-5000-40-50",
              f"{data_dir}/{key}/test-random-erdos-500-40-50",)
             ],
        ]):

            for index, (_train, _test1) in enumerate(_set):

                print(f"Start for dataset {_train}-{_test1}")

                _train_graphs, (_, _, _n_node_labels) = load_data(
                    dataset=f"{file_path}/{data_path}/{_train}.txt",
                    degree_as_node_label=False)

                _test_graphs, _ = load_data(
                    dataset=f"{file_path}/{data_path}/{_test1}.txt",
                    degree_as_node_label=False)

                #_test_graphs2, _ = load_data(
                #    dataset=f"{file_path}/{data_path}/{_test2}.txt",
                #    degree_as_node_label=False)

                for _net_class in [
                    "acgnn",
                    #"gin",
                    "acrgnn",
                    # "acrgnn-single"
                ]:

                    filename = f"./logging/{extra_name}{key}-{enum}-{index}.mix"

                    for a, r, c in _networks:
                        (_agg, _agg_abr) = list(a.items())[0]
                        (_read, _read_abr) = list(r.items())[0]
                        (_comb, _comb_abr) = list(c.items())[0]

                        for comb_layers in [1, 2]:

                            if _net_class == "acgnn" and (
                                    _read == "max" or _read == "add"):
                                continue
                            elif _net_class == "gin" and (_agg == "mean" or _agg == "max" or _comb == "mlp" or _read == "max" or _read == "add" or comb_layers > 1):
                                continue

                            if _comb == "mlp" and comb_layers > 1:
                                continue

                            for l in range(1, 4):
                                for lr in [0.01]:
                                    for h in [10,64,100]:
                                        print(a, r, c, _net_class, l, comb_layers)

                                        run_filename = f"./logging/{extra_name}{key}-{enum}-{index}-{_net_class}-agg{_agg_abr}-read{_read_abr}-comb{_comb_abr}-cl{comb_layers}-L{l}.log"
                                        time_range=int(sys.argv[2])
                                        num_relation=int(sys.argv[3])
                                        _args = argument_parser().parse_args(
                                        [
                                        f"--readout={_read}",
                                        f"--aggregate={_agg}",
                                        f"--combine={_comb}",
                                        f"--network={_net_class}",
                                        f"--filename={run_filename}",
                                        "--epochs=100",
                                        f"--batch_size=128",
                                        f"--hidden_dim={h}",
                                        f"--num_layers={l}",
                                        f"--combine_layers={comb_layers}",
                                        f"--num_mlp_layers=2",
                                        "--device=0",
                                        f"--lr={lr}",
                                        f"--time_range={time_range}",
                                        f"--num_relation={num_relation}"
                                        ])

                                        line = main(
                                        _args,
                                        manual=True,
                                        train_data=_train_graphs,
                                        test1_data=_test_graphs,
                                        test2_data=_test_graphs,
                                        n_classes=_n_node_labels,
                                    # save_model=f"{file_path}/saved_models/{extra_name}{key}/MODEL-{_net_class}-{enum}-agg{_agg_abr}-read{_read_abr}-comb{_comb_abr}-cl{comb_layers}-L{l}-H{h}.pth",
                                        train_model=True,
                                    # load_model=f"saved_models/h32/MODEL-{_net_class}-{key}-{enum}-agg{_agg_abr}-read{_read_abr}-comb{_comb_abr}-L{l}.pth",
                                    # plot=f"plots/{run_filename}.png",
                                        truncated_fn=None
                                        )

                                # append results per layer
                                        with open(filename, 'a') as f:
                                            f.write(_net_class+' '+str(lr)+' '+str(h)+' '+str(l)+' '+str(comb_layers)+':'+line+'\n')
