from __future__ import division

import time

import torch
import torch.nn.functional as F
from torch import tensor
from torch.optim import Adam
from tqdm import tqdm
import numpy as np
from geom_data_utils import load_geom_datasets
from sklearn.model_selection import train_test_split
import logging
import random
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

# def index_to_mask(index, size):
#     mask = torch.zeros(size, dtype=torch.bool, device=index.device)
#     mask[index] = 1
#     return mask

def index_to_mask(index, size):
    mask = torch.zeros((size, ), dtype=torch.bool)
    mask[index] = 1
    return mask

def random_planetoid_splits(data, num_classes):
    # Set new random planetoid splits:
    # * 20 * num_classes labels for training
    # * 500 labels for validation
    # * 1000 labels for testing

    indices = []
    for i in range(num_classes):
        index = (data.y == i).nonzero().view(-1)
        index = index[torch.randperm(index.size(0))]
        indices.append(index)

    train_index = torch.cat([i[:20] for i in indices], dim=0)

    rest_index = torch.cat([i[20:] for i in indices], dim=0)
    rest_index = rest_index[torch.randperm(rest_index.size(0))]

    data.train_mask = index_to_mask(train_index, size=data.num_nodes)
    data.val_mask = index_to_mask(rest_index[:500], size=data.num_nodes)
    data.test_mask = index_to_mask(rest_index[500:1500], size=data.num_nodes)

    return data

def random_geom_splits(data, num_classes, seed, name):
    # Set new random planetoid splits:
    # * 20 * num_classes labels for training
    # * 500 labels for validation
    # * 1000 labels for testing
    edge_index, x, y, idx_train, idx_val, idx_test = load_geom_datasets(name, seed)
    train_mask = index_to_mask(idx_train, size=y.size(0))
    val_mask = index_to_mask(idx_val, size=y.size(0))
    test_mask = index_to_mask(idx_test, size=y.size(0))
    data.train_mask = train_mask
    data.val_mask = val_mask
    data.test_mask = test_mask
    return data

def random_splits(data, seed):
    # Set new random planetoid splits:
    y = data.y
    idx_train, idx_val, idx_test = get_train_val_test(data.x.shape[0], val_size=0.32, test_size=0.2, stratify=y, seed=seed%10)
    train_mask = index_to_mask(idx_train, size=y.size(0))
    val_mask = index_to_mask(idx_val, size=y.size(0))
    test_mask = index_to_mask(idx_test, size=y.size(0))
    data.train_mask = train_mask
    data.val_mask = val_mask
    data.test_mask = test_mask
    return data


def get_train_val_test(nnodes, val_size=0.1, test_size=0.8, stratify=None, seed=None):
    assert stratify is not None, 'stratify cannot be None!'

    if seed is not None:
        np.random.seed(seed)

    idx = np.arange(nnodes)
    train_size = 1 - val_size - test_size
    idx_train_and_val, idx_test = train_test_split(idx,
                                                   random_state=None,
                                                   train_size=train_size + val_size,
                                                   test_size=test_size,
                                                   stratify=stratify)

    if stratify is not None:
        stratify = stratify[idx_train_and_val]

    idx_train, idx_val = train_test_split(idx_train_and_val,
                                          random_state=None,
                                          train_size=(train_size / (train_size + val_size)),
                                          test_size=(val_size / (train_size + val_size)),
                                          stratify=stratify)

    return idx_train, idx_val, idx_test

def run_(dataset, model, runs, epochs, lr, weight_decay, early_stopping,
        permute_masks=None, logger=None, **kwargs):

    val_losses, accs, durations = [], [], []
    for _ in tqdm(range(runs)):
        ## fix random seed for each run
        np.random.seed(_)
        torch.manual_seed(_)
        torch.cuda.manual_seed(_)

        data = dataset[0]
        if permute_masks is not None:
            if permute_masks == random_geom_splits:
                data = permute_masks(data, dataset.num_classes, seed=_, name=kwargs['name'])
            elif permute_masks == random_splits:
                data = permute_masks(data, seed=_)
            else:
                data = permute_masks(data, dataset.num_classes)
        data = data.to(device)
        model.to(device).reset_parameters()
        optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

        if torch.cuda.is_available():
            torch.cuda.synchronize()

        t_start = time.perf_counter()

        best_val_loss = float('inf')
        test_acc = 0
        val_loss_history = []

        for epoch in range(1, epochs + 1):
            train(model, optimizer, data)
            eval_info = evaluate(model, data)
            eval_info['epoch'] = epoch

            if logger is not None:
                logger(eval_info)

            if eval_info['val_loss'] < best_val_loss:
                best_val_loss = eval_info['val_loss']
                test_acc = eval_info['test_acc']

            val_loss_history.append(eval_info['val_loss'])
            if early_stopping > 0 and epoch > epochs // 2:
                tmp = tensor(val_loss_history[-(early_stopping + 1):-1])
                if eval_info['val_loss'] > tmp.mean().item():
                    break

        if torch.cuda.is_available():
            torch.cuda.synchronize()

        t_end = time.perf_counter()

        val_losses.append(best_val_loss)
        accs.append(test_acc)
        durations.append(t_end - t_start)

    loss, acc, duration = tensor(val_losses), tensor(accs), tensor(durations)

    print('Val Loss: {:.4f}, Test Accuracy: {:.4f} ± {:.4f}, Duration: {:.3f}'.
          format(loss.mean().item(),
                 acc.mean().item(),
                 acc.std().item(),
                 duration.mean().item()))

