import argparse
import os
import torch
import torch.nn as nn
import torch.distributed as dist
# import torch.multiprocessing as mp
# from torch.nn.parallel import DistributedDataParallel
import torch.nn.functional as F
import pickle as pkl
import pdb
import utils
from tqdm import tqdm

# from ogb.nodeproppred import Evaluator

class FeedForwardNet(nn.Module):
    def __init__(self, in_feats, hidden, out_feats, n_layers, dropout):
        super(FeedForwardNet, self).__init__()
        self.layers = nn.ModuleList()
        self.n_layers = n_layers
        if n_layers == 1:
            self.layers.append(nn.Linear(in_feats, out_feats))
        else:
            self.layers.append(nn.Linear(in_feats, hidden))
            for i in range(n_layers - 2):
                self.layers.append(nn.Linear(hidden, hidden))
            self.layers.append(nn.Linear(hidden, out_feats))
        if self.n_layers > 1:
            self.prelu = nn.PReLU()
            self.dropout = nn.Dropout(dropout)
        self.reset_parameters()

    def reset_parameters(self):
        gain = nn.init.calculate_gain("relu")
        for layer in self.layers:
            nn.init.xavier_uniform_(layer.weight, gain=gain)
            nn.init.zeros_(layer.bias)

    def forward(self, x):
        for layer_id, layer in enumerate(self.layers):
            x = layer(x)
            if layer_id < self.n_layers - 1:
                x = self.dropout(self.prelu(x))
        return x


