import argparse
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import random
# from torch_geometric.loader import DataLoader
from torch_scatter import scatter
from torch_geometric.utils import to_undirected, remove_isolated_nodes
torch.autograd.set_detect_anomaly(True)
from logger import Logger
from dataset import load_dataset,Large_Dataset
from data_utils import evaluate, eval_acc,to_sparse_tensor, eval_rocauc, sample_neighborhood,load_fixed_splits, sample_neg_neighborhood
from encoders import LINK, GCN, MLP, SGC, GAT, SGCMem, MultiLP, MixHop, GCNJK, GATJK, H2GCN, APPNP_Net, LINK_Concat, LINKX, GPRGNN, GCNII
import faulthandler

faulthandler.enable()

from models import DSSL,LogisticRegression
from os import path
DATAPATH = path.dirname(path.dirname(path.abspath(__file__))) + '/data/' \
### Parse args ###
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=0, help='gpu id')
parser.add_argument('--runs', type=int, default=1)
parser.add_argument('--dataset', type=str, default='Cora')
parser.add_argument('--sub_dataset', type=str, default='DE')
parser.add_argument('--epochs', type=int, default=200)
parser.add_argument('--lr', type=int, default=0)# 0.01
parser.add_argument('--weight_decay', type=float, default=1e-3)
parser.add_argument('--hidden_channels', type=int, default=64)
parser.add_argument('--num_layers', type=int, default=2)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--directed', action='store_true', help='set to not symmetrize adjacency')
parser.add_argument('--normalize_features', type=bool, default=True)
parser.add_argument('--seed', type=int, default=0, help='Random seed.')
parser.add_argument('--display_step', type=int, default=25, help='how often to print')
parser.add_argument('--train_prop', type=float, default=.48, help='training label proportion')
parser.add_argument('--valid_prop', type=float, default=.32, help='validation label proportion')
parser.add_argument('--batch_size', type=int, default=1024, help="batch size")
parser.add_argument('--rand_split', type=bool, default=True, help='use random splits')
parser.add_argument('--embedding_dim', type=int, default=10, help="embedding dim")
parser.add_argument('--neighbor_max', type=int, default=5, help="neighbor num max")
parser.add_argument('--cluster_num', type=int, default=6, help="cluster num")
parser.add_argument('--no_bn', action='store_true', help='do not use batchnorm')
parser.add_argument('--alpha', type=float, default=1)
parser.add_argument('--gamma', type=float, default=0.1)
parser.add_argument('--entropy', type=float, default=0.0)
parser.add_argument('--tau', type=float, default=0.99)
parser.add_argument('--encoder', type=str, default='MLP')
parser.add_argument('--mlp_bool', type=int, default=1, help="embedding with mlp predictor")
parser.add_argument('--tao', type=float, default=1)
parser.add_argument('--beta', type=float, default=1)
parser.add_argument('--mlp_inference_bool', type=int, default=1, help="embedding with mlp predictor")
parser.add_argument('--neg_alpha', type=int, default=0, help="negative alpha ")
parser.add_argument('--load_json', type=int, default=0, help="load json")

args = parser.parse_args()
print(args)
class Bunch(object):
  def __init__(self, adict):
    self.__dict__.update(adict)


if args.lr == 0:
    args.lr = 0.001
elif args.lr == 1:
    args.lr = 0.01


def extract_args_from_json(json_file_path,args_dict):

    import json
    summary_filename = 'json/'+json_file_path+'.json'
    import os
    if os.path.isfile(summary_filename):
        with open(summary_filename) as f:
            summary_dict = json.load(fp=f)
        for key in summary_dict.keys():
            args_dict[key] = summary_dict[key]
        return args_dict
    return args_dict

print(args.dataset)
args_dict = vars(args)
if args.load_json:
    args_dict = extract_args_from_json(args.dataset, args_dict)
    args = Bunch(args_dict)

### Seeds ###
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(args.seed)

### device ###

device = 'cuda:{}'.format(args.gpu) if torch.cuda.is_available() else 'cpu'
device = torch.device(device)

### Load and preprocess data ###
dataset = load_dataset(args.dataset, args.sub_dataset)

if len(dataset.label.shape) == 1:
    dataset.label = dataset.label.unsqueeze(1)

if args.dataset in ['snap-patents', 'CiteSeer','genius']:
    dataset.graph['edge_index'],edge_attr, mask = remove_isolated_nodes(dataset.graph['edge_index'])
    dataset.graph['node_feat']= dataset.graph['node_feat'][mask]
    dataset.label= dataset.label[mask]

if args.rand_split or args.dataset in ['snap-patents','ogbn-proteins', 'wiki','Cora', 'PubMed','genius']:
    split_idx_lst = [dataset.get_idx_split(train_prop=args.train_prop, valid_prop=args.valid_prop)
                for _ in range(args.runs)]
else:
    split_idx_lst = load_fixed_splits(args.dataset, args.sub_dataset)



dataset.graph['num_nodes']=dataset.graph['node_feat'].shape[0]

