import os
import json
import time
import datetime

import copy

import sklearn.metrics
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric

from ssl_tasks import clustering, pairsim, dgi, partition, pairdis

# %%
LAYER_MAP = {
    'lin': torch.nn.Linear,
    'gcn': torch_geometric.nn.GCNConv,
    'sage': torch_geometric.nn.SAGEConv,
    'gat': torch_geometric.nn.GATConv,
    'gtx': torch_geometric.nn.TransformerConv,
}

ACT_MAP = {
    'relu': torch.nn.ReLU(inplace=True),
    'prelu': torch.nn.PReLU(),
}

# %%


class GNNEncoder(torch.nn.Module):
    """
    Base class for GNN-based U-SSL. Consists of graph-specific encoders (stems) and universal representation learning
    module (backbone). Can handle any number of (homogeneous graph) datasets at a time.
    """
    def __init__(self, stems, backbone, device=torch.device('cuda')):
        super().__init__()
        self.configs = {
            'stems': stems,
            'backbone': backbone,
        }
        self.device_ = device

        # Construct stems
        self.stem_in_features = [stem['num_node_f'] for stem in stems]
        self.stems = torch.nn.ModuleList()
        num_stems = len(stems)
        for stem_id in range(num_stems):
            stem = stems[stem_id]
            stem_feature_sizes = [stem['num_node_f'], *stem['num_layer_features']]
            layers = []
            for i in range(1, len(stem_feature_sizes)):
                if stem['layer_type'] != 'lin':
                    layers.append(
                        (
                            LAYER_MAP[stem['layer_type']](in_channels=stem_feature_sizes[i - 1],
                                                          out_channels=stem_feature_sizes[i]),
                            'x, edge_index -> x'
                        )
                    )
                else:
                    layers.append(LAYER_MAP[stem['layer_type']](in_features=stem_feature_sizes[i - 1],
                                                                out_features=stem_feature_sizes[i]))
                layers.append(ACT_MAP[stem['act']])
            if stem['layer_type'] != 'lin':
                self.stems.append(torch_geometric.nn.Sequential('x, edge_index', layers))
            else:
                self.stems.append(torch.nn.Sequential(*layers))

        # Construct backbone
        layers = []
        backbone_feature_sizes = [backbone['num_in_features'], *backbone['num_layer_features']]
        for i in range(1, len(backbone_feature_sizes)):
            if backbone['layer_type'] in ['gtx', 'gat']:
                layers.append(
                    (
                        LAYER_MAP[backbone['layer_type']](in_channels=backbone_feature_sizes[i - 1],
                                                          out_channels=backbone_feature_sizes[i],
                                                          heads=backbone['num_heads'][i - 1]),
                        'x, edge_index -> x'
                    )
                )
                layers.append(torch.nn.Linear(in_features=layers[-1][0].out_channels * layers[-1][0].heads,
                                              out_features=backbone_feature_sizes[i]))
                layers.append(ACT_MAP[backbone['act']])
            else:
                layers.append(
                    (
                        LAYER_MAP[backbone['layer_type']](in_channels=backbone_feature_sizes[i - 1],
                                                          out_channels=backbone_feature_sizes[i]),
                        'x, edge_index -> x'
                    )
                )
                layers.append(ACT_MAP[backbone['act']])
        self.backbone = torch_geometric.nn.Sequential('x, edge_index', layers)

    def forward(self, x, edge_index):
        stem_id = self.stem_in_features.index(x.shape[-1])
        if self.configs['stems'][stem_id]['layer_type'] != 'lin':
            x = self.stems[stem_id](x, edge_index)
        else:
            x = self.stems[stem_id](x)
        x = self.backbone(x, edge_index)
        return x


