from __future__ import division

import time

import torch
import torch.nn.functional as F
from torch import tensor
from torch.optim import Adam
from tqdm import tqdm
import numpy as np
from geom_data_utils import load_geom_datasets
from sklearn.model_selection import train_test_split
import logging
import random
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")


def index_to_mask(index, size):
    mask = torch.zeros((size, ), dtype=torch.bool)
    mask[index] = 1
    return mask

def random_planetoid_splits(data, num_classes):


    indices = []
    for i in range(num_classes):
        index = (data.y == i).nonzero().view(-1)
        index = index[torch.randperm(index.size(0))]
        indices.append(index)

    train_index = torch.cat([i[:20] for i in indices], dim=0)

    rest_index = torch.cat([i[20:] for i in indices], dim=0)
    rest_index = rest_index[torch.randperm(rest_index.size(0))]

    data.train_mask = index_to_mask(train_index, size=data.num_nodes)
    data.val_mask = index_to_mask(rest_index[:500], size=data.num_nodes)
    data.test_mask = index_to_mask(rest_index[500:1500], size=data.num_nodes)

    return data

def random_geom_splits(data, num_classes, seed, name):

    edge_index, x, y, idx_train, idx_val, idx_test = load_geom_datasets(name, seed)
    train_mask = index_to_mask(idx_train, size=y.size(0))
    val_mask = index_to_mask(idx_val, size=y.size(0))
    test_mask = index_to_mask(idx_test, size=y.size(0))
    data.train_mask = train_mask
    data.val_mask = val_mask
    data.test_mask = test_mask
    return data

def random_splits(data, seed):
    y = data.y
    idx_train, idx_val, idx_test = get_train_val_test(data.x.shape[0], val_size=0.32, test_size=0.2, stratify=y, seed=seed%10)
    train_mask = index_to_mask(idx_train, size=y.size(0))
    val_mask = index_to_mask(idx_val, size=y.size(0))
    test_mask = index_to_mask(idx_test, size=y.size(0))
    data.train_mask = train_mask
    data.val_mask = val_mask
    data.test_mask = test_mask
    return data


def get_train_val_test(nnodes, val_size=0.1, test_size=0.8, stratify=None, seed=None):
    assert stratify is not None, 'stratify cannot be None!'

    if seed is not None:
        np.random.seed(seed)

    idx = np.arange(nnodes)
    train_size = 1 - val_size - test_size
    idx_train_and_val, idx_test = train_test_split(idx,
                                                   random_state=None,
                                                   train_size=train_size + val_size,
                                                   test_size=test_size,
                                                   stratify=stratify)

    if stratify is not None:
        stratify = stratify[idx_train_and_val]

    idx_train, idx_val = train_test_split(idx_train_and_val,
                                          random_state=None,
                                          train_size=(train_size / (train_size + val_size)),
                                          test_size=(val_size / (train_size + val_size)),
                                          stratify=stratify)

    return idx_train, idx_val, idx_test




def inject_random_edges(data, seed, ptb_rate):
    np.random.seed(seed%10)
    random.seed(seed%10)

    def sample_zero_forever(data):
        nonzero_or_sampled = set(zip(*data.edge_index.numpy()))
        while True:
            t = tuple(np.random.choice(data.x.shape[0], 2, replace=False))
            if t not in nonzero_or_sampled:
                yield t
                nonzero_or_sampled.add(t)
                nonzero_or_sampled.add((t[1], t[0]))

    def sample_zero_n(data, n=100):
        itr = sample_zero_forever(data)
        return [next(itr) for _ in range(n)]


    l = [data.train_mask, data.val_mask, data.test_mask]

    added = sample_zero_n(data, n=int(ptb_rate * data.edge_index.shape[1])//2)
    added = torch.LongTensor(added).t()
    added_sym = torch.cat((added[1].view(1,-1), added[0].view(1,-1)), dim=0)
    added = torch.cat((added, added_sym), dim=1)
    edge_index_new = torch.cat((data.edge_index, added), dim=1)
    edge_index = edge_index_new[:, edge_index_new[0].sort()[1]]
    from torch_geometric.data import Data
    data = Data(x=data.x, edge_index=edge_index, y=data.y)
    data.train_mask, data.val_mask, data.test_mask = l

    return data

def run(dataset, model, runs, epochs, lr, weight_decay, early_stopping,
        permute_masks=None, logger=None, **kwargs):

    test_accs, accs, durations = [], [], []
    for _ in range(runs):
        for _ in range(4):
            np.random.seed(_)
            torch.manual_seed(_)
            torch.cuda.manual_seed(_) 
        #     #cornell
        
        data = dataset[0]
        if 'ptb_rate' in kwargs and kwargs['ptb_rate'] > 0:
            data = inject_random_edges(data, seed=_, ptb_rate=kwargs['ptb_rate'])

        if permute_masks is not None:
            if permute_masks == random_geom_splits:
                data = permute_masks(data, dataset.num_classes, seed=_, name=kwargs['name'])
            elif permute_masks == random_splits:
                data = permute_masks(data, seed=_)
            else:
                data = permute_masks(data, dataset.num_classes)

        data = data.to(device)

        model.to(device).reset_parameters()
        optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

        if torch.cuda.is_available():
            torch.cuda.synchronize()

        t_start = time.perf_counter()
        best_val_acc = 0
        best_test_acc = 0
        test_acc = 0
        val_acc_history = []
        test_acc_history = []
        val_accs=[]


        for epoch in range(1, epochs + 1):
            train(model, optimizer, data)
            eval_info = evaluate(model, data)
            eval_info['epoch'] = epoch

            if logger is not None:
                logger(eval_info)

    
            if eval_info['val_acc'] > best_val_acc:
                best_val_acc=eval_info['val_acc']
                best_test_acc = eval_info['test_acc']
                test_acc = eval_info['test_acc']
                p=model.conv1.save_alpha
                torch.save(model.conv1.save_alpha ,'bestmodel')

            val_acc_history.append(eval_info['val_acc'])
            if early_stopping > 0 and epoch > epochs // 2:
                tmp = tensor(val_acc_history[-(early_stopping + 1):-1])
                if eval_info['val_acc'] < tmp.mean().item():
                    break
            
            # print(eval_info)

        if torch.cuda.is_available():
            torch.cuda.synchronize()
        t_end = time.perf_counter()
        val_accs.append(best_val_acc)
        accs.append(test_acc)
        durations.append(t_end - t_start)
 

    loss, acc, duration = tensor(val_accs), tensor(accs), tensor(durations)
    



    print('Val Acc: {:.2f}, Test Accuracy: {:.2f} ± {:.2f}, Duration: {:.3f}'.
          format(100 * loss.mean().item(),
                 100 * acc.mean().item(),
                 100 * acc.std().item(),
                 duration.mean().item()))


def train(model, optimizer, data):
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()


def evaluate(model, data):
    model.eval()

    with torch.no_grad():
        logits = model(data)

    outs = {}
    for key in ['train', 'val', 'test']:
        mask = data['{}_mask'.format(key)]
        loss = F.nll_loss(logits[mask], data.y[mask]).item()
        pred = logits[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()

        outs['{}_loss'.format(key)] = loss
        outs['{}_acc'.format(key)] = acc

    return outs
