import os
import json
import time
import datetime

import copy

import sklearn.metrics
import torch
import torch.nn as nn
import torch.nn.functional as F
# import torcheval.metrics as torch_metrics
import torch_geometric

import nagphormer_blocks
from ssl_tasks import clustering, pairsim, dgi, partition, pairdis

# %%


class TransformerEncoder(torch.nn.Module):
    """
    Base class for NAGphormer-based U-SSL model. Consists of graph-specific encoders (stems) and universal
    representation learning module (backbone). Can handle any number of (homogeneous graph) datasets at a time.
    NAGphormer is implemented with modules from `nagphormer_blocks.py`, and is based on the official implementation of
    NAGphormer: A Tokenized Graph Transformer for Node Classification in Large Graphs (ICLR 2023)
    """
    def __init__(self, stems, backbone, num_hops=2, pos_embed_dim=15, dropout_rate=0., attention_dropout_rate=0.1,
                 device=torch.device('cuda')):
        super().__init__()
        self.configs = {
            'stems': stems,
            'backbone': backbone,
        }
        self.device_ = device
        self.sequence_length = num_hops + 1
        self.pos_embed_dim = pos_embed_dim
        self.hidden_dim = backbone['num_layer_features']
        self.ffn_dim = 2 * backbone['num_layer_features']
        self.dropout_rate = dropout_rate
        self.attention_dropout_rate = attention_dropout_rate

        # Construct stems
        self.stem_in_features = [stem['num_node_f'] for stem in stems]
        self.stems = torch.nn.ModuleList()
        for stem in stems:
            self.stems.append(torch.nn.Linear(in_features=stem['num_node_f'], out_features=stem['num_layer_features']))

        # Construct backbone
        backbone_encoders = [
            nagphormer_blocks.EncoderLayer(
                self.hidden_dim,  # attention input dim
                self.ffn_dim,  # MLP input dim
                self.dropout_rate, self.attention_dropout_rate, backbone['num_heads']
            )
            for _ in range(backbone['num_layers'])
        ]
        self.backbone = torch.nn.ModuleList(backbone_encoders)
        self.final_ln = torch.nn.LayerNorm(self.hidden_dim)

        # Construct readout
        self.out_proj = nn.Linear(self.hidden_dim, int(self.hidden_dim / 2))
        self.attn_layer = nn.Linear(2 * self.hidden_dim, 1)
        self.scaling = nn.Parameter(torch.ones(1) * 0.5)

        # Initialisation
        self.apply(lambda module: nagphormer_blocks.init_params(module, num_layers=backbone['num_layers']))

    def forward(self, x):
        stem_id = self.stem_in_features.index(x.shape[-1])
        x = self.stems[stem_id](x)

        # transformer layers
        for backbone_layer in self.backbone:
            x = backbone_layer(x)

        # layer normalisation
        x = self.final_ln(x)

        # attention-based readout
        target = x[:, 0, :].unsqueeze(1).repeat(1, self.sequence_length - 1, 1)
        split_tensor = torch.split(x, [1, self.sequence_length - 1], dim=1)
        node_tensor = split_tensor[0]
        neighbor_tensor = split_tensor[1]
        layer_atten = self.attn_layer(torch.cat((target, neighbor_tensor), dim=2))
        layer_atten = F.softmax(layer_atten, dim=1)
        neighbor_tensor = neighbor_tensor * layer_atten
        neighbor_tensor = torch.sum(neighbor_tensor, dim=1, keepdim=True)
        output = (node_tensor + neighbor_tensor).squeeze(1)
        output = self.out_proj(output)
        return output


