import os
import time
import torch
import numpy as np
from ray import tune
import torch.nn as nn
from copy import deepcopy
from args import get_parser
import torch.nn.functional as F
from utils import set_seeds, accuracy
import torch_geometric.transforms as T
from ray.tune.suggest.optuna import OptunaSearch
from torch_geometric.utils import k_hop_subgraph
from torch_geometric.data import DataLoader, Data, Dataset
from model import Policy, Augmentations, GNN, Discriminator, LogReg, NodeEncoder
from torch_geometric.datasets import Planetoid, WikiCS, Coauthor, Amazon


class Model(object):
    def __init__(self, cfg):
        self.lr = cfg['lr']
        self.device = cfg['device']
        self.verbose = cfg['verbose']
        self.temperature = cfg['temperature']
        self.policy = Policy(cfg['hid_dim'], cfg['dropout'], n_aug=4)
        self.encoder_augment = GNN(cfg['feat_dim'], cfg['hid_dim'], cfg['layer'], cfg['dropout'])
        self.encoder_base = NodeEncoder(cfg['feat_dim'], cfg['hid_dim'], cfg['dropout'])
        self.augments = Augmentations(cfg['feat_dim'], cfg['hid_dim'], cfg['ratio'], cfg['n_hops'])
        self.discriminator = Discriminator(cfg['hid_dim'])
        self.policy.to(cfg['device'])
        self.encoder_augment.to(cfg['device'])
        self.encoder_base.to(cfg['device'])
        self.augments.to(cfg['device'])
        self.discriminator.to(cfg['device'])

    def train(self, data_loader):
        params_aug = (list(self.encoder_augment.parameters()) + list(self.augments.augmentations.parameters()) +
                      list(self.policy.parameters()) + list(self.discriminator.parameters()))
        params_enc = (list(self.encoder_base.parameters()) + list(self.augments.augmentations.parameters()) +
                      list(self.policy.parameters()) + list(self.discriminator.parameters()))
        optimizer_enc = torch.optim.Adam(params_enc, lr=self.lr)
        optimizer_aug = torch.optim.Adam(params_aug, lr=self.lr)

        self.policy.train()
        self.augments.train()
        self.discriminator.train()

        best_loss, patience_count = 1e9, 0
        best_state = self.encoder_base.state_dict()
        for idx, batch in enumerate(data_loader):

            if batch.edge_index.size(1) == 0:
                continue

            loss_enc, loss_aug = .0, .0
            start = time.time()

            coin = int(torch.FloatTensor([0.5]).bernoulli())
            if coin:
                self.encoder_base.train()
                self.encoder_augment.eval()
            else:
                self.encoder_base.eval()
                self.encoder_augment.train()

            batch = batch.to(self.device)
            batch.edge_weight = None

            h_aug, g_aug = self.encoder_augment(batch.x, batch.edge_index, batch.edge_weight, batch.batch, None)
            action, prob = self.policy(g_aug, self.temperature, 2)
            i, j = tuple(map(int, action.max(1)[1]))
            score = (action * prob).sum(1)
            batch1, mask1 = self.augments.augmentations[i](deepcopy(batch), h_aug, g_aug, self.temperature)
            batch2, mask2 = self.augments.augmentations[j](deepcopy(batch), h_aug, g_aug, self.temperature)
            h1, g1 = self.encoder_base(batch1.x, batch1.edge_index, batch1.edge_weight, batch1.batch, mask1)
            h2, g2 = self.encoder_base(batch2.x, batch2.edge_index, batch2.edge_weight, batch2.batch, mask2)
            g1 = g1 * score[0]
            g2 = g2 * score[1]

            loss = (self.jsd_loss(h1, g2, batch1.batch) + self.jsd_loss(h2, g1, batch2.batch)) / 10
            loss.backward()

            if coin:
                loss_enc = loss.item()
                nn.utils.clip_grad_norm_(params_enc, 2.)
                optimizer_enc.step()
                optimizer_enc.zero_grad()
            else:
                loss_aug = loss.item()
                nn.utils.clip_grad_norm_(params_aug, 2.)
                optimizer_aug.step()
                optimizer_aug.zero_grad()

            end = time.time()

            if self.verbose:
                print(f'Step {idx + 1}/ | Time: {end - start:.2f}s | '
                      f'Encoder Loss: {loss_enc:.4f} | Augmenter Loss: {loss_aug:.4f}')

            if loss_enc > 0:
                if loss_enc < best_loss:
                    best_loss = loss_enc
                    best_state = self.encoder_base.state_dict()
                    patience_count = 0
                else:
                    patience_count += 1

                if patience_count == config['patience']:
                    if self.verbose:
                        print('Early stopping ...')
                    break

            if self.device.startswith('cuda'):
                torch.cuda.empty_cache()

        self.encoder_base.load_state_dict(best_state)

    @torch.no_grad()
    def encode(self, data):
        self.encoder_base.eval()
        h, __ = self.encoder_base(data.x, data.edge_index, None, data.batch, None)
        return h.detach()

    @staticmethod
    def jsd_loss(enc1, enc2, indices):
        pos_mask = torch.eye(enc1.shape[0], enc2.shape[0], device=enc1.device)
        if enc1.shape[0] != enc2.shape[0]:
            pos_mask = pos_mask[indices]
        neg_mask = 1. - pos_mask
        logits = enc1 @ enc2.t()
        Epos = (np.log(2.) - F.softplus(- logits))
        Eneg = (F.softplus(-logits) + logits - np.log(2.))
        Epos = (Epos * pos_mask).sum() / pos_mask.sum()
        Eneg = (Eneg * neg_mask).sum() / neg_mask.sum()
        return Eneg - Epos


