import torch as T
import torch.nn as nn
import numpy as np
import networkx as nx 
from bsgnn.utils.data import load_picked_data
from bsgnn.samplers import BallSampler, AdjSampler, AdjGraph, MultiAdjSampler
from bsgnn.data import InMemory, process_batch
from bsgnn.models.kmpgnn_batch import MPGNN
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.optim as optim
import pickle as pkl
import seaborn as sns
from matplotlib import pyplot as plt

datainfo = {
    'REDDITBINARY': (566, 2, 2000),
    'COLLAB': (367, 3, 5000),
    'REDDITMULTI5K': (734, 5, 4999)
}

def bsgnn_v1(idim, odim, dropout=0):
    model = MPGNN(layers=[(idim, 32, 32), (32, 32, 32), (32, 32, 32), (32, 32, 32)], 
                  mlp_layers=[(idim, 32, odim), (32, 32, odim), (32, 32, odim), (32, 32, odim), (32, 32, odim)], 
                  dropout=dropout)
    return model


def train_one_fold(model, data, train_idx, test_idx):
    train_by_epoch = []
    test_by_epoch = []
    train_data = [data[i] for i in train_idx]
    test_data = [data[i] for i in test_idx]
    train_loader =  DataLoader(train_data, batch_size=32, shuffle=True, 
                               num_workers=0, collate_fn=lambda x: zip(*x))
    test_loader =  DataLoader(test_data, batch_size=32, shuffle=True, 
                              num_workers=0, collate_fn=lambda x: zip(*x))
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-3)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

    for epoch in range(200):
        corrects = 0
        model.train()
        for i, data_batched in enumerate(train_loader):
            optimizer.zero_grad()
            adjs, feats, labels = data_batched
            labels = T.Tensor(labels).long().to(device)
            adjs, feats, graph_sizes = process_batch(adjs, feats)
            output = model(adjs.to(device), feats.to(device), graph_sizes)
            _, predicted = T.max(output, 1)
            corrects += (predicted == labels).cpu().sum().item()
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()
            train_by_epoch.append(corrects/len(train_idx))
            
        model.eval()
        corrects = 0
        for i, data_batched in enumerate(test_loader):
            adjs, feats, labels = data_batched
            labels = T.Tensor(labels).long().to(device)
            adjs, feats, graph_sizes = process_batch(adjs, feats)
            output = model(adjs.to(device), feats.to(device), graph_sizes)
            _, predicted = T.max(output, 1)
            corrects += (predicted == labels).cpu().sum().item()
            test_by_epoch.append(corrects/len(test_idx))
            
    if (epoch+1) % 10 == 0:
        print(epoch+1, train_by_epoch[-1], test_by_epoch[-1])

    return model, train_by_epoch, test_by_epoch

if __name__ == "__main__":
    dname = "COLLAB"
    r = 2
    b = 5 
    k = 3
    with open("../datasets/{}_r{}_b{}_k{}.dataset".format(dname, r, b, k), "rb") as f:
        sampled_data = pkl.load(f)
    device = T.device("cuda" if T.cuda.is_available() else "cpu")
    idim, odim, n = datainfo[dname]
    with open("../datasets/{}.folds".format(dname), "rb") as f:
        fold_ids = pkl.load(f)
    for i in range(10):
        print("Fold {}".format(i))
        model = bsgnn_v1(idim, odim, dropout=0.5)
        model.to(device)
        _ = train_one_fold(model, sampled_data, *fold_ids[i])