class TransformerNodeClassifier(TransformerEncoder):
    """
    Class used for node classification. Can handle only one (homogeneous graph) dataset at a time.
    """
    save_freq = 5

    def __init__(self, stems, backbone, num_hops, num_features, num_classes, device, state_dict=None):
        super().__init__(stems, backbone, num_hops=num_hops, device=device)
        self.training_type = None
        self.lr_scheduler = None
        self.optimizer = None
        self.pretrained = False
        if state_dict is not None:
            self.load_state_dict(state_dict)
            self.pretrained = True
        stem_id = [stem['num_node_f'] for stem in self.configs['stems']].index(num_features)
        self.stems = self.stems[stem_id]
        self.predictor = torch.nn.Linear(int(self.hidden_dim / 2), num_classes)

    def forward(self, x):
        # stem
        x = self.stems(x)

        # transformer layers
        for backbone_layer in self.backbone:
            x = backbone_layer(x)

        # layer normalisation
        x = self.final_ln(x)

        # attention-based readout
        target = x[:, 0, :].unsqueeze(1).repeat(1, self.sequence_length - 1, 1)
        split_tensor = torch.split(x, [1, self.sequence_length - 1], dim=1)
        node_tensor = split_tensor[0]
        neighbor_tensor = split_tensor[1]
        layer_atten = self.attn_layer(torch.cat((target, neighbor_tensor), dim=2))
        layer_atten = F.softmax(layer_atten, dim=1)
        neighbor_tensor = neighbor_tensor * layer_atten
        neighbor_tensor = torch.sum(neighbor_tensor, dim=1, keepdim=True)
        output = (node_tensor + neighbor_tensor).squeeze(1)
        output = self.out_proj(output)
        output = self.predictor(output)
        return output

    def train_model(self, dataset, loss, optimizer, num_epochs, freeze_encoder, save_path=None, verbose=50,
                    lr_adapt=True):
        self.optimizer = optimizer
        if freeze_encoder:
            self.freeze_encoder()
        self.train()
        self.to(self.device_)
        dataset.data.to(self.device_)
        if lr_adapt:
            self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=50)
        best_loss = 1e2
        print('starting training of node classifier')
        start_time = time.time()
        for epoch in range(num_epochs):
            # zero grad
            self.optimizer.zero_grad()

            # forward
            mask = dataset.data.train_mask
            preds = self.forward(x=dataset.data.x)
            train_loss = loss(preds[mask], dataset.data.y[mask])

            # backward and update
            train_loss.backward()
            self.optimizer.step()

            # save
            if save_path is not None and ((epoch + 1) % self.save_freq == 0 or (epoch + 1) == num_epochs):
                mask = dataset.data.val_mask
                val_loss = loss(preds[mask], dataset.data.y[mask]).item()
                if val_loss < best_loss:
                    best_loss = val_loss
                    self.save_model(epoch, save_path)

            # print
            if (epoch + 1) % verbose == 0 or epoch == 0 or (epoch + 1) == num_epochs:
                print(
                    '[' + datetime.datetime.now().strftime(format='%d.%m.%y %H:%M:%S') + ']' +
                    ' epoch: ' + str(epoch + 1) +
                    '; classification_loss: ' + str(round(train_loss.item(), 4))
                )

            # lr schedule
            if self.lr_scheduler is not None:
                self.lr_scheduler.step(train_loss)

        # print
        print('training completed. elapsed time: ', str(time.time() - start_time))

        eval_loss, eval_acc = self.evaluate_model(dataset, loss)
        print('evaluation loss: ', str(round(eval_loss, 4)))
        print('evaluation accuracy: ', str(round(eval_acc, 4)))

        return eval_loss, eval_acc

    def freeze_encoder(self):
        for p in self.stems.parameters():
            p.requires_grad = False
        for p in self.backbone.parameters():
            p.requires_grad = False

    def save_model(self, epoch, save_path):
        save_dict = {
            'model_state_dict': self.state_dict(),
            'optimizer': {
                'type': 'adam',
                'state_dict': self.optimizer.state_dict(),
            },
            'configs': self.configs,
            'epoch': epoch,
        }
        torch.save(save_dict, save_path)

    def evaluate_model(self, dataset, loss):
        self.eval()
        mask = dataset.data.test_mask
        preds = self.forward(x=dataset.data.x)
        eval_loss = loss(preds[mask], dataset.data.y[mask])
        y_hat = preds.argmax(dim=1)
        test_correct = y_hat[mask] == dataset.data.y[mask]
        test_acc = int(test_correct.sum()) / int(mask.sum())
        return eval_loss.item(), test_acc