def load_data(data_dir, data_name, split_idx=0):
    if data_name == 'Wiki-CS':
        data = WikiCS(root=data_dir, transform=T.NormalizeFeatures())
        return data.data, {
            'train': data.data.train_mask[:, split_idx].nonzero().view(-1),
            'test': data.data.test_mask.nonzero().view(-1),
            'val': data.data.val_mask[:, split_idx].nonzero().view(-1)
        }
    elif data_name in ('Cora', 'Citeseer', 'Pubmed'):
        data = Planetoid(root=data_dir, name=data_name, transform=T.NormalizeFeatures())
        return data.data, {
            'train': data.data.train_mask.nonzero().view(-1),
            'test': data.data.test_mask.nonzero().view(-1),
            'val': data.data.val_mask.nonzero().view(-1)
        }
    elif data_name in ('Amazon-Computers', 'Amazon-Photo'):
        data = Amazon(root=data_dir, name=data_name.replace('Amazon-', ''), transform=T.NormalizeFeatures())
    elif data_name in ('Coauthor-CS', 'Coauthor-Physics'):
        data = Coauthor(root=data_dir, name=data_name.replace('Coauthor-', ''), transform=T.NormalizeFeatures())

    num_nodes = data.data.x.size(0)
    train_size = int(num_nodes * .1)
    indices = torch.randperm(num_nodes)
    return data.data, {
        'train': indices[:train_size],
        'val': indices[train_size:2 * train_size],
        'test': indices[2 * train_size:]
    }


def remove_isolated(data, split):
    mask = torch.zeros(data.x.size(0), dtype=torch.bool)
    mask[data.edge_index.flatten()] = True
    row, col = data.edge_index
    node_idx = row.new_full((data.x.size(0),), -1)
    idx = torch.nonzero(mask).flatten()
    node_idx[idx] = torch.arange(idx.size(0), device=row.device)
    data.edge_index = node_idx[data.edge_index]
    data.x = data.x[mask]
    data.y = data.y[mask]

    isolated_nodes = set(torch.where(mask==False)[0].tolist())
    mapping, new_idx = {}, 0
    for idx, value in enumerate(mask):
        mapping[idx] = new_idx
        if value:
            new_idx += 1

    split = {k: torch.LongTensor([mapping[int(idx)] for idx in v if idx not in isolated_nodes]) for k, v in split.items()}
    return data, split


class SubGraph(Dataset):
    def __init__(self, data, n_hop):
        self.data = data
        self.n_hop = n_hop

    def __getitem__(self, index):
        idx, edge, center, __ = k_hop_subgraph(node_idx=int(index), edge_index=self.data.edge_index,
                                               num_hops=self.n_hop, relabel_nodes=True, )
        return Data(x=self.data.x[idx], edge_index=edge, center=center, y=self.data.y[idx])

    def __len__(self):
        return self.data.x.size(0)


def run(config, data_dir):
    from tqdm import tqdm
    accuracies = []
    augs = []
    for idx in tqdm(range(20), unit_scale=True, desc='Running experiments...'):
        set_seeds(idx)
        data, split = load_data(data_dir, config['name'], idx)
        config['feat_dim'] = data.x.size(1)
        data, split = remove_isolated(data, split)
        data_loader = DataLoader(SubGraph(data, config['n_hops']), batch_size=config['batch'], shuffle=True)
        model = Model(config)
        model.train(data_loader)
        data = next(iter(DataLoader([data], batch_size=1)))
        data = data.to(config['device'])
        x = model.encode(data)
        acc = linear_classifier(x, data.y, split)
        accuracies.append(acc)
    print({'hid_dim': config['hid_dim'], 'batch': config['batch'], 'layer': config['layer'],
           'lr': config['lr'], 'temperature': config['temperature'], 'n_hops': config['n_hops'],
           'ratio': config['ratio'], 'dropout': config['dropout'], 'acc': accuracies,
           'mean': np.mean(accuracies), 'std': np.std(accuracies)})

    return np.mean(accuracies)


def linear_classifier(x, y, split):
    n_hidden = x.size(1)
    n_class = y.max().item() + 1
    classifier = LogReg(n_hidden, n_class).to(x.device)
    optimizer = torch.optim.Adam(classifier.parameters(), lr=0.01)
    f = nn.LogSoftmax(dim=-1)
    best_test_acc = 0
    best_val_acc = 0
    for __ in range(300):
        classifier.train()
        optimizer.zero_grad()
        output = classifier(x[split['train']])
        loss = nn.CrossEntropyLoss()(f(output), y[split['train']])
        loss.backward()
        optimizer.step()
        test_acc = accuracy(y[split['test']].view(-1, 1), classifier(x[split['test']]).argmax(-1).view(-1, 1))
        val_acc = accuracy(y[split['val']].view(-1, 1), classifier(x[split['val']]).argmax(-1).view(-1, 1))
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc
    return best_test_acc


if __name__ == '__main__':
    args = get_parser()
    print(args.name)
    args.device = 'cpu'
    config = args.__dict__
    run(config, os.path.abspath(os.path.join('./data', config['name'])))