class GNNNodeClassifier(GNNEncoder):
    """
    Class used for node classification. Can handle only one (homogeneous graph) dataset at a time.
    """
    save_freq = 5

    def __init__(self, stems, backbone, num_features, num_classes, device, state_dict=None):
        super().__init__(stems, backbone, device)
        self.training_type = None
        self.lr_scheduler = None
        self.optimizer = None
        self.pretrained = False
        if state_dict is not None:
            self.load_state_dict(state_dict)
            self.pretrained = True
        stem_id = [stem['num_node_f'] for stem in self.configs['stems']].index(num_features)
        self.stems = self.stems[stem_id]
        if self.configs['backbone']['layer_type'] in ['gat', 'gtx']:
            self.predictor = torch.nn.Linear(in_features=self.backbone[-2].out_features,
                                             out_features=num_classes)
        else:
            self.predictor = torch.nn.Linear(in_features=self.backbone[-2].out_channels,
                                             out_features=num_classes)

    def forward(self, x, edge_index):
        if self.configs['stems'][0]['layer_type'] != 'lin':
            x = self.stems(x, edge_index)
        else:
            x = self.stems(x)
        x = self.backbone(x, edge_index)
        x = self.predictor(x)
        return x

    def train_model(self, dataset, loss, optimizer, num_epochs, freeze_encoder, save_path=None, verbose=50,
                    lr_adapt=True):
        self.optimizer = optimizer
        if freeze_encoder:
            self.freeze_encoder()
        self.train()
        self.to(self.device_)
        dataset.data.to(self.device_)
        if lr_adapt:
            self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=50)
        best_loss = 1e2
        print('starting training of node classifier')
        start_time = time.time()
        for epoch in range(num_epochs):
            # zero grad
            self.optimizer.zero_grad()

            # forward
            mask = dataset.data.train_mask
            preds = self.forward(x=dataset.data.x, edge_index=dataset.data.edge_index)
            train_loss = loss(preds[mask], dataset.data.y[mask])

            # backward and update
            train_loss.backward()
            self.optimizer.step()

            # save
            if save_path is not None and (epoch + 1) % self.save_freq == 0:
                mask = dataset.data.val_mask
                val_loss = loss(preds[mask], dataset.data.y[mask]).item()
                if val_loss < best_loss:
                    best_loss = val_loss
                    self.save_model(epoch, save_path)

            # print
            if (epoch + 1) % verbose == 0 or epoch == 0 or (epoch + 1) == num_epochs:
                print(
                    '[' + datetime.datetime.now().strftime(format='%d.%m.%y %H:%M:%S') + ']' +
                    ' epoch: ' + str(epoch + 1) +
                    '; classification_loss: ' + str(round(train_loss.item(), 4))
                )

            # lr schedule
            if self.lr_scheduler is not None:
                self.lr_scheduler.step(train_loss)

        # print
        print('training completed. elapsed time: ', str(time.time() - start_time))

        eval_loss, eval_acc = self.evaluate_model(dataset, loss)
        print('evaluation loss: ', str(round(eval_loss, 4)))
        print('evaluation accuracy: ', str(round(eval_acc, 4)))

        return eval_loss, eval_acc

    def freeze_encoder(self):
        for p in self.stems.parameters():
            p.requires_grad = False
        for p in self.backbone.parameters():
            p.requires_grad = False

    def save_model(self, epoch, save_path):
        save_dict = {
            'model_state_dict': self.state_dict(),
            'optimizer': {
                'type': 'adam',
                'state_dict': self.optimizer.state_dict(),
            },
            'configs': self.configs,
            'epoch': epoch,
        }
        torch.save(save_dict, save_path)

    def evaluate_model(self, dataset, loss):
        self.eval()
        mask = dataset.data.test_mask
        preds = self.forward(x=dataset.data.x, edge_index=dataset.data.edge_index)
        eval_loss = loss(preds[mask], dataset.data.y[mask])
        y_hat = preds.argmax(dim=1)
        test_correct = y_hat[mask] == dataset.data.y[mask]
        test_acc = int(test_correct.sum()) / int(mask.sum())
        return eval_loss.item(), test_acc