class TransformerLinkPredictor(TransformerEncoder):
    """
    Class used for link prediction. Can handle only one (homogeneous graph) dataset at a time.
    """
    save_freq = 5

    def __init__(self, stems, backbone, num_hops, num_features, device, state_dict=None):
        super().__init__(stems, backbone, num_hops=num_hops, device=device)
        self.training_type = None
        self.lr_scheduler = None
        self.optimizer = None
        self.pretrained = False
        if state_dict is not None:
            self.load_state_dict(state_dict)
            self.pretrained = True
        stem_id = [stem['num_node_f'] for stem in self.configs['stems']].index(num_features)
        self.stems = self.stems[stem_id]
        self.predictor = torch.nn.Linear(int(self.hidden_dim / 2), 1)

    def forward(self, x):
        x, edge_label_index = x
        x = self.stems(x)

        # transformer layers
        for backbone_layer in self.backbone:
            x = backbone_layer(x)

        # layer normalisation
        x = self.final_ln(x)

        # attention-based readout
        target = x[:, 0, :].unsqueeze(1).repeat(1, self.sequence_length - 1, 1)
        split_tensor = torch.split(x, [1, self.sequence_length - 1], dim=1)
        node_tensor = split_tensor[0]
        neighbor_tensor = split_tensor[1]
        layer_atten = self.attn_layer(torch.cat((target, neighbor_tensor), dim=2))
        layer_atten = F.softmax(layer_atten, dim=1)
        neighbor_tensor = neighbor_tensor * layer_atten
        neighbor_tensor = torch.sum(neighbor_tensor, dim=1, keepdim=True)
        output = (node_tensor + neighbor_tensor).squeeze(1)
        output = self.out_proj(output)
        output = self.predictor(output)

        # link prediction
        source_embeddings = output[edge_label_index[0]]
        target_embeddings = output[edge_label_index[1]]
        logits = (source_embeddings * target_embeddings).sum(dim=-1)
        return logits

    def train_model(self, dataset, loss, optimizer, num_epochs, freeze_encoder, save_path=None, verbose=50,
                    lr_adapt=True):
        self.optimizer = optimizer
        if freeze_encoder:
            self.freeze_encoder()
        self.train()
        self.to(self.device_)
        dataset.data.to(self.device_)
        if lr_adapt:
            self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=50)
        best_loss = 1e2
        print('starting training of link predictor')
        start_time = time.time()
        for epoch in range(num_epochs):
            # zero grad
            self.optimizer.zero_grad()

            # forward
            preds = self.forward(x=(dataset.data.train_data.x, dataset.data.train_data.edge_label_index))
            train_loss = loss(preds, dataset.data.train_data.edge_label)

            # backward and update
            train_loss.backward()
            self.optimizer.step()

            # save
            if save_path is not None and (epoch + 1) % self.save_freq == 0:
                val_loss = self.evaluate_model(dataset.val_loader)
                if val_loss < best_loss:
                    best_loss = val_loss
                    self.save_model(epoch, save_path)

            # print
            if (epoch + 1) % verbose == 0 or epoch == 0 or (epoch + 1) == num_epochs:
                print(
                    '[' + datetime.datetime.now().strftime(format='%d.%m.%y %H:%M:%S') + ']' +
                    ' epoch: ' + str(epoch + 1) +
                    '; prediction_loss: ' + str(round(train_loss.item(), 4))
                )

            # lr schedule
            if self.lr_scheduler is not None:
                self.lr_scheduler.step(train_loss)

        # print
        dataset.data.to(torch.device('cpu'))
        self.to(torch.device('cpu'))
        torch.cuda.empty_cache()
        print('training completed. elapsed time: ', str(time.time() - start_time))

        eval_loss, eval_acc = self.evaluate_model(dataset, loss, self.device_)
        print('evaluation loss: ', str(round(eval_loss, 4)))
        print('evaluation accuracy: ', str(round(eval_acc, 4)))

        return eval_loss, eval_acc

    def freeze_encoder(self):
        for p in self.stems.parameters():
            p.requires_grad = False
        for p in self.backbone.parameters():
            p.requires_grad = False

    def evaluate_model(self, dataset, loss, device):
        self.to(device)
        self.eval()
        preds = self.forward(x=(dataset.data.test_data.x.to(device),
                                dataset.data.test_data.edge_label_index.to(device)))
        eval_loss = loss(preds, dataset.data.test_data.edge_label.to(device))
        test_auc = sklearn.metrics.roc_auc_score(dataset.data.test_data.edge_label.detach().cpu().numpy(),
                                                 preds.detach().cpu().numpy())
        # test_hr = torch_metrics.functional.hit_rate(preds.detach().cpu(),
        #                                             dataset.data.test_data.edge_label.detach().cpu())
        return eval_loss.item(), test_auc