n = dataset.graph['num_nodes']
# infer the number of classes for non one-hot and one-hot labels
c = max(dataset.label.max().item() + 1, dataset.label.shape[1])
d = dataset.graph['node_feat'].shape[1]

if not args.directed and args.dataset != 'ogbn-proteins':
    dataset.graph['edge_index'] = to_undirected(dataset.graph['edge_index'])


dataset.graph['edge_index'], dataset.graph['node_feat'] = \
    dataset.graph['edge_index'].to(
        device), dataset.graph['node_feat'].to(device)

dataset.label = dataset.label.to(device)


sampled_neighborhoods = sample_neighborhood(dataset, device, args)

if args.neg_alpha:
    sampled_neg_neighborhoods = sample_neg_neighborhood(dataset, device, args)
    print('sample_neg_neighborhoods')

def evaluate(model,dataset,split_idx):
    model.eval()
    embedding = model.online_encoder(dataset)
    embedding = embedding.detach()
    emb_dim, num_class = embedding.shape[1], dataset.label.unique().shape[0]
    train_accs, dev_accs, test_accs =[], [], []

    for i in range(10):
        train_idx = np.array(split_idx['train'].cpu())
        valid_idx = np.array(split_idx['valid'].cpu())
        test_idx = np.array(split_idx['test'].cpu())

        dataset.label=dataset.label.type(torch.LongTensor).to(device)
        train_label = dataset.label[train_idx]
        valid_label = dataset.label[valid_idx]
        test_label = dataset.label[test_idx]

        classifier = LogisticRegression(emb_dim, num_class).to(device)
        optimizer_LR = torch.optim.AdamW(classifier.parameters(), lr=0.01, weight_decay=0.01)

        for epoch in range(100):
            classifier.train()
            logits, loss = classifier(embedding[train_idx, :], train_label.squeeze())
            # print ("finetune epoch: {}, finetune loss: {}".format(epoch, loss))
            optimizer_LR.zero_grad()
            loss.backward()
            optimizer_LR.step()

        train_logits, _ = classifier(embedding[train_idx, :], train_label.squeeze())
        dev_logits, _ = classifier(embedding[valid_idx, :], valid_label.squeeze())
        test_logits, _ = classifier(embedding[test_idx, :], test_label.squeeze())
        train_preds = torch.argmax(train_logits, dim=1)
        dev_preds = torch.argmax(dev_logits, dim=1)
        test_preds = torch.argmax(test_logits, dim=1)

        train_acc = (torch.sum(train_preds == train_label.squeeze()).float() / train_label.squeeze().shape[
            0]).detach().cpu().numpy()
        dev_acc = (torch.sum(dev_preds == valid_label.squeeze()).float() / valid_label.squeeze().shape[
            0]).detach().cpu().numpy()
        test_acc = (torch.sum(test_preds == test_label.squeeze()).float() /
                    test_label.squeeze().shape[0]).detach().cpu().numpy()

        train_accs.append(train_acc*100)
        dev_accs.append(dev_acc * 100)
        test_accs.append(test_acc * 100)

    train_accs = np.stack(train_accs)
    dev_accs = np.stack(dev_accs)
    test_accs = np.stack(test_accs)

    train_acc, train_std = train_accs.mean(), train_accs.std()
    dev_acc, dev_std = dev_accs.mean(), dev_accs.std()
    test_acc, test_std = test_accs.mean(), test_accs.std()

    return train_acc, dev_acc, test_acc


### Choose encoder ###

if args.encoder == 'GCN':
    encoder = GCN(in_channels=d,
                  hidden_channels=args.hidden_channels,
                  out_channels=args.hidden_channels,
                  num_layers=args.num_layers, use_bn=not args.no_bn,
                  dropout=args.dropout).to(device)
else:
    encoder = MLP(in_channels=d,
                  hidden_channels=args.hidden_channels,
                  out_channels=args.hidden_channels,
                  num_layers=args.num_layers,
                  dropout=args.dropout).to(device)

model = DSSL(encoder=encoder,
             hidden_channels=args.hidden_channels,
             dataset=dataset,
             device=device,
             cluster_num=args.cluster_num,
             alpha=args.alpha,
             gamma=args.gamma,
            tao=args.tao,
            beta=arg.beta,
             moving_average_decay=args.tau).to(device)

if not args.mlp_bool: # 0 embedding without mlp predictor
    model.Embedding_mlp = False
if not args.mlp_inference_bool: # 0 embedding without mlp predictor
    model.inference_mlp = False
### logger ###
logger = Logger(args.runs, args)

model.train()
print('MODEL:', model)

# print (split_idx_lst)
import datetime

time_now = datetime.datetime.now()
print('start training')
print(time_now)

meanAcc = 0