class SSLGNN(torch.nn.Module):
    """
    Class for constructing a U-SSL model with an object of class GNNEncoder and heads for SSL task learning. Can handle
    any number of datasets and any number of SSL tasks.
    """
    base_lr = 1e-3
    save_freq = 20
    task_weights = {  # taken from results of "Automated Self-supervised Learning for Graphs (ICLR 2022)"
        'graph partitioning': 0.9,
        'deep graph infomax': 0.65,
        'pair-wise attribute similarity': 0.65,
        'pair-wise distance': 0.1,
        'clustering': 0.1,
    }

    def __init__(self, encoder, data, ssl_tasks):
        super().__init__()
        # model
        self.encoder = encoder
        self.encoder.to(self.encoder.device_)

        # data
        self.data = data

        # ssl tasks
        self.ssl_tasks = ssl_tasks
        self.ssl_objects = self.instantiate_ssl()

        # optimization
        self.optimizer = torch.optim.Adam(self.encoder.parameters(), lr=self.base_lr)
        self.head_optimizers = [torch.optim.Adam(ssl_object.predictor.parameters(), lr=self.base_lr)
                                for ssl_object in self.ssl_objects if ssl_object.name != 'graph partitioning']
        if 'partition' in self.ssl_tasks:
            partition_task = [ssl_object for ssl_object in self.ssl_objects if ssl_object.name == 'graph partitioning']
            partition_task = partition_task[0]
            partition_optimizers = [torch.optim.Adam(predictor.parameters(), lr=self.base_lr)
                                    for predictor in partition_task.predictor]
            self.head_optimizers = [*self.head_optimizers, *partition_optimizers]
        self.plateau_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=50,
                                                                               factor=1 / 3., verbose=True)
        self.lin_lr_scheduler = torch.optim.lr_scheduler.LinearLR(self.optimizer)
        self.lr_scheduler = [self.plateau_lr_scheduler, self.lin_lr_scheduler]

    def instantiate_ssl(self):
        ssl_objects = []
        if self.encoder.configs['backbone']['layer_type'] in ['gat', 'gtx']:
            embedding_size = self.encoder.backbone[-2].out_features
        else:
            embedding_size = self.encoder.backbone[-2].out_channels
        if 'clustering' in self.ssl_tasks:
            ssl_objects.append(
                clustering.ClusteringTask(data=self.data, embedding_size=embedding_size, device=self.encoder.device_)
            )
        if 'pairsim' in self.ssl_tasks:
            ssl_objects.append(
                pairsim.PairwiseAttrSimTask(data=self.data, embedding_size=embedding_size, device=self.encoder.device_)
            )
        if 'dgi' in self.ssl_tasks:
            ssl_objects.append(
                dgi.DGITask(data=self.data, encoder=self.encoder, embedding_size=embedding_size,
                            device=self.encoder.device_)
            )
        if 'partition' in self.ssl_tasks:
            ssl_objects.append(
                partition.PartitionTask(data=self.data, embedding_size=embedding_size, device=self.encoder.device_)
            )
        if 'pairdis' in self.ssl_tasks:
            ssl_objects.append(
                pairdis.PairwiseDistanceTask(data=self.data, embedding_size=embedding_size, device=self.encoder.device_)
            )
        return ssl_objects

    def pretrain(self, num_epochs=1000, weighted=False, lr_adapt=True, verbose_freq=50, save_path=None):
        if lr_adapt is False:
            self.lr_scheduler = None

        if save_path is not None:
            with open(os.path.join(save_path['folder'], 'config.txt'), 'w') as f:
                f.write(json.dumps(self.encoder.configs))
                f.write("\n" + json.dumps({'weighted': weighted}))

        best_loss = 1e2
        self.train()
        print('starting pre-training')
        start_time = time.time()
        for epoch in range(num_epochs):
            # zero grad
            self.optimizer.zero_grad()
            [head_optimizer.zero_grad() for head_optimizer in self.head_optimizers]

            # forward + backward + update stems
            for dataset in self.data.datasets:
                # embeddings
                dataset.data.to(self.encoder.device_)
                x = self.encoder(x=dataset.data.x, edge_index=dataset.data.edge_index)

                # ssl
                ssl_loss = 0
                for ssl_object in self.ssl_objects:
                    if weighted:
                        ssl_loss = ssl_loss +\
                                   self.task_weights[ssl_object.name] * ssl_object.get_loss(x, dataset.name)
                    else:
                        ssl_loss = ssl_loss + ssl_object.get_loss(x, dataset.name)

                dataset.data.to(torch.device('cpu'))  # remove data from gpu

                ssl_loss.backward()

            # update parameters
            self.optimizer.step()
            [head_optimizer.step() for head_optimizer in self.head_optimizers]

            # save and overwrite model
            if ssl_loss.item() < best_loss and (epoch + 1) % self.save_freq == 0:
                best_loss = copy.deepcopy(ssl_loss.item())
                if save_path is not None:
                    self.save_model(epoch, save_path['file'])
                else:
                    print('no save_path provided. skipping saving')

            # print
            if (epoch + 1) % verbose_freq == 0 or epoch == 0 or epoch == num_epochs - 1:
                print(
                    '[' + datetime.datetime.now().strftime(format='%d.%m.%y %H:%M:%S') + ']' +
                    ' epoch: ' + str(epoch + 1) +
                    '; ssl_loss: ' + str(round(ssl_loss.item(), 4))
                )

            # lr schedule
            if self.plateau_lr_scheduler is not None:
                self.adjust_lr(ssl_loss)
        # print
        print('pre-training completed. elapsed time: ', str(time.time() - start_time))

    def save_model(self, epoch, save_path):
        head_state_dict = [ssl_object.predictor.state_dict() for ssl_object in self.ssl_objects
                           if ssl_object.name != 'graph partitioning']
        if 'partition' in self.ssl_tasks:
            partition_task = [ssl_object for ssl_object in self.ssl_objects if ssl_object.name == 'graph partitioning']
            partition_task = partition_task[0]
            partition_dicts = [predictor.state_dict() for predictor in partition_task.predictor]
            head_state_dict = [head_state_dict, partition_dicts]
        save_dict = {
            'data': self.data,
            'model_state_dict': {
                'heads': head_state_dict,
                'encoder': self.encoder.state_dict(),
            },
            'optimizer': {
                'type': 'adam',
                'base_lr': self.base_lr,
                'state_dicts': {
                    'encoder': self.optimizer.state_dict(),
                    'heads': [head_optimizer.state_dict() for head_optimizer in self.head_optimizers],
                },
            },
            'configs': self.encoder.configs,
            'ssl_tasks': self.ssl_tasks,
            'epoch': epoch,
        }
        torch.save(save_dict, save_path)

    def adjust_lr(self, ssl_loss):
        self.lin_lr_scheduler.step()
        self.plateau_lr_scheduler.step(ssl_loss)
