#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@author: ----
"""

import torch
from torch_geometric.data import Data, DataLoader

import argparse
import os.path

from utils import load_predicates, encode_input_dataset, encode_entailed_dataset, create_training_dataset
from gnn_architectures import GNN

parser = argparse.ArgumentParser(description="Train the GNNs")
parser.add_argument('--dataset-name',
                    help='Name of the dataset, for saving and loading files, including predicates files, '
                         'and the learnt model')
parser.add_argument('--training-scheme',
                    default='from-rules',
                    nargs='?',
                    choices=['from-rules', 'from-data'],
                    help='Choose if GNN trains from data or from ruleset')
parser.add_argument('--encoding-scheme',
                    default='NEC',
                    nargs='?',
                    choices=['NEC', 'EC'],
                    help='Choose whether to encode with edge colours or not')
parser.add_argument('--expand-neighbourhood',
                    action='store_true',
                    help='Choose whether to expand the neighbourhood of the training from rules examples')
parser.add_argument('--train-graph',
                    nargs='?',
                    default=None,
                    help='Filename of training data with graph, including extension.')
parser.add_argument('--train-facts',
                    nargs='?',
                    default=None,
                    help='Filename of training data with facts, including extension.')

args = parser.parse_args()

saved_model_name = args.dataset_name + "_" + args.training_scheme + "_" + args.encoding_scheme

if __name__ == "__main__":
    
    print("Running training of " +
          "{} using {} encoding, {} training scheme...".format(args.dataset_name, args.encoding_scheme,
                                                               args.training_scheme))
    print("Training model {}...".format(saved_model_name))
    
    if args.training_scheme == 'from-data':
        # We need to create the filepaths for the training data.
        train_data_path = args.train_graph
        # We check that they exist.
        assert os.path.exists(train_data_path)
        print("Loading graph data from {}".format(train_data_path))
        #  This contains the positive examples.
        train_data_output_path = args.train_facts 
        assert os.path.exists(train_data_output_path)
        print("Loading graph data from {}".format(train_data_output_path))

    if args.expand_neighbourhood:
        assert args.training_scheme == 'from-rules',\
            "Expanding neighbourhood doesn't make sense in the from data training scheme"

    use_edge_colours = args.encoding_scheme == 'EC'                            
    
    # Load binary and unary predicates into memory
    binaryPredicates, unaryPredicates = load_predicates(args.dataset_name)
    num_binary = len(binaryPredicates)
    num_unary = len(unaryPredicates)
    mask_threshold = num_binary + num_unary
    
    # Select the device where the test will be run
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    if args.training_scheme == 'from-data':

        # Load the training input dataset.
        train_dataset = []
        with open(train_data_path, 'r') as train:
            for RDF_fact in train:
                #  Remove end of line character before appending.
                assert RDF_fact.endswith('\n'), RDF_fact
                train_dataset.append(RDF_fact[:-1])

        # Load the training positive examples.
        train_output_dataset = []
        with open(train_data_output_path, 'r') as train_output:
            for RDF_fact in train_output:
                #  Remove end of line character before appending.
                assert(RDF_fact.endswith('\n'))
                train_output_dataset.append(RDF_fact[:-1])

        """ Encode training initial RDF dataset into graph form, which considers also the entailed dataset.
         Here is a more detailed list of the output:
         --- train_x : torch.FloatTensor of size i x j, with i the number of nodes in the graph, 
                        and j the length of feature vectors 
         --- train_edge_index : torch.LongTensor representing all edges in the graph, each edge is a pair of integers,
                                with each integer encoding a node of the graph.
         --- train_edge_types : torch.LongTensor of the colours used in the graph, encoded as 0: (a,b)-(a),
                                                                                              1: (a,b)-(b),
                                                                                              2: (a,b)-(b,a) 
                                                                                              3: (a)-(b) [for (a,b)]
         --- train_node_to_const_dict: dictionary mapping each node (encoded as an integer) to the constant or
                                        pair of constants that it represents.
         --- train_const_to_node_dict: reverse of the dictionary above.
         --- train_pred_dict: map of every position in the feature vector to the predicate it represents
        """
        (train_x, train_edge_index, train_edge_types,
         train_node_to_const_dict, train_const_to_node_dict,
         train_pred_dict, _) = encode_input_dataset(args.encoding_scheme,
                                                 train_dataset,
                                                 train_output_dataset,
                                                 binaryPredicates,
                                                 unaryPredicates,
                                                 training=True)

        # Encode entailed RDF dataset into vector form. The result is a torch.FloatTensor of size i x j, with i the
        # Number of nodes in the graph, and j the length of feature vectors.
        train_y = encode_entailed_dataset(train_output_dataset, train_node_to_const_dict,
                                          train_const_to_node_dict, train_pred_dict)

        # Convert to PyTorch Geometric Data objects
        # Data: "A plain old python object modeling a single graph with various (optional) attributes"
        #        Please note that edge_type is a custom attribute of the function, NOT related to the optional
        #        attribute edge_attr.
        train_data = Data(x=train_x, y=train_y, edge_index=train_edge_index, edge_type=train_edge_types)
        # DataLoader: "Data loader which merges data objects from a torch_geometric.data.dataset to a mini-batch."
        #  Note that list train_data.to(device) is a Dataset. DataLoader only uses two methods within
        #  the dataset argument: __length__, and __getitem__, so it works with a list like this.
        train_loader = DataLoader(dataset=[train_data.to(device)], batch_size=1)
        
        # Sanity check: feature vector length must be equal to the number of unary and binary predicates.
        assert len(train_data.x[0]) == mask_threshold

    else:
        # If the training scheme is not 'from-data', it must be 'from-rules'
        assert args.training_scheme == 'from-rules'
        
        # Construct the training dataset directly from the rules
        dataset, train_node_dicts = create_training_dataset(args.dataset_name,
                                                            args.encoding_scheme,
                                                            unaryPredicates,
                                                            binaryPredicates,
                                                            expand_neighbourhood=args.expand_neighbourhood)
    
        # Sanity check
        assert len(dataset[0].x[0]) == mask_threshold

        data_list = [data.to(device) for data in dataset]
        train_loader = DataLoader(data_list, batch_size=len(data_list))
    
    # Define the GNN. Note that if edge colours are used, there are 4 edge colours.
    model = GNN(mask_threshold, use_edge_colours, num_edge_types=4).to(device)
    # Select Adam as the optimisation algorithm
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    
    def train():
        # Set module in training mode (this method is inherited from torch.nn.Module)
        model.train()
        
        total_loss = 0

        # Notice how here we are iterating over the elements of train_loader, according to the documentation is
        # a DataLoader, which in turn means that iteration is entirely controlled by the iterable data structure
        # that implements whichever Dataset argument was used on creation on the DataLoader. In our case, the Dataset
        # is a Pytorch Geometric Data object, which provides an iterable method where it simply provides a tuple with
        # attributes, their names and values. In short, a batch here is iterating through 4-tuples of the form
        #  x, y, edge_index, edge_type
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            if args.training_scheme == 'from-rules':
                # Then we require the mask logic.
                # Here we've appended the masks onto the end of each y vector
                # to ensure they stay with the correct batch. I now know there
                # are less confusing ways of doing this: the PTG Data object
                # should pass things around within batches correctly
                # automatically.
                mask = batch.y.T[mask_threshold:].T
                y = batch.y.T[:mask_threshold].T
            else:
                y = batch.y
            # Construct a weight matrix with weight of 5.0 wherever there is a
            # 1 output in the y vector, 0.5 where there is a 0.
            weight = torch.tensor([0.5, 5.0]).to(device)
            # .data is a tensor method that gives you the values; .long() transforms it to long format
            # ALSO: bear in mind that y is going to be a single number, because it is just one element in the batch
            # ALSO: view_as is an operation of tensors to make it look the same size as y: so essentially we are looking
            #       at weight as a tensor of the same size as y.
            weight_ = weight[y.data.long()].view_as(y)
            # Compute GNN output
            # Instances of modules are callable, and what happens on the call depends on whether there are `hooks`.
            # There aren't in this case, in which case the call uses the `forward` method inside the model. And indeed,
            # the forward method extracts named attributes from its input which coincide with the names of the
            # attributes in the object `batch` that we pass as input to the instance `model` of this Module
            output = model(batch)
            # Target label
            label = y.to(device)
            lossFunc = torch.nn.BCELoss(reduction='none')
            # Compute loss matrix, to be reduced later
            loss = lossFunc(output, label)
            
            # Double check we're not getting NaNs
            assert(not (loss != loss).any())
            if args.training_scheme == 'from-rules':
                loss = loss * weight_ * mask
            else:
                loss = loss * weight_
            # Use sum reduction on loss, backpropagate
            loss.sum().backward()
            optimizer.step()
            # Any weight components < 0 are immediately "clamped" to 0, but not the bias
            for name, param in model.named_parameters():
                if 'bias' not in name:
                    param.data.clamp_(0)
            total_loss += batch.num_graphs * loss.sum().item()
        return total_loss

    # Implementing a form of early stopping. Keep track of the lowest loss
    # achieved, if we've had n epochs (to be specified) only achieving higher
    # losses than the lowest one recorded, then stop early.
    min_loss = None
    num_bad_iterations = 0
    
    if args.training_scheme == 'from-data':
        # How often we'll report progress of GNN
        divisor = 200
        # Maximum number of epochs reporting higher loss than lowest achieved
        # before we stop early
        max_num_bad = 50
    else:
        divisor = 500
        max_num_bad = 1000
        
    print("Training...")
    # Train for a maximum of 50000 epochs, but expect to always stop early
    for epoch in range(50001):
        loss = train()
        if min_loss is None: min_loss = loss
        if epoch % divisor == 0:
            print('Epoch: {:03d}, Loss: {:.5f}'.
                  format(epoch, loss))
            if epoch % 1000 == 0:
                torch.save(model,
                           "./models/checkpoints/" +
                           "{}_Epoch{}.pt".format(args.dataset_name, epoch))
        if loss >= min_loss:
            num_bad_iterations += 1
            if num_bad_iterations > max_num_bad:
                print("Stopping early")
                break
        else:
            num_bad_iterations = 0
            min_loss = loss
            
    torch.save(model, './models/' + saved_model_name + '.pt')