def inject_noise(data, sigma=0.05, ratio=0.25):
    # maybe only perturb some channels
    noise = (sigma) * torch.randn(data.x.shape)
    columns = np.random.choice(data.x.shape[1], int(ratio * data.x.shape[0]), replace=False)
    noise[columns] = 0
    data.x = data.x + noise
    return data

def _inject_random_edges(data, seed ,ptb_rate):
    np.random.seed(seed%10)
    random.seed(seed%10)
    n = data.x.shape[0]
    import scipy.sparse as sp
    adj = sp.csr_matrix((np.ones(data.edge_index.shape[1]), (data.edge_index[0], data.edge_index[1])), shape=(n, n))
    adj = adj.tolil()
    from deeprobust.graph.global_attack import Random
    attacker = Random()
    n_perturbations = int(ptb_rate * (adj.sum()//2))
    attacker.attack(adj, n_perturbations, type='add')
    perturbed_adj = attacker.modified_adj
    data.edge_index = torch.LongTensor(adj.nonzero())
    return data

def _inject_random_edges(data, seed, ptb_rate):
    np.random.seed(seed%10)
    random.seed(seed%10)
    def _random_add_edges(adj, add_ratio):
        def sample_zero_forever(mat):
            nonzero_or_sampled = set(zip(*mat.nonzero()))
            while True:
                # t = tuple(np.random.choice(adj.shape[0], 2, replace=False))
                t = tuple(np.random.randint(0, adj.shape[0], 2))
                if t not in nonzero_or_sampled:
                    yield t
                    nonzero_or_sampled.add(t)
                    nonzero_or_sampled.add((t[1], t[0]))

        def sample_zero_n(mat, n=100):
            itr = sample_zero_forever(mat)
            return [next(itr) for _ in range(n)]

        assert np.abs(adj - adj.T).sum() == 0, "Input graph is not symmetric"
        non_zeros = [(x, y) for x,y in np.argwhere(adj != 0) if x < y] # (x, y)

        added = sample_zero_n(adj, n=int(add_ratio * len(non_zeros)))
        for x, y in added:
            adj[x, y] = 1
            adj[y, x] = 1
        return adj

    n = data.x.shape[0]
    import scipy.sparse as sp
    adj = sp.csr_matrix((np.ones(data.edge_index.shape[1]), (data.edge_index[0], data.edge_index[1])), shape=(n, n))
    adj = adj.tolil()
    adj = _random_add_edges(adj, ptb_rate)
    data.edge_index = torch.LongTensor(adj.nonzero())
    return data


def inject_random_edges(data, seed, ptb_rate):
    np.random.seed(seed%10)
    random.seed(seed%10)

    def sample_zero_forever(data):
        nonzero_or_sampled = set(zip(*data.edge_index.numpy()))
        # nonzero_or_sampled = set(zip(*mat.nonzero()))
        while True:
            # t = tuple(np.random.randint(0, data.x.shape[0], 2))
            t = tuple(np.random.choice(data.x.shape[0], 2, replace=False))
            if t not in nonzero_or_sampled:
                yield t
                nonzero_or_sampled.add(t)
                nonzero_or_sampled.add((t[1], t[0]))

    def sample_zero_n(data, n=100):
        itr = sample_zero_forever(data)
        return [next(itr) for _ in range(n)]

    # assert np.abs(adj - adj.T).sum() == 0, "Input graph is not symmetric"
    # non_zeros = [(x, y) for x,y in np.argwhere(adj != 0) if x < y] # (x, y)

    l = [data.train_mask, data.val_mask, data.test_mask]

    added = sample_zero_n(data, n=int(ptb_rate * data.edge_index.shape[1])//2)
    added = torch.LongTensor(added).t()
    added_sym = torch.cat((added[1].view(1,-1), added[0].view(1,-1)), dim=0)
    added = torch.cat((added, added_sym), dim=1)
    edge_index_new = torch.cat((data.edge_index, added), dim=1)
    edge_index = edge_index_new[:, edge_index_new[0].sort()[1]]
    # assert edge_index.shape[1] - data.edge_index.shape[1] == int(ptb_rate * data.edge_index.shape[1])
    from torch_geometric.data import Data
    data = Data(x=data.x, edge_index=edge_index, y=data.y)
    data.train_mask, data.val_mask, data.test_mask = l

    # print(data.edge_index[:, :10])
    return data

def run(dataset, model, runs, epochs, lr, weight_decay, early_stopping,
        permute_masks=None, logger=None, **kwargs):

    val_accs, accs, durations = [], [], []
    for _ in tqdm(range(runs)):
        ## fix random seed for each run
        np.random.seed(_)
        torch.manual_seed(_)
        torch.cuda.manual_seed(_)
        data = dataset[0]
        if 'ptb_rate' in kwargs and kwargs['ptb_rate'] > 0:
            data = inject_random_edges(data, seed=_, ptb_rate=kwargs['ptb_rate'])
        # data = inject_noise(data)

        if permute_masks is not None:
            if permute_masks == random_geom_splits:
                data = permute_masks(data, dataset.num_classes, seed=_, name=kwargs['name'])
            elif permute_masks == random_splits:
                data = permute_masks(data, seed=_)
            else:
                data = permute_masks(data, dataset.num_classes)

        # print(data.edge_index.shape[1])
        # print(data.edge_index[1, :10])
        # print(data.train_mask.sum())
        data = data.to(device)

        model.to(device).reset_parameters()
        optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

        if torch.cuda.is_available():
            torch.cuda.synchronize()

        t_start = time.perf_counter()

        best_val_acc = 0
        test_acc = 0
        val_acc_history = []

        for epoch in range(1, epochs + 1):
            train(model, optimizer, data)
            eval_info = evaluate(model, data)
            eval_info['epoch'] = epoch

            if logger is not None:
                logger(eval_info)

            if eval_info['val_acc'] > best_val_acc:
                best_val_acc = eval_info['val_acc']
                test_acc = eval_info['test_acc']
                # print(eval_info['test_acc'])

                # alpha = model.conv1.alpha_pair.view(-1, 64).cpu().numpy().astype(np.float16)
                # import pandas as pd; df = pd.DataFrame(data=alpha)
                # df.to_csv(f'saved_scores/model-3/{dataset.name}.csv')
                # # alpha = model.conv1.alpha.view(-1, 64).cpu().numpy().astype(np.float16)
                # # import pandas as pd; df = pd.DataFrame(data=alpha)
                # # df.to_csv(f'saved_scores/model-2/{dataset.name}.csv')
                # # np.save(f'saved_scores/model-1/{dataset.name}.npy', model.conv1.alpha.view(-1).cpu())

            val_acc_history.append(eval_info['val_acc'])
            if early_stopping > 0 and epoch > epochs // 2:
                tmp = tensor(val_acc_history[-(early_stopping + 1):-1])
                if eval_info['val_acc'] < tmp.mean().item():
                    break
            
            print(eval_info)

        if torch.cuda.is_available():
            torch.cuda.synchronize()

        t_end = time.perf_counter()
        val_accs.append(best_val_acc)
        accs.append(test_acc)
        durations.append(t_end - t_start)

    loss, acc, duration = tensor(val_accs), tensor(accs), tensor(durations)


    logging.debug('Val Acc: {:.2f}, Test Accuracy: {:.2f} ± {:.2f}, Duration: {:.3f}'.
          format(100 * loss.mean().item(),
                 100 * acc.mean().item(),
                 100 * acc.std().item(),
                 duration.mean().item()))
    print('Val Acc: {:.2f}, Test Accuracy: {:.2f} ± {:.2f}, Duration: {:.3f}'.
          format(100 * loss.mean().item(),
                 100 * acc.mean().item(),
                 100 * acc.std().item(),
                 duration.mean().item()))


def train(model, optimizer, data):
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()


def evaluate(model, data):
    model.eval()

    with torch.no_grad():
        logits = model(data)

    outs = {}
    for key in ['train', 'val', 'test']:
        mask = data['{}_mask'.format(key)]
        loss = F.nll_loss(logits[mask], data.y[mask]).item()
        pred = logits[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()

        outs['{}_loss'.format(key)] = loss
        outs['{}_acc'.format(key)] = acc

    return outs
