import argparse
import math
import os
import shutil

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.sparse import csr_matrix
from tensorboardX import SummaryWriter
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from tqdm import tqdm

import torch_geometric.transforms as T
import wandb
from ogb.linkproppred import Evaluator, PygLinkPropPredDataset
from optim_schedule import ScheduledOptim
from torch_geometric.utils import add_remaining_self_loops, to_undirected
from transformer import TransformerModel

class LinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(LinkPredictor, self).__init__()

        self.lins = torch.nn.ModuleList()
        self.lins.append(torch.nn.Linear(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
        self.lins.append(torch.nn.Linear(hidden_channels, out_channels))

        self.dropout = dropout

    def reset_parameters(self):
        for lin in self.lins:
            lin.reset_parameters()

    def forward(self, x_i, x_j):
        x = x_i * x_j
        for lin in self.lins[:-1]:
            x = lin(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lins[-1](x)
        return torch.sigmoid(x)

class LinkPredictionDataset(torch.utils.data.Dataset):
    def __init__(self, x, edge, ego_graphs, pe=None, mask=None, num_heads=1):
        super(LinkPredictionDataset).__init__()
        self.x = x
        self.edge = edge
        self.ego_graphs = ego_graphs
        self.pe = pe
        self.mask = mask
        self.num_heads = num_heads

    def __len__(self):
        return len(self.edge)

    def __getitem__(self, idx):
        edge = self.edge[idx]
        neg_edge = torch.randint(0, len(self.x), edge[:1].size(), dtype=torch.long)
        return torch.cat([edge, neg_edge])

def make_batcher(dataset):
    def batcher(idx):
        idx = torch.stack(idx)
        bs = idx.shape[0]
        idx = idx.view(-1)
        src = dataset.ego_graphs[idx]
        src_padding = (src == -1)
        src[src_padding] = 0
        shape = src.shape

        src_mask = [torch.repeat_interleave(torch.BoolTensor(dataset.mask[idx.cpu().numpy()]), dataset.num_heads, dim=0)] if dataset.mask is not None else []
        pe_batch = [dataset.pe[src.view(-1)].view(shape[0], shape[1], -1)] if dataset.pe is not None else []
        src = dataset.x[src.view(-1)].view(shape[0], shape[1], -1)
        batch = [src, src_padding] + src_mask + pe_batch
        return [x.view(bs, 3, *x.shape[1:]) for x in batch]

    return batcher

def train(model, predictor, data, ego_graphs, split_edge, args, optimizer, pe, adj, device):
    model.train()
    predictor.train()

    pos_train_edge = split_edge['train']['edge']

    total_loss = total_examples = 0
    train_dataset = LinkPredictionDataset(data.x, pos_train_edge, ego_graphs, pe, adj, args.num_heads)
    tqdm_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, collate_fn=make_batcher(train_dataset), pin_memory=True)

    for t, batch in enumerate(tqdm_loader):
        optimizer.zero_grad()
        batch = [x.to(device) for x in batch]
        src, src_padding = batch[:2]
        src_mask = batch[2] if args.mask else None
        pe_batch = batch[-1] if pe is not None else None
        h = []
        for i in range(3):
            h.append(model(src[:, i], src_mask=src_mask[:, i] if src_mask is not None else None, padding=src_padding[:, i], pe=pe_batch[:, i] if pe_batch is not None else None))

        pos_out = predictor(h[0], h[1])
        pos_loss = -torch.log(pos_out + 1e-15).mean()

        neg_out = predictor(h[0], h[2])
        neg_loss = -torch.log(1 - neg_out + 1e-15).mean()

        loss = pos_loss + neg_loss
        # tqdm_loader.set_description(f"loss={loss.item():.2f}")
        loss.backward()
        optimizer.step()

        num_examples = pos_out.size(0)
        total_loss += loss.item() * num_examples
        total_examples += num_examples
    
    yield total_loss / total_examples

@torch.no_grad()
def test(model, predictor, data, ego_graphs, split_edge, args, evaluator, pe, adj, device):
    if args.dataset == "citation":
        return test2(model, predictor, data, ego_graphs, split_edge, args, evaluator, pe, adj, device)
    else:
        return test1(model, predictor, data, ego_graphs, split_edge, args, evaluator, pe, adj, device)

def construct_batch(data, ego_graphs, idx, pe, adj, mask, num_heads):
    src = ego_graphs[idx]
    src_padding = (src == -1)
    src[src_padding] = 0
    shape = src.shape

    src_mask = None
    if mask:
        src_mask = torch.BoolTensor(adj[idx.cpu().numpy()])
        src_mask = torch.repeat_interleave(src_mask, num_heads, dim=0)

    pe_batch = None
    if pe is not None:
        if len(pe.shape) == 2:
            pe_batch = pe[src.view(-1)].view(shape[0], shape[1], -1).transpose(0, 1)
        elif pe.shape[1] > 1:
            pe_batch = torch.tensor(pe[idx.cpu().numpy()]).float()
            pe_batch = pe_batch.transpose(0, 1)
        else:
            pe_batch = pe
    src = data.x[src.view(-1)].view(shape[0], shape[1], -1).transpose(0, 1)

    return src, src_mask, src_padding, pe_batch

@torch.no_grad()
def test2(model, predictor, data, ego_graphs, split_edge, args, evaluator, pe, adj, device):
    batch_size = args.batch_size
    model.eval()
    predictor.eval()

    def forward(idx):
        src, src_mask, src_padding, pe_batch = construct_batch(data, ego_graphs, idx, pe, adj, args.mask, args.num_heads)
        src = src.transpose(0, 1)
        src = src.to(device)
        src_padding = src_padding.to(device)
        if pe_batch is not None:
            pe_batch = pe_batch.transpose(0, 1)
            pe_batch = pe_batch.to(device)
        return model(src, src_mask=src_mask, padding=src_padding, pe=pe_batch)

    h = []
    for idx in range(0, data.x.shape[0], batch_size):
        src = list(range(idx, min(data.x.shape[0], idx + batch_size)))
        src = torch.LongTensor(src)
        hsrc = forward(src)
        hsrc = hsrc.detach().cpu()
        h.append(hsrc)
    
    h = torch.cat(h)
    assert h.shape[0] == data.x.shape[0]

    def test_split(split):
        source = split_edge[split]['source_node'].to(device)
        target = split_edge[split]['target_node'].to(device)
        target_neg = split_edge[split]['target_node_neg'].to(device)

        pos_preds = []
        for perm in DataLoader(range(source.size(0)), batch_size):
            src, dst = source[perm], target[perm]
            hsrc = h[src].to(device)
            hdst = h[dst].to(device)
            pos_preds += [predictor(hsrc, hdst).squeeze().cpu()]
        pos_pred = torch.cat(pos_preds, dim=0)

        neg_preds = []
        source = source.view(-1, 1).repeat(1, 1000).view(-1)
        target_neg = target_neg.view(-1)
        for perm in DataLoader(range(source.size(0)), batch_size):
            src, dst_neg = source[perm], target_neg[perm]
            hsrc = h[src].to(device)
            hdst_neg = h[dst_neg].to(device)
            neg_preds += [predictor(hsrc, hdst_neg).squeeze().cpu()]
        neg_pred = torch.cat(neg_preds, dim=0).view(-1, 1000)

        return evaluator.eval({
            'y_pred_pos': pos_pred,
            'y_pred_neg': neg_pred,
        })['mrr_list'].mean().item()

    # train_mrr = test_split('eval_train')
    valid_mrr = test_split('valid')
    test_mrr = test_split('test')

    return dict(MRR=(valid_mrr, test_mrr))

@torch.no_grad()
def test1(model, predictor, data, ego_graphs, split_edge, args, evaluator, pe, adj, device):
    model.eval()
    predictor.eval()

    edge_lists = [
             split_edge['valid']['edge'],
             split_edge['valid']['edge_neg'],
             split_edge['test']['edge'],
             split_edge['test']['edge_neg']]

    ret = []
    for edge_list in edge_lists:
        preds = []
        for perm in DataLoader(range(edge_list.size(0)), args.batch_size):
            edge = edge_list[perm].t()
            def forward(idx):
                src, src_mask, src_padding, pe_batch = construct_batch(data, ego_graphs, idx, pe, adj, args.mask, args.num_heads)
                src = src.transpose(0, 1)
                pe_batch = pe_batch.transpose(0, 1)
                src = src.to(device)
                src_padding = src_padding.to(device)
                pe_batch = pe_batch.to(device)
                return model(src, src_mask=src_mask, padding=src_padding, pe=pe_batch)
            h0 = forward(edge[0])
            h1 = forward(edge[1])
            preds += [predictor(h0, h1).squeeze().cpu()]
        pred = torch.cat(preds, dim=0)
        ret.append(pred)

    pos_valid_pred, neg_valid_pred, pos_test_pred, neg_test_pred = ret

    results = {}
    for K in [10, 20, 50, 100]:
        evaluator.K = K
        valid_hits = evaluator.eval({
            'y_pred_pos': pos_valid_pred,
            'y_pred_neg': neg_valid_pred,
        })[f'hits@{K}']
        test_hits = evaluator.eval({
            'y_pred_pos': pos_test_pred,
            'y_pred_neg': neg_test_pred,
        })[f'hits@{K}']

        results[f'Hits@{K}'] = (valid_hits, test_hits)

    return results

def get_exp_name(dataset, para_dic, input_exp_name):
    para_name = '_'.join([dataset] + [key + str(value) for key, value in para_dic.items()])
    exp_name = para_name + '_' + input_exp_name

    if os.path.exists('runs/' + exp_name):
        shutil.rmtree('runs/' + exp_name)

    return exp_name

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def get_positional_embedding(max_len, d_model):
    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    pe = pe.unsqueeze(0).transpose(0, 1)
    return pe

def main():
    parser = argparse.ArgumentParser(description='OGBN (GNN)')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--dataset', type=str, default='arxiv')
    parser.add_argument('--log_steps', type=int, default=1)
    parser.add_argument('--num_layers', type=int, default=4)
    parser.add_argument('--num_heads', type=int, default=2)
    parser.add_argument('--ego_size', type=int, default=64)
    parser.add_argument('--hidden_size', type=int, default=64)
    parser.add_argument('--input_dropout', type=float, default=0.2)
    parser.add_argument('--hidden_dropout', type=float, default=0.4)
    parser.add_argument('--weight_decay', type=float, default=0.05)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--lr_scale', type=float, default=1.0)
    parser.add_argument('--epochs', type=int, default=500)
    parser.add_argument('--early_stopping', type=int, default=10)
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--runs', type=int, default=1)
    parser.add_argument('--pe_type', type=int, default=1)
    parser.add_argument('--mask', action='store_true', default=False)
    parser.add_argument("--optimizer", type=str, default='adamw', choices=['adam', 'adamw'], help="optimizer")
    parser.add_argument('--warmup', type=int, default=10000)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--exp_name', type=str, default='')
    args = parser.parse_args()
    print(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    para_dic = {'nl': args.num_layers, 'nh': args.num_heads, 'es': args.ego_size, 'hs': args.hidden_size,
                'id': args.input_dropout, 'hd': args.hidden_dropout, 'bs': args.batch_size, 'pe': args.pe_type, 
                'op': args.optimizer, 'lr': args.lr, 'wd': args.weight_decay, 'ls': args.lr_scale, 'sd': args.seed}
    para_dic['warm'] = args.warmup
    if args.mask:
        para_dic['mask'] = 1
    exp_name = get_exp_name(args.dataset, para_dic, args.exp_name)

    wandb_name = exp_name.replace('_sd'+str(args.seed), '')
    wandb.init(name=wandb_name, project="xfmr4gl")
    wandb.config.update(args)

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    dataset = PygLinkPropPredDataset(name=f'ogbl-{args.dataset}', transform=T.ToSparseTensor())

    ego_graphs = np.load(f'data/{args.dataset}-ego-graphs-padding-{args.ego_size}.npy')
    ego_graphs = torch.LongTensor(ego_graphs)

    pe = None
    if args.pe_type == 1:
        pe = torch.load(f'data/{args.dataset}-embedding-{args.hidden_size}.pt')
    elif args.pe_type == 2:
        pe = np.load(f'data/{args.dataset}-ego-graphs-gpe-{args.ego_size}-{args.hidden_size}.npy')
    elif args.pe_type == 3:
        pe = get_positional_embedding(args.ego_size, args.hidden_size)

    data = dataset[0]
    if args.dataset == 'ppa':
        data.x = data.x.float()
    if args.dataset == 'ddi':
        data.x = torch.LongTensor(range(data.adj_t.size(0)))

    adj = None
    if args.mask:
        adj = ~np.load(f'data/{args.dataset}-ego-graphs-adj-{args.ego_size}.npy')

    split_edge = dataset.get_edge_split()
    if args.dataset == 'citation':
        for split in split_edge:
            split_edge[split]['edge'] = torch.stack([torch.cat([split_edge[split]['source_node'], split_edge[split]['target_node']]), torch.cat([split_edge[split]['target_node'], split_edge[split]['source_node']])], dim=1)
        # idx = torch.randperm(split_edge['train']['source_node'].numel())[:86596]
        # split_edge['eval_train'] = {
        #     'source_node': split_edge['train']['source_node'][idx],
        #     'target_node': split_edge['train']['target_node'][idx],
        #     'target_node_neg': split_edge['valid']['target_node_neg'],
        # }


    model = TransformerModel(data.x.size(1), args.hidden_size,
                             args.num_heads, args.hidden_size,
                             args.num_layers, 2, 
                             args.input_dropout, args.hidden_dropout,
                             return_outputs=True).to(device)
    predictor = LinkPredictor(args.hidden_size, args.hidden_size, 1,
                              args.num_layers, args.hidden_dropout).to(device)
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
        model = nn.DataParallel(model)
        predictor = nn.DataParallel(predictor)

    wandb.watch(model, log='all')

    pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('model parameters:', pytorch_total_params)

    evaluator = Evaluator(name=f'ogbl-{args.dataset}')

    try:
        writer = SummaryWriter('runs/' + exp_name)
        hahaha
    except Exception as e:
        print(e)
        writer = None

    for run in range(args.runs):
        if torch.cuda.device_count() > 1:
            model.module.init_weights()
            # predictor.reset_parameters()
            predictor.module.reset_parameters()
        else:
            model.init_weights()
            predictor.reset_parameters()
        best_val_hits = 0
        cor_test_hits = 0
        patience = 0

        if args.optimizer == 'adam':
            optimizer = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()), lr=args.lr, weight_decay=args.weight_decay)
        elif args.optimizer == 'adamw':
            optimizer = torch.optim.AdamW(list(model.parameters()) + list(predictor.parameters()), lr=args.lr, weight_decay=args.weight_decay)
        else:
            raise NotImplementedError
        optimizer = ScheduledOptim(optimizer, args.hidden_size, n_warmup_steps=args.warmup, init_lr_scale=args.lr_scale)

        for epoch in range(1, 1 + args.epochs):
            train_iter = train(model, predictor, data, ego_graphs, split_edge, args, optimizer, pe, adj, device)
            for loss in train_iter:
                results = test(model, predictor, data, ego_graphs, split_edge, args, evaluator, pe, adj, device)
                if epoch % args.log_steps == 0:
                    for key, result in results.items():
                        valid_hits, test_hits = result
                        if writer is not None:
                            writer.add_scalar('loss', loss, epoch)
                            # writer.add_scalar('acc/train', train_hits, epoch)
                            writer.add_scalar('acc/valid', valid_hits, epoch)
                            writer.add_scalar('acc/test', test_hits, epoch)
                            writer.add_scalar('acc/best_val', best_val_hits, epoch)
                            writer.add_scalar('acc/cor_test', cor_test_hits, epoch)
                            writer.add_scalar('lr', get_lr(optimizer), epoch)

                        wandb.log({'Train Loss': loss, #f'Train {key}': train_hits,
                                f'Valid {key}': valid_hits, f'Test {key}': test_hits,
                                f'best_val_hits': best_val_hits, f'cor_test_hits': cor_test_hits,
                                'LR': get_lr(optimizer)})
                        print(f'Run: {run + 1:02d}, '
                            f'Epoch: {epoch:02d}, '
                            f'Loss: {loss:.4f}, '
                            # f'Train {key}: {100 * train_hits:.2f}%, '
                            f'Valid {key}: {100 * valid_hits:.2f}% '
                            f'Test {key}: {100 * test_hits:.2f}%')
                        best_key = "Hits@50"
                        if args.dataset == "ppa":
                            best_key = "Hits@100"
                        elif args.dataset == "ddi":
                            best_key = "Hits@20"
                        elif args.dataset == "citation":
                            best_key = "MRR"
                        if key == best_key:
                            if valid_hits > best_val_hits:
                                best_val_hits = valid_hits
                                cor_test_hits = test_hits
                                patience = 0
                                if torch.cuda.device_count() > 1:
                                    torch.save(model.module.state_dict(), 'saved/' + exp_name + '_model.pt')
                                    torch.save(predictor.module.state_dict(), 'saved/' + exp_name + '_predictor.pt')
                                else:
                                    torch.save(model.state_dict(), 'saved/' + exp_name + '_model.pt')
                                    torch.save(predictor.state_dict(), 'saved/' + exp_name + '_predictor.pt')
                                wandb.save('saved/' + exp_name + '.pt')
                            else:
                                patience += 1
                                if patience >= args.early_stopping:
                                    print('Early stopping...')
                                    wandb.log({'Final val': best_val_hits, 'Final test': cor_test_hits})
                                    exit()


if __name__ == "__main__":
    main()
