import argparse
import math
import os
import shutil

import numpy as np
import torch
import torch.nn.functional as F
from scipy.sparse import csr_matrix
from tensorboardX import SummaryWriter
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from tqdm import tqdm

import dgl
import torch_geometric.transforms as T
import wandb
from dataset import MyNodePropPredDataset
from gnn import GNNModel
# from line_profiler import LineProfiler
from logger import Logger
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from optim_schedule import ScheduledOptim
from torch_geometric.utils import add_remaining_self_loops, to_undirected


class NodeClassificationDataset(torch.utils.data.Dataset):
    def __init__(self, dgl_data, ego_graphs, ego_idx):
        self.graph = dgl_data[0]
        self.label = dgl_data[1]
        
        self.ego_graphs = ego_graphs
        self.ego_idx = ego_idx
       
    def __getitem__(self, idx):
        ego = self.ego_idx[idx]
        nids = self.ego_graphs[ego]
        # assert nids[0] == ego
        subg = self.graph.subgraph(nids)
        label = self.label[ego]
        """
        n = subg.number_of_nodes()
        src, dst = subg.edges(form="uv", order="eid")
        src, dst = src.numpy(), dst.numpy()

        nfeat = [subg.ndata['nfeat'], ]
        for i in range(subg.edata['feat'].size(-1)):
            w = subg.edata['feat'][:, i].numpy()
            adj = sp.coo_matrix((w, (src, dst)), shape=(n, n))
            laplacian = sp.csgraph.laplacian(adj, normed=True)
            diag = laplacian.diagonal()
            laplacian = sp.diags(diag) - laplacian
            k = min(n - 2, self.hidden_size)
            x = eigen_decomposision(n, k, laplacian, self.hidden_size, retry=10)
            x = torch.tensor(x).float()
            nfeat.append(x)
        subg.ndata['nfeat'] = torch.cat(nfeat, dim=-1)
        """
        return ego, subg, label

    def __len__(self):
        return self.ego_idx.size(0)

def batcher(encoder="gnn"):

    def batcher_gnn(batch):
        idx, graph, label = zip(*batch)
        idx = torch.stack(idx).long()
        graph = dgl.batch(graph)
        label = torch.stack(label).long()
        return idx, graph, label

    return batcher_gnn

