import os
import time
import torch
import numpy as np
import torch.nn as nn
from copy import deepcopy
from args import get_parser
from sklearn.svm import LinearSVC
from torch.nn import functional as F
import torch_geometric.transforms as T
from torch_geometric.utils import degree
from utils import set_seeds, mute_warning
from sklearn.metrics import accuracy_score
from torch_geometric.data import DataLoader
from model import Policy, Augmentations, GNN
from torch_geometric.datasets import TUDataset
from sklearn.preprocessing import StandardScaler
from sklearn.exceptions import ConvergenceWarning
from sklearn.utils._testing import ignore_warnings
from sklearn.model_selection import GridSearchCV, StratifiedKFold


class Model(object):
    def __init__(self, cfg):
        self.lr = cfg['lr']
        self.epoch = cfg['epoch']
        self.device = cfg['device']
        self.verbose = cfg['verbose']
        self.temperature = cfg['temperature']
        self.policy = Policy(cfg['hid_dim'], cfg['dropout'])
        self.encoder_augment = GNN(cfg['feat_dim'], cfg['hid_dim'], cfg['layer'], cfg['dropout'])
        self.encoder_base = GNN(cfg['feat_dim'], cfg['hid_dim'], cfg['layer'], cfg['dropout'])
        self.augments = Augmentations(cfg['feat_dim'], cfg['hid_dim'], cfg['ratio'], cfg['n_hops'], cfg['random'])
        self.policy.to(cfg['device'])
        self.encoder_augment.to(cfg['device'])
        self.encoder_base.to(cfg['device'])
        self.augments.to(cfg['device'])

    def train(self, data):
        params_aug = (list(self.encoder_augment.parameters()) + list(self.augments.augmentations.parameters()) +
                      list(self.policy.parameters()))
        params_enc = (list(self.encoder_base.parameters()) + list(self.augments.parameters()) +
                      list(self.policy.parameters()))
        optimizer_enc = torch.optim.Adam(params_enc, lr=self.lr)
        optimizer_aug = torch.optim.Adam(params_aug, lr=self.lr)

        self.policy.train()
        self.augments.train()
        for epoch in range(self.epoch):
            loss_enc, loss_aug = .0, .0
            ctr_enc, ctr_aug = 1, 1
            start = time.time()

            for batch in data:

                coin = int(torch.FloatTensor([0.5]).bernoulli())
                if coin:
                    self.encoder_base.train()
                    self.encoder_augment.eval()
                else:
                    self.encoder_base.eval()
                    self.encoder_augment.train()

                batch = batch.to(self.device)
                batch.edge_weight = None

                h_aug, g_aug = self.encoder_augment(batch.x, batch.edge_index, batch.edge_weight, batch.batch, None)
                action, prob = self.policy(g_aug, self.temperature, 2)
                i, j = tuple(map(int, action.max(1)[1]))
                score = (action * prob).sum(1)
                batch1, mask1 = self.augments.augmentations[i](deepcopy(batch), h_aug, g_aug, self.temperature)
                batch2, mask2 = self.augments.augmentations[j](deepcopy(batch), h_aug, g_aug, self.temperature)
                __, g1 = self.encoder_base(batch1.x, batch1.edge_index, batch1.edge_weight, batch1.batch, mask1)
                __, g2 = self.encoder_base(batch2.x, batch2.edge_index, batch2.edge_weight, batch2.batch, mask2)
                g1 = g1 * score[0]
                g2 = g2 * score[1]
                loss = self.jsd_loss(g1, g2, None) / 100
                loss.backward()

                if coin:
                    ctr_enc += 1
                    loss_enc += loss.item()
                    nn.utils.clip_grad_norm_(params_enc, 2.)
                    optimizer_enc.step()
                    optimizer_enc.zero_grad()
                else:
                    ctr_aug += 1
                    loss_aug += loss.item()
                    nn.utils.clip_grad_norm_(params_aug, 2.)
                    optimizer_aug.step()
                    optimizer_aug.zero_grad()

            loss_enc /= ctr_enc
            loss_aug /= ctr_aug
            end = time.time()

            if self.verbose:
                print(f'Epoch {epoch + 1}/{self.epoch} | Time: {end - start:.2f}s | '
                      f'Encoder Loss: {loss_enc:.4f} | Augmenter Loss: {loss_aug:.4f}')

            if self.device.startswith('cuda'):
                torch.cuda.empty_cache()

    @torch.no_grad()
    def encode(self, data):
        scale = StandardScaler()
        self.encoder_base.eval()
        enc = []
        for batch in data:
            batch = batch.to(self.device)
            __, g = self.encoder_base(batch.x, batch.edge_index, None, batch.batch, None)
            enc.append(g.detach().cpu())
        enc = torch.cat(enc, 0).numpy()
        enc = scale.fit_transform(enc)
        return enc

    @staticmethod
    def jsd_loss(enc1, enc2, indices):
        pos_mask = torch.eye(enc1.shape[0], enc2.shape[0], device=enc1.device)
        if enc1.shape[0] != enc2.shape[0]:
            pos_mask = pos_mask[indices]
        neg_mask = 1. - pos_mask
        logits = (enc1 @ enc2.t())
        Epos = (np.log(2.) - F.softplus(- logits))
        Eneg = (F.softplus(-logits) + logits - np.log(2.))
        Epos = (Epos * pos_mask).sum() / pos_mask.sum()
        Eneg = (Eneg * neg_mask).sum() / neg_mask.sum()
        return Eneg - Epos


