import argparse
import math
import os
import shutil

import numpy as np
import torch
import torch.nn.functional as F
from scipy.sparse import csr_matrix
from tensorboardX import SummaryWriter
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from tqdm import tqdm

import torch_geometric.transforms as T
import wandb
from dataset import MyNodePropPredDataset
from logger import Logger
from ogb.nodeproppred import Evaluator, PygNodePropPredDataset
from optim_schedule import ScheduledOptim
from torch_geometric.utils import add_remaining_self_loops, to_undirected
from transformer import TransformerModel


class NodeClassificationDataset(torch.utils.data.Dataset):
    def __init__(self, x, y, ego_graphs, pe, args, num_classes, adj=None, y_partial=None):
        super(NodeClassificationDataset).__init__()
        self.x = x
        self.y = y
        self.ego_graphs = ego_graphs
        self.pe = pe
        self.args = args
        self.num_classes = num_classes
        self.adj = adj
        self.y_partial = y_partial

    def __len__(self):
        return len(self.ego_graphs)

    def __getitem__(self, idx):
        return idx

def batcher(dataset):
    def batcher_dev(idx):
        idx = torch.LongTensor(idx)
        src = dataset.ego_graphs[idx]
        batch_idx = src[:, 0]
        src_padding = (src == -1)
        src[src_padding] = 0
        shape = src.shape

        src_mask = [torch.repeat_interleave(dataset.adj[idx], dataset.args.num_heads, dim=0)] if dataset.adj is not None else []
        pe_batch = [dataset.pe[src.view(-1)].view(shape[0], shape[1], -1)] if dataset.pe is not None else []
        if dataset.pe is not None and dataset.args.isolated_pe > 0:
            pe_batch[0][torch.sum(~src_padding, dim=-1) <= dataset.args.isolated_pe, 0, :] = 0
        src_label = [dataset.y_partial[src.view(-1)].view(shape[0], shape[1], -1)] if dataset.y_partial is not None else []
        if dataset.args.use_label == 1:
            src_label[0][:, 0, :] = 0
        src = dataset.x[src.view(-1)].view(shape[0], shape[1], -1)
        y = dataset.y.squeeze(1)[batch_idx].long()

        return [src, src_padding, batch_idx, y] + src_mask + pe_batch + src_label

    return batcher_dev

def cross_entropy(x, labels):
    y = F.cross_entropy(x, labels, reduction="none")
    y = torch.log(0.5 + y) - math.log(0.5)
    return torch.mean(y)