def train(model, loader, device, optimizer, first_token=True):
    model.train()

    total_loss = 0
    for batch in tqdm(loader, desc="Iteration"):
        optimizer.zero_grad()

        batch = [x.to(device) for x in batch]
        g, y = batch[1:]

        res = model(g, first_token=first_token)
        if first_token:
            out = res
        else:
            out, y = res

        loss = F.cross_entropy(out, y.squeeze(1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    return total_loss / len(loader)

@torch.no_grad()
def test(model, loader, device, evaluator, y_true, split_idx, eval_type='valid'):
    model.eval()

    y_pred = torch.zeros_like(y_true).long()

    for batch in tqdm(loader, desc="Iteration"):
        batch = [x.to(device) for x in batch]
        idx, g, y = batch

        out = model(g)
        y_pred[idx] = out.argmax(dim=-1, keepdim=True)

    metric = 'acc'

    if eval_type == 'valid':
        valid_acc = evaluator.eval({
            'y_true': y_true[split_idx['valid']],
            'y_pred': y_pred[split_idx['valid']],
        })[metric]
        return valid_acc

    train_acc = evaluator.eval({
        'y_true': y_true[split_idx['train']],
        'y_pred': y_pred[split_idx['train']],
    })[metric]
    test_acc = evaluator.eval({
        'y_true': y_true[split_idx['test']],
        'y_pred': y_pred[split_idx['test']],
    })[metric]

    return train_acc, test_acc

def get_exp_name(dataset, para_dic, input_exp_name):
    para_name = '_'.join([dataset] + [key + str(value) for key, value in para_dic.items()])
    exp_name = para_name + '_' + input_exp_name

    if os.path.exists('runs/' + exp_name):
        shutil.rmtree('runs/' + exp_name)

    return exp_name

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def main():
    parser = argparse.ArgumentParser(description='OGBN (GNN)')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--dataset', type=str, default='arxiv')
    parser.add_argument('--model', type=str, default='gcn')
    parser.add_argument('--log_steps', type=int, default=1)
    parser.add_argument('--num_layers', type=int, default=3)
    parser.add_argument('--num_heads', type=int, default=1, help='only used by gat')
    parser.add_argument('--ego_size', type=int, default=128)
    parser.add_argument('--hidden_size', type=int, default=128)
    parser.add_argument('--hidden_dropout', type=float, default=0.3)
    parser.add_argument('--weight_decay', type=float, default=0.0)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--lr_scale', type=float, default=1.0)
    parser.add_argument('--epochs', type=int, default=500)
    parser.add_argument('--early_stopping', type=int, default=100)
    parser.add_argument('--batch_size', type=int, default=512)
    parser.add_argument('--eval_batch_size', type=int, default=2048)
    parser.add_argument('--batch_norm', type=int, default=1)
    parser.add_argument('--residual', type=int, default=1)
    parser.add_argument('--linear_layer', type=int, default=1)
    parser.add_argument('--norm', type=str, default='both')
    parser.add_argument('--num_workers', type=int, default=4, help='number of workers')
    parser.add_argument('--runs', type=int, default=1, help='only support one run...')
    parser.add_argument("--optimizer", type=str, default='adamw', choices=['adam', 'adamw'], help="optimizer")
    parser.add_argument('--warmup', type=int, default=0)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--load_path', type=str, default='')
    parser.add_argument('--exp_name', type=str, default='')
    args = parser.parse_args()
    print(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    para_dic = {'': args.model, 'nl': args.num_layers, 'nh': args.num_heads, 'es': args.ego_size, 'hs': args.hidden_size,
                'hd': args.hidden_dropout, 'bs': args.batch_size, 'op': args.optimizer, 
                'lr': args.lr, 'wd': args.weight_decay, 'ls': args.lr_scale, 'bn': args.batch_norm, 
                'rs': args.residual, 'll': args.linear_layer, 'sd': args.seed}
    para_dic['warm'] = args.warmup
    exp_name = get_exp_name(args.dataset, para_dic, args.exp_name)

    wandb_name = exp_name.replace('_sd'+str(args.seed), '')
    wandb.init(name=wandb_name, project="xfmr4gl")
    wandb.config.update(args)

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    dataset = DglNodePropPredDataset(name=f'ogbn-{args.dataset}')
    split_idx = dataset.get_idx_split()

    ego_graphs = np.load(f'data/{args.dataset}-ego-graphs-{args.ego_size}.npy', allow_pickle=True)
    ego_graphs = [torch.LongTensor(ego) for ego in ego_graphs]

    data = dataset[0]
    graph = dgl.remove_self_loop(data[0])
    graph = dgl.add_self_loop(graph)
    if args.dataset == 'arxiv' or args.dataset == 'papers100M':
        temp_graph = dgl.to_bidirected(graph)
        temp_graph.ndata['feat'] = graph.ndata['feat']
        graph = temp_graph
    data = (graph, data[1].long())

    graph = data[0]
    graph.ndata['labels'] = data[1]
    graph.ndata['is_train_idx'] = torch.zeros(graph.number_of_nodes(), dtype=torch.bool)
    graph.ndata['is_train_idx'][split_idx['train']] = 1

    y_true = data[1].clone().to(device)

    train_dataset = NodeClassificationDataset(data, ego_graphs, split_idx['train'])
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=batcher('gnn'), pin_memory=True)

    valid_dataset = NodeClassificationDataset(data, ego_graphs, split_idx['valid'])
    valid_loader = DataLoader(valid_dataset, batch_size=args.eval_batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=batcher('gnn'), pin_memory=True)

    test_dataset = NodeClassificationDataset(data, ego_graphs, torch.cat((split_idx['train'], split_idx['test'])))
    test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=batcher('gnn'), pin_memory=True)

    args.batch_norm = args.batch_norm == 1
    args.residual = args.residual == 1
    args.linear_layer = args.linear_layer == 1

    model = GNNModel(conv_type=args.model, input_size=graph.ndata['feat'].shape[1], hidden_size=args.hidden_size, num_layers=args.num_layers, 
                     num_classes=dataset.num_classes, batch_norm=args.batch_norm, residual=args.residual, 
                     dropout=args.hidden_dropout, linear_layer=args.linear_layer, norm=args.norm, num_heads=args.num_heads).to(device)
    wandb.watch(model, log='all')

    pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('model parameters:', pytorch_total_params)

    evaluator = Evaluator(name=f'ogbn-{args.dataset}')
    logger = Logger(args.runs, args)

    if not os.path.exists('runs'):
        os.mkdir('runs')
    if not os.path.exists('saved'):
        os.mkdir('saved')

    try:
        writer = SummaryWriter('runs/' + exp_name)
    except Exception as e:
        print(e)
        writer = None

    if args.load_path:
        model.load_state_dict(torch.load(args.load_path, map_location='cuda:0'))

        valid_acc = test(model, valid_loader, device, evaluator, y_true, split_idx, 'valid')
        valid_output = f'Valid: {100 * valid_acc:.2f}% '

        cor_train_acc, cor_test_acc = test(model, test_loader, device, evaluator, y_true, split_idx, 'test')
        train_output = f'Train: {100 * cor_train_acc:.2f}%, '
        test_output = f'Test: {100 * cor_test_acc:.2f}%'

        print(train_output + valid_output + test_output)
        return

    for run in range(args.runs):
        model.reset_parameters()
        best_val_acc = 0
        cor_train_acc = 0
        cor_test_acc = 0
        patience = 0

        if args.optimizer == 'adam':
            optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        elif args.optimizer == 'adamw':
            optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        else:
            raise NotImplementedError
        lr_scheduler = None
        if args.warmup > 0:
            optimizer = ScheduledOptim(optimizer, args.hidden_size, n_warmup_steps=args.warmup, init_lr_scale=args.lr_scale)
        elif args.warmup < 0:
            lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, mode='max', factor=0.5, patience=100, verbose=True, min_lr=args.lr / 10
            )

        for epoch in range(1, 1 + args.epochs):
            if lr_scheduler is not None:
                if epoch <= 50:
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = args.lr * epoch / 50

            # lp = LineProfiler()
            # lp_wrapper = lp(train)
            # loss = lp_wrapper(model, train_loader, device, optimizer, args)
            # lp.print_stats()
            loss = train(model, train_loader, device, optimizer, args)

            train_output = valid_output = test_output = ''
            if epoch >= 20 and epoch % args.log_steps == 0:
                valid_acc = test(model, valid_loader, device, evaluator, y_true, split_idx, 'valid')
                if lr_scheduler is not None:
                    lr_scheduler.step(valid_acc)
                valid_output = f'Valid: {100 * valid_acc:.2f}% '

                if valid_acc > best_val_acc:
                    best_val_acc = valid_acc
                    cor_train_acc, cor_test_acc = test(model, test_loader, device, evaluator, y_true, split_idx, 'test')
                    logger.add_result(run, (cor_train_acc, valid_acc, cor_test_acc))
                    train_output = f'Train: {100 * cor_train_acc:.2f}%, '
                    test_output = f'Test: {100 * cor_test_acc:.2f}%'
                    patience = 0
                    try:
                        torch.save(model.state_dict(), 'saved/' + exp_name + '.pt')
                        wandb.save('saved/' + exp_name + '.pt')
                    except FileNotFoundError as e:
                        print(e)
                else:
                    patience += 1
                    if patience >= args.early_stopping:
                        print('Early stopping...')
                        break
                if writer is not None:
                    writer.add_scalar('loss', loss, epoch)
                    writer.add_scalar('acc/valid', valid_acc, epoch)
                    writer.add_scalar('acc/best_val', best_val_acc, epoch)
                    writer.add_scalar('acc/cor_train', cor_train_acc, epoch)
                    writer.add_scalar('acc/cor_test', cor_test_acc, epoch)
                    writer.add_scalar('lr', get_lr(optimizer), epoch)
                wandb.log({'Train Loss': loss, 'Valid Acc': valid_acc, 'best_val_acc': best_val_acc, 
                           'cor_train_acc': cor_train_acc, 'cor_test_acc': cor_test_acc, 'LR': get_lr(optimizer)})
            else:
                if writer is not None:
                    writer.add_scalar('loss', loss, epoch)
                    writer.add_scalar('lr', get_lr(optimizer), epoch)
                wandb.log({'Train Loss': loss, 'LR': get_lr(optimizer)})
            print(f'Run: {run + 1:02d}, '
                    f'Epoch: {epoch:02d}, '
                    f'Loss: {loss:.4f}, ' + 
                    train_output + valid_output + test_output)

        logger.print_statistics(run)
    logger.print_statistics()


if __name__ == "__main__":
    main()