### Training loop ###
for run in range(args.runs):
    split_idx = split_idx_lst[run]
    model.reset_parameters()
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    best_val = float('-inf')
    loss_lst = []
    best_loss = float('inf')

    for epoch in range(args.epochs):
        # pre-training
        model.train()
        batch_size = args.batch_size
        perm = torch.randperm(n)
        epoch_loss = 0
        for batch in range(0, n, batch_size):
            optimizer.zero_grad()
            online_embedding = model.online_encoder(dataset)
            target_embedding = model.target_encoder(dataset)
            batch_idx = perm[batch:batch + batch_size]
            batch_idx = batch_idx.to(device)
            batch_neighbor_index = sampled_neighborhoods[batch_idx]
            batch_embedding = online_embedding[batch_idx].to(device)
            batch_embedding = F.normalize(batch_embedding, dim=-1, p=2)
            batch_neighbor_embedding = [target_embedding[i, :].unsqueeze(0) for i in batch_neighbor_index]
            batch_neighbor_embedding = torch.cat(batch_neighbor_embedding, dim=0).to(device)
            batch_neighbor_embedding = F.normalize(batch_neighbor_embedding, dim=-1, p=2)
            main_loss, context_loss, entropy_loss, k_node = model(batch_embedding, batch_neighbor_embedding)
            tmp = F.one_hot(torch.argmax(k_node, dim=1), num_classes=args.cluster_num).type(torch.FloatTensor).to(device)
            batch_sum = (torch.reshape(torch.sum(tmp, 0), (-1, 1)))
            if args.neg_alpha:
                batch_neg_neighbor_index = sampled_neg_neighborhoods[batch_idx]
                batch_neighbor_embedding = [target_embedding[i, :].unsqueeze(0) for i in batch_neg_neighbor_index]
                batch_neighbor_embedding = torch.cat(batch_neighbor_embedding, dim=0).to(device)
                batch_neighbor_embedding = F.normalize(batch_neighbor_embedding, dim=-1, p=2)
                main_neg_loss, tmp, tmp, tmp = model(batch_embedding, batch_neighbor_embedding)
                loss = main_loss + args.gamma * (context_loss + entropy_loss) + main_neg_loss

            else:
                loss = main_loss+ args.gamma*(context_loss+entropy_loss)
            print("run : {}, batch : {}, main_loss: {}, context_loss: {}, entropy_loss: {}".format(run,batch,main_loss, context_loss, entropy_loss))
            loss.backward()
            optimizer.step()
            model.update_moving_average()
            epoch_loss = epoch_loss + loss
        if epoch %1== 0:
            model.eval()
            for batch in range(0, n, batch_size):
                online_embedding = model.online_encoder(dataset).detach().cpu()
                target_embedding = model.target_encoder(dataset).detach().cpu()
                batch_idx = perm[batch:batch + batch_size]
                batch_idx = batch_idx.to(device)
                batch_neighbor_index = sampled_neighborhoods[batch_idx]
                batch_target_embedding = target_embedding[batch_idx].to(device)
                batch_embedding = online_embedding[batch_idx].to(device)
                batch_neighbor_embedding = [target_embedding[i, :].unsqueeze(0) for i in batch_neighbor_index]
                batch_neighbor_embedding = torch.cat(batch_neighbor_embedding, dim=0).to(device)
                main_loss, context_loss, entropy_loss, k_node= model(batch_embedding, batch_neighbor_embedding)
                tmp = F.one_hot(torch.argmax(k_node, dim=1), num_classes=args.cluster_num).type(torch.FloatTensor).to(
                    device)
                if batch == 0:
                    cluster = torch.matmul(batch_embedding.t(),tmp )
                    batch_sum=(torch.reshape(torch.sum(tmp, 0), (-1, 1)))
                else:
                    cluster+=torch.matmul(batch_embedding.t(),tmp)
                    batch_sum += (torch.reshape(torch.sum(tmp, 0), (-1, 1)))
            cluster = F.normalize(cluster, dim=-1, p=2)
            model.update_cluster(cluster,batch_sum)
        print("epoch: {}, loss: {}".format(epoch, epoch_loss))

        if epoch %10 ==0:
            time_epoch = datetime.datetime.now()
            print('cost time')
            timetmp = time_epoch-time_now
            print(timetmp)
            train_acc, dev_acc, test_acc=evaluate(model,dataset,split_idx)
            print("run: {}, train acc: {}, val acc: {}, test acc: {}".format(run, train_acc,dev_acc,test_acc))

            if best_val < dev_acc:
                best_val = dev_acc
                besttest_acc = test_acc

    meanAcc +=besttest_acc


besttest_acc = meanAcc/args.runs
path = 'result/'+ args.dataset + '/'
import  os
if not os.path.exists(path):
    os.mkdir(path)
checkpt_file = path + '%.4fresult_' % (
    besttest_acc) +  '_encoder_%s' % args.encoder + '_cluster_num_%s' % args.cluster_num + '_neighbor_max_%s' % args.neighbor_max +  '_neg_alpha_%s' % args.neg_alpha +  '_mlp_inference_bool_%s' % args.mlp_inference_bool +  '_tao_%s' % args.tao+ '_beta_%s' % args.beta + '_lr_%s' % args.lr + '_gamma%s_mlp%s_epoch%s_weight_decay%s_hidden%s_dropout%s' % (
                   args.gamma, args.mlp_bool, args.epochs, args.weight_decay, args.hidden_channels,
                   args.dropout) + '.txt'
np.savetxt(checkpt_file, np.zeros(1), delimiter=' ')