def load_data(data_dir, data_name):
    data = TUDataset(root=data_dir, name=data_name, use_node_attr=True, use_edge_attr=False)
    if data.num_features == 0:
        max_degree = int(max([degree(d.edge_index[0]).max().item() for d in data])) + 1
        trans = T.Compose([T.OneHotDegree(max_degree), T.NormalizeFeatures()])
    else:
        trans = T.Compose([T.NormalizeFeatures()])
    data = TUDataset(root=data_dir, name=data_name, use_node_attr=True, transform=trans)
    return data


def run(config, data_dir):
    data = load_data(data_dir, config['name'])
    config['feat_dim'] = data.num_features
    seeds = [123, 321, 231, 132, 312]
    a = []
    for i in range(5):
        set_seeds(seeds[i])
        data_loader = DataLoader(data, shuffle=True, batch_size=config['batch'])
        model = Model(config)
        model.train(data_loader)
        x = model.encode(DataLoader(data, shuffle=False, batch_size=config['batch']))
        y = np.array([d.y.item() for d in data])
        acc = linear_classifier(x, y)
        a.extend(acc)

    print({'hid_dim': config['hid_dim'], 'batch': config['batch'], 'epoch': config['epoch'],
           'layer': config['layer'], 'lr': config['lr'], 'temperature': config['temperature'],
           'n_hops': config['n_hops'], 'ratio': config['ratio'], 'dropout': config['dropout'],
           'acc': a, 'mean': np.mean(a), 'std': np.std(a)})
    return np.mean(a)


@ignore_warnings(category=ConvergenceWarning)
def linear_classifier(x, y):
    kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=0)
    accuracies = []
    params = {'C': [0.001, 0.01, 0.1, 1, 10, 100, 1000]}
    for train_index, test_index in kf.split(x, y):
        x_train, x_test = x[train_index], x[test_index]
        y_train, y_test = y[train_index], y[test_index]
        classifier = GridSearchCV(LinearSVC(dual=False), params, cv=5, scoring='accuracy', verbose=0, n_jobs=5)
        classifier.fit(x_train, y_train)
        accuracies.append(accuracy_score(y_test, classifier.predict(x_test)))
    return accuracies


if __name__ == '__main__':
    mute_warning()
    args = get_parser()
    args.device = 'cpu'
    args.verbose = True
    config = args.__dict__

    print(args.name)
    run(config, os.path.abspath(os.path.join('./data', config['name'])))