class SIGN(nn.Module):
    def __init__(
        self,
        in_feats,
        hidden,
        out_feats,
        num_hops,
        n_layers,
        dropout,
        input_drop,
        loss_fn,
    ):
        super(SIGN, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.prelu = nn.PReLU()
        self.inception_ffs = nn.ModuleList()
        self.input_drop = nn.Dropout(input_drop)
        for hop in range(num_hops+1):
            self.inception_ffs.append(
                FeedForwardNet(in_feats, hidden, hidden, n_layers, dropout)
            )
        self.project = FeedForwardNet(
            (num_hops+1) * hidden, hidden, out_feats, n_layers, dropout
        )
        self.loss_fn = loss_fn

    def forward(self, feats, y=None):
        feats = [self.input_drop(feat) for feat in feats]
        hidden = []
        for feat, ff in zip(feats, self.inception_ffs):
            hidden.append(ff(feat))
        out = self.project(self.dropout(self.prelu(torch.cat(hidden, dim=-1))))
        if y == None:
            return out
        else:
            loss = self.loss_fn(out, y)
            return out, loss

    def reset_parameters(self):
        for ff in self.inception_ffs:
            ff.reset_parameters()
        self.project.reset_parameters()


def train(model, x, labels, train_loader, optimizer):
    # x [k*[n,d]]
    model.train()
    total_loss, iter_num = 0, 0
    for batch in tqdm(train_loader):
        batch_feat = [feat[batch].cuda() for feat in x]
        y_true = labels[batch]
        y_true = utils.label_platten(y_true, num_classes) # [b,c]
        _, loss = model(batch_feat, y_true.cuda())
        loss_train = loss
        total_loss += loss_train
        optimizer.zero_grad()
        loss_train.backward()
        optimizer.step()
        iter_num+=1
    loss = total_loss / iter_num
    return loss


@torch.no_grad()
def test(model, x, labels, test_loader):
    model.eval()

    bs = 0
    acc_values, precision_values, recall_values, f1_values = 0, 0, 0, 0
    for batch in test_loader:
        batch_feat = [feat[batch].cuda() for feat in x]
        logits = model(batch_feat) # tensor [b,c]
        y_true = labels[batch] # tensor [b,c]
        y_true = utils.label_platten(y_true, num_classes) # [b,c]
        acc_value, precision_value, recall_value, f1_value = utils.evaluate(logits, y_true.cuda())
        acc_values += acc_value
        precision_values += precision_value
        recall_values += recall_value
        f1_values += f1_value
        bs += 1
    acc, precision, recall, f1 = utils.average(acc_values, precision_values, recall_values, f1_values, bs)
    return acc, precision, recall, f1


def main():
    parser = argparse.ArgumentParser(description='UniKG-SIGN')
    parser.add_argument('--dataset', type=str, default='wiki_full')
    parser.add_argument('--use_sign_embedding', action='store_true')
    parser.add_argument('--num_hops', type=int, default=4)
    parser.add_argument('--log_steps', type=int, default=1)
    parser.add_argument('--save_steps', type=int, default=10)
    parser.add_argument('--num_layers', type=int, default=3)
    parser.add_argument('--hidden_channels', type=int, default=256)
    parser.add_argument('--dropout', type=float, default=0)
    parser.add_argument('--input_drop', type=float, default=0) 
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--epochs', type=int, default=300)
    parser.add_argument('--runs', type=int, default=1)
    parser.add_argument('--batch_size', type=int, default=300000)
    args = parser.parse_args()
    print(args)

    device_ids = [0, 1, 2, 3, 4, 5, 6, 7]
    global loss_function, num_classes
    torch.manual_seed(123)
    # pdb.set_trace()

    if not args.use_sign_embedding:
        raise RuntimeError('UniKG need sign embedding.')

    if args.dataset == 'wiki_full':
        x = []
        x.append(pkl.load(open('feature_all/feature_all_pca.pkl', 'rb')))
        for i in range(1, args.num_hops+1):
            x.append(torch.load(f'feature_all/feature_all_pca_{i}.pth'))
        y_true = torch.load('multilabels/labels_cluster.pth')
        num_classes = torch.max(y_true).item()+1
        print(y_true , y_true.shape)

    if args.dataset == 'wiki_1M':
        x = []
        x.append(torch.load('feature_1M/feature_1M_pca.pth'))
        for i in range(1, args.num_hops+1):
            x.append(torch.load(f'feature_1M/feature_1M_pca_{i}.pth'))
        y_true = torch.load('multilabels/labels_cluster_1M.pth')
        num_classes = 2000
        print(y_true , y_true.shape)

    if args.dataset == 'wiki_10M':
        x = []
        x.append(torch.load('feature_10M/feature_10M_pca.pth'))
        for i in range(1, args.num_hops+1):
            x.append(torch.load(f'feature_10M/feature_10M_pca_{i}.pth'))
        y_true = torch.load('multilabels/labels_cluster_10M.pth')
        num_classes = 2000
        print(y_true , y_true.shape)

    entity_num = x[0].shape[0]
    shuffle_idx = torch.randperm(x[0].shape[0])
    split = [0.8]
    train_idx = shuffle_idx[:int(split[0]*entity_num)]
    # val_idx = shuffle_idx[int(split[0]*entity_num):int(split[1]*entity_num)]
    test_idx = shuffle_idx[int(split[0]*entity_num):]
    
    train_loader = torch.utils.data.DataLoader(train_idx, batch_size=args.batch_size, shuffle=True, drop_last=False)
    test_loader = torch.utils.data.DataLoader(test_idx, batch_size=args.batch_size, shuffle=False, drop_last=False)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    loss_function = nn.BCEWithLogitsLoss()
    model = SIGN(x[0].size(-1), args.hidden_channels, num_classes, args.num_hops, args.num_layers, args.dropout, args.input_drop, loss_function).cuda()
    log = open(f'{args.dataset}-{args.num_layers}-{args.hidden_channels}-{parser.description}.txt','w')

    for run in range(args.runs):
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
        for epoch in range(1, 1 + args.epochs):
            loss = train(model, x, y_true, train_loader, optimizer)
            acc, precision, recall, f1 = test(model, x, y_true, test_loader)
            if epoch % args.save_steps == 0:
                torch.save(model, f'model/mlp/{parser.description}-{args.dataset}-{epoch}.pth')

            txt = f'Run: {run + 1:02d}, Epoch: {epoch:02d}, Loss: {loss:.4f}, Acc: {100 * acc:.2f}%, precision: {100 * precision:.2f}%, recall: {100 * recall:.2f}%, f1: {100 * f1:.2f}\n'
            if epoch % args.log_steps == 0:
                print(txt)
            log.write(txt)
            log.flush()

if __name__ == "__main__":
    main()