def train(model, loader, device, optimizer, args):
    model.train()

    total_loss = 0
    for batch in tqdm(loader, desc="Iteration"):
        optimizer.zero_grad()

        batch = [x.to(device) for x in batch]
        src, src_padding, _, y = batch[:4]
        src_mask = batch[4] if args.mask else None
        pe_batch = batch[4 + args.mask] if args.pe_type else None
        src_label = batch[-1] if args.use_label else None

        out = model(src, src_mask=src_mask, padding=src_padding, pe=pe_batch, src_label=src_label)

        if args.label_trick == 1:
            loss = cross_entropy(out, y)
        else:
            loss = F.cross_entropy(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    return total_loss / len(loader)

@torch.no_grad()
def test(model, loader, device, evaluator, args, datay, num_nodes, split_idx, eval_type='valid'):
    model.eval()

    y_pred = torch.zeros(num_nodes, 1).long()
    y_pred = y_pred.to(device)

    for batch in tqdm(loader, desc="Iteration"):
        batch = [x.to(device) for x in batch]
        src, src_padding, batch_idx, y = batch[:4]
        src_mask = batch[4] if args.mask else None
        pe_batch = batch[4 + args.mask] if args.pe_type else None
        src_label = batch[-1] if args.use_label else None

        out = model(src, src_mask=src_mask, padding=src_padding, pe=pe_batch, src_label=src_label)
        y_pred[batch_idx] = out.argmax(dim=-1, keepdim=True)

    metric = 'acc'
    y_train, y_valid, y_test = datay

    if eval_type == 'valid':
        valid_acc = evaluator.eval({
            'y_true': y_valid,
            'y_pred': y_pred[split_idx['valid']],
        })[metric]
        return valid_acc

    train_acc = evaluator.eval({
        'y_true': y_train,
        'y_pred': y_pred[split_idx['train']],
    })[metric]
    test_acc = evaluator.eval({
        'y_true': y_test,
        'y_pred': y_pred[split_idx['test']],
    })[metric]

    return train_acc, test_acc

def get_exp_name(dataset, para_dic, input_exp_name):
    para_name = '_'.join([dataset] + [key + str(value) for key, value in para_dic.items()])
    exp_name = para_name + '_' + input_exp_name

    if os.path.exists('runs/' + exp_name):
        shutil.rmtree('runs/' + exp_name)

    return exp_name

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def main():
    parser = argparse.ArgumentParser(description='OGBN (GNN)')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--dataset', type=str, default='arxiv')
    parser.add_argument('--log_steps', type=int, default=1)
    parser.add_argument('--num_layers', type=int, default=4)
    parser.add_argument('--num_heads', type=int, default=2)
    parser.add_argument('--ego_size', type=int, default=64)
    parser.add_argument('--hidden_size', type=int, default=64)
    parser.add_argument('--input_dropout', type=float, default=0.2)
    parser.add_argument('--hidden_dropout', type=float, default=0.4)
    parser.add_argument('--label_dropout', type=float, default=0.5)
    parser.add_argument('--weight_decay', type=float, default=0.05)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--lr_scale', type=float, default=1.0)
    parser.add_argument('--epochs', type=int, default=500)
    parser.add_argument('--early_stopping', type=int, default=100)
    parser.add_argument('--batch_size', type=int, default=512)
    parser.add_argument('--eval_batch_size', type=int, default=2048)
    parser.add_argument('--layer_norm', type=int, default=0)
    parser.add_argument('--use_label', type=int, default=0)
    parser.add_argument('--label_trick', type=int, default=0)
    parser.add_argument('--src_scale', type=int, default=0)
    parser.add_argument('--isolated_pe', type=int, default=0)
    parser.add_argument('--num_workers', type=int, default=4, help='number of workers')
    parser.add_argument('--runs', type=int, default=1, help='only support one run...')
    parser.add_argument('--pe_type', type=int, default=1)
    parser.add_argument('--mask', type=int, default=0)
    parser.add_argument("--optimizer", type=str, default='adamw', choices=['adam', 'adamw'], help="optimizer")
    parser.add_argument('--warmup', type=int, default=10000)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--load_path', type=str, default='')
    parser.add_argument('--exp_name', type=str, default='')
    args = parser.parse_args()
    print(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    para_dic = {'nl': args.num_layers, 'nh': args.num_heads, 'es': args.ego_size, 'hs': args.hidden_size,
                'id': args.input_dropout, 'hd': args.hidden_dropout, 'bs': args.batch_size, 'pe': args.pe_type, 
                'op': args.optimizer, 'lr': args.lr, 'wd': args.weight_decay, 'ls': args.lr_scale, 
                'ln': args.layer_norm, 'ul': args.use_label, 'lt': args.label_trick, 'lp': args.isolated_pe, 
                'sc': args.src_scale, 'ld': args.label_dropout, 'sd': args.seed}
    para_dic['warm'] = args.warmup
    para_dic['mask'] = args.mask
    exp_name = get_exp_name(args.dataset, para_dic, args.exp_name)

    wandb_name = exp_name.replace('_sd'+str(args.seed), '')
    wandb.init(name=wandb_name, project="xfmr4gl")
    wandb.config.update(args)

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    if args.dataset == 'papers100M':
        dataset = MyNodePropPredDataset(name=args.dataset)
    else:
        dataset = PygNodePropPredDataset(name=f'ogbn-{args.dataset}')

    ego_graphs = np.load(f'data/{args.dataset}-ego-graphs-padding-{args.ego_size}.npy')
    ego_graphs = torch.LongTensor(ego_graphs)

    pe = None
    if args.pe_type == 1:
        pe = torch.load(f'data/{args.dataset}-embedding-{args.hidden_size}.pt')
    elif args.pe_type == 2:
        pe = np.fromfile("data/paper100m.pro", dtype=np.float32).reshape(-1, 128)
        pe = torch.FloatTensor(pe)
        if args.hidden_size < 128:
            pe = pe[:, :args.hidden_size]

    data = dataset[0]

    adj = None
    if args.mask:
        adj = torch.BoolTensor(~np.load(f'data/{args.dataset}-ego-graphs-adj-{args.ego_size}.npy'))

    split_idx = dataset.get_idx_split()
    num_classes = dataset.num_classes

    if args.use_label == 1:
        onehot = torch.zeros([data.x.shape[0], num_classes])
        idx = split_idx['train']
        onehot[idx, data.y[idx, 0]] = 1
        # data.x = torch.cat([data.x, onehot], dim=-1)
        data.y_partial = onehot
    else:
        data.y_partial = None

    y_train = data.y[split_idx['train']].to(device)
    y_valid = data.y[split_idx['valid']].to(device)
    y_test = data.y[split_idx['test']].to(device)

    y_train, y_valid, y_test = y_train.long(), y_valid.long(), y_test.long()
    datay = (y_train, y_valid, y_test)

    train_dataset = NodeClassificationDataset(data.x, data.y, ego_graphs[split_idx['train']], pe, args, num_classes, adj, data.y_partial)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=batcher(train_dataset), pin_memory=True)

    valid_dataset = NodeClassificationDataset(data.x, data.y, ego_graphs[split_idx['valid']], pe, args, num_classes, adj, data.y_partial)
    valid_loader = DataLoader(valid_dataset, batch_size=args.eval_batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=batcher(valid_dataset), pin_memory=True)

    test_dataset = NodeClassificationDataset(data.x, data.y, ego_graphs[torch.cat((split_idx['train'], split_idx['test']))], pe, args, num_classes, adj, data.y_partial)
    test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=batcher(test_dataset), pin_memory=True)

    model = TransformerModel(data.x.size(1), args.hidden_size,
                             args.num_heads, args.hidden_size,
                             args.num_layers, num_classes, 
                             args.input_dropout, args.hidden_dropout,
                             layer_norm=args.layer_norm, use_label=args.use_label,
                             ldropout=args.label_dropout, src_scale=args.src_scale).to(device)
    wandb.watch(model, log='all')

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
        model = nn.DataParallel(model)

    pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('model parameters:', pytorch_total_params)

    evaluator = Evaluator(name=f'ogbn-{args.dataset}')
    logger = Logger(args.runs, args)

    if not os.path.exists('runs'):
        os.mkdir('runs')
    if not os.path.exists('saved'):
        os.mkdir('saved')

    try:
        writer = SummaryWriter('runs/' + exp_name)
    except Exception as e:
        print(e)
        writer = None

    if args.load_path:
        model.load_state_dict(torch.load(args.load_path, map_location='cuda:0'))

        valid_acc = test(model, valid_loader, device, evaluator, args, datay, data.num_nodes, split_idx, 'valid')
        valid_output = f'Valid: {100 * valid_acc:.2f}% '

        cor_train_acc, cor_test_acc = test(model, test_loader, device, evaluator, args, datay, data.num_nodes, split_idx, 'test')
        train_output = f'Train: {100 * cor_train_acc:.2f}%, '
        test_output = f'Test: {100 * cor_test_acc:.2f}%'

        print(train_output + valid_output + test_output)
        return

    for run in range(args.runs):
        if torch.cuda.device_count() > 1:
            model.module.init_weights()
        else:
            model.init_weights()
        best_val_acc = 0
        cor_train_acc = 0
        cor_test_acc = 0
        patience = 0

        if args.optimizer == 'adam':
            optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        elif args.optimizer == 'adamw':
            optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        else:
            raise NotImplementedError
        lr_scheduler = None
        if args.warmup > 0:
            optimizer = ScheduledOptim(optimizer, args.hidden_size if args.hidden_size > 0 else data.x.size(1), n_warmup_steps=args.warmup, init_lr_scale=args.lr_scale)
        elif args.warmup < 0:
            lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, mode='max', factor=0.5, patience=100, verbose=True, min_lr=args.lr / 10
            )

        for epoch in range(1, 1 + args.epochs):
            if lr_scheduler is not None:
                if epoch <= 50:
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = args.lr * epoch / 50

            loss = train(model, train_loader, device, optimizer, args)

            train_output = valid_output = test_output = ''
            if epoch >= 20 and epoch % args.log_steps == 0:
                valid_acc = test(model, valid_loader, device, evaluator, args, datay, data.num_nodes, split_idx, 'valid')
                if lr_scheduler is not None:
                    lr_scheduler.step(valid_acc)
                valid_output = f'Valid: {100 * valid_acc:.2f}% '

                if valid_acc > best_val_acc:
                    best_val_acc = valid_acc
                    cor_train_acc, cor_test_acc = test(model, test_loader, device, evaluator, args, datay, data.num_nodes, split_idx, 'test')
                    logger.add_result(run, (cor_train_acc, valid_acc, cor_test_acc))
                    train_output = f'Train: {100 * cor_train_acc:.2f}%, '
                    test_output = f'Test: {100 * cor_test_acc:.2f}%'
                    patience = 0
                    try:
                        if torch.cuda.device_count() > 1:
                            torch.save(model.module.state_dict(), 'saved/' + exp_name + '.pt')
                        else:
                            torch.save(model.state_dict(), 'saved/' + exp_name + '.pt')
                        wandb.save('saved/' + exp_name + '.pt')
                    except FileNotFoundError as e:
                        print(e)
                else:
                    patience += 1
                    if patience >= args.early_stopping:
                        print('Early stopping...')
                        break
                if writer is not None:
                    writer.add_scalar('loss', loss, epoch)
                    writer.add_scalar('acc/valid', valid_acc, epoch)
                    writer.add_scalar('acc/best_val', best_val_acc, epoch)
                    writer.add_scalar('acc/cor_train', cor_train_acc, epoch)
                    writer.add_scalar('acc/cor_test', cor_test_acc, epoch)
                    writer.add_scalar('lr', get_lr(optimizer), epoch)
                wandb.log({'Train Loss': loss, 'Valid Acc': valid_acc, 'best_val_acc': best_val_acc, 
                           'cor_train_acc': cor_train_acc, 'cor_test_acc': cor_test_acc, 'LR': get_lr(optimizer)})
            else:
                if writer is not None:
                    writer.add_scalar('loss', loss, epoch)
                    writer.add_scalar('lr', get_lr(optimizer), epoch)
                wandb.log({'Train Loss': loss, 'LR': get_lr(optimizer)})
            print(f'Run: {run + 1:02d}, '
                    f'Epoch: {epoch:02d}, '
                    f'Loss: {loss:.4f}, ' + 
                    train_output + valid_output + test_output)

        logger.print_statistics(run)
    logger.print_statistics()


if __name__ == "__main__":
    main()