class SSLTransformer(torch.nn.Module):
    """
    Class for constructing a U-SSL model with an object of class TransformerEncoder and heads for SSL task learning. Can
    handle any number of datasets and any number of SSL tasks.
    """
    base_lr = 1e-3
    save_freq = 20
    task_weights = {  # taken from results of "Automated Self-supervised Learning for Graphs (ICLR 2022)"
        'graph partitioning': 0.9,
        'deep graph infomax': 0.65,
        'pair-wise attribute similarity': 0.65,
        'pair-wise distance': 0.1,
        'clustering': 0.1,
    }

    def __init__(self, encoder, data, ssl_tasks):
        super().__init__()
        # model
        self.encoder = encoder
        self.encoder.to(self.encoder.device_)

        # data
        self.data = data

        # ssl tasks
        self.ssl_tasks = ssl_tasks
        self.ssl_objects = self.instantiate_ssl()

        # optimization
        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.base_lr)
        self.head_optimizers = [torch.optim.Adam(ssl_object.predictor.parameters(), lr=self.base_lr)
                                for ssl_object in self.ssl_objects if ssl_object.name != 'graph partitioning']
        if 'partition' in self.ssl_tasks:
            partition_task = [ssl_object for ssl_object in self.ssl_objects if ssl_object.name == 'graph partitioning']
            partition_task = partition_task[0]
            partition_optimizers = [torch.optim.Adam(predictor.parameters(), lr=self.base_lr)
                                    for predictor in partition_task.predictor]
            self.head_optimizers = [*self.head_optimizers, *partition_optimizers]
        self.plateau_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=50,
                                                                               factor=1 / 3., verbose=True)
        self.lin_lr_scheduler = torch.optim.lr_scheduler.LinearLR(self.optimizer)
        self.lr_scheduler = [self.plateau_lr_scheduler, self.lin_lr_scheduler]

    def instantiate_ssl(self):
        ssl_objects = []
        embedding_size = self.encoder.out_proj.out_features
        if 'clustering' in self.ssl_tasks:
            ssl_objects.append(
                clustering.ClusteringTask(data=self.data, embedding_size=embedding_size, device=self.encoder.device_)
            )
        if 'pairsim' in self.ssl_tasks:
            ssl_objects.append(
                pairsim.PairwiseAttrSimTask(data=self.data, embedding_size=embedding_size, device=self.encoder.device_)
            )
        if 'dgi' in self.ssl_tasks:
            ssl_objects.append(
                dgi.DGITask(data=self.data, encoder=self.encoder, embedding_size=embedding_size,
                            device=self.encoder.device_)
            )
        if 'partition' in self.ssl_tasks:
            ssl_objects.append(
                partition.PartitionTask(data=self.data, embedding_size=embedding_size, device=self.encoder.device_)
            )
        if 'pairdis' in self.ssl_tasks:
            ssl_objects.append(
                pairdis.PairwiseDistanceTask(data=self.data, embedding_size=embedding_size, device=self.encoder.device_)
            )
        return ssl_objects

    def pretrain(self, num_epochs=1000, weighted=False, lr_adapt=True, verbose_freq=50, batched=False, save_path=None):
        if lr_adapt is False:
            self.lr_scheduler = None

        if save_path is not None:
            with open(os.path.join(save_path['folder'], 'config.txt'), 'w') as f:
                f.write(json.dumps(self.encoder.configs))
                dataset_names = [dataset.name for dataset in self.data.datasets]
                f.write("\nepochs: " + str(num_epochs))
                f.write("\ndatasets: " + "; ".join(dataset_names))
                f.write("\nSSL tasks: " + ";".join(self.ssl_tasks))
                f.write("\nmodel type: node aggregation graph transformer")
                f.write("\n" + json.dumps({'weighted': weighted}))

        best_loss = 1e2
        self.train()
        print('starting pre-training')
        start_time = time.time()
        printed = False
        for epoch in range(num_epochs):
            # zero grad
            self.optimizer.zero_grad()
            [head_optimizer.zero_grad() for head_optimizer in self.head_optimizers]

            # forward + gradients
            for dataset in self.data.datasets:
                # embeddings
                x = dataset.data.x.to(self.encoder.device_)
                x = self.encoder(x=x)

                # ssl
                ssl_loss = 0
                for ssl_object in self.ssl_objects:
                    if weighted:
                        ssl_loss = ssl_loss +\
                                   self.task_weights[ssl_object.name] * ssl_object.get_loss(x, dataset.name)
                    else:
                        ssl_loss = ssl_loss + ssl_object.get_loss(x, dataset.name)

                # gradients
                ssl_loss.backward()

            # update
            self.optimizer.step()
            [head_optimizer.step() for head_optimizer in self.head_optimizers]

            # save and overwrite model
            if ssl_loss.item() < best_loss and (epoch + 1) % self.save_freq == 0:
                best_loss = copy.deepcopy(ssl_loss.item())
                if save_path is not None:
                    self.save_model(epoch, save_path['file'])
                else:
                    if not printed:
                        print('no save_path provided. skipping saving')
                        printed = True

            # print
            if (epoch + 1) % verbose_freq == 0 or epoch == 0 or epoch == num_epochs - 1:
                print(
                    '[' + datetime.datetime.now().strftime(format='%d.%m.%y %H:%M:%S') + ']' +
                    ' epoch: ' + str(epoch + 1) +
                    '; ssl_loss: ' + str(round(ssl_loss.item(), 4))
                )

            # lr schedule
            if self.plateau_lr_scheduler is not None:
                self.adjust_lr(ssl_loss)

        # print
        print('pre-training completed. elapsed time: ', str(time.time() - start_time))

    def save_model(self, epoch, save_path):
        save_dict = {
            'data': [dataset.name for dataset in self.data.datasets],
            'model_state_dict': self.encoder.state_dict(),
            'optimizer': {
                'type': 'adam',
                'base_lr': self.base_lr,
                'state_dicts': self.optimizer.state_dict(),
            },
            'configs': self.encoder.configs,
            'num_hops': self.encoder.sequence_length - 1,
            'ssl_tasks': self.ssl_tasks,
            'epoch': epoch,
        }
        torch.save(save_dict, save_path)

    def adjust_lr(self, ssl_loss):
        self.lin_lr_scheduler.step()
        self.plateau_lr_scheduler.step(ssl_loss)
