import torch
from torch_geometric.loader import DataLoader
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from models.gnn import GNN
from sklearn.metrics import roc_auc_score
from collections import Counter

import os
from tqdm import tqdm
import argparse
import time
import numpy as np
import pickle
import higher
from torch.autograd import Variable
import copy
import shutil

### importing OGB
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator

### importing loss
from auclosses import AUCLoss_multiLabel

dtype = torch.cuda.FloatTensor

cls_criterion = torch.nn.BCEWithLogitsLoss()
reg_criterion = torch.nn.MSELoss()

def set_all_seeds(SEED):
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def zero_grad(model):
    for name, p in model.named_parameters():
        if p.grad is not None:
            p.grad.data.zero_()

def log(log_file_path, string):
    '''
    Write one line of log into screen and file.
        log_file_path: Path of log file.
        string:        String to write in log file.
    '''
    with open(log_file_path, 'a+') as f:
        f.write(string + '\n')
        f.flush()
    print(string)

def proj_sca(x, bound):
    if x > bound:
        return bound
    elif x < 0:
        return 0
    else:
        return x

class AUCLoss_multiLabel():

    def __init__(self, imratio, m=1.0):
        self.p = imratio
        self.m = m

    def g1(self, outputs, a, b, targets, task=0):
        p_i = self.p[task]
        a_i = a[task]
        b_i = b[task]
        loss_val = (1 - p_i) * torch.mean((outputs - a_i) ** 2 * (1 == targets).float()) + \
                   p_i * torch.mean((outputs - b_i) ** 2 * (0 == targets).float())
        return loss_val

    def g1_grad_a(self, outputs, a, targets, task=0):
        p_i = self.p[task]
        a_i = a[task]
        grad_val = -2 * (1 - p_i) * torch.mean((outputs - a_i) * (1 == targets).float())
        return grad_val

    def g1_grad_b(self, outputs, b, targets, task=0):
        p_i = self.p[task]
        b_i = b[task]
        grad_val = -2 * p_i * torch.mean((outputs - b_i) * (0 == targets).float())
        return grad_val

    def g2(self, outputs, targets, task=0):
        p_i = self.p[task]
        loss_val = -2 * (1 - p_i) * torch.mean(outputs * (1 == targets).float()) + \
                   2 * p_i * torch.mean(outputs * (0 == targets).float()) + \
                   2 * p_i * (1-p_i) * self.m
        return loss_val

    def g3(self, alpha, task=0):
        p_i = self.p[task]
        alpha_i = alpha[task]
        return p_i * (1 - p_i) * alpha_i ** 2

    def g3_grad(self, alpha, task=0):
        p_i = self.p[task]
        alpha_i = alpha[task]
        return 2 * p_i * (1 - p_i) * alpha_i


def eval_rocauc(y_true, y_pred):
    '''
        compute ROC-AUC averaged across tasks
    '''

    rocauc_list = []

    for i in range(y_true.shape[1]):
        if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0:
            is_labeled = y_true[:, i] == y_true[:, i]
            log(logdir, f"y_true: {y_true[is_labeled, i]}")
            log(logdir, f"y_pred: {y_pred[is_labeled, i]}")
            rocauc_list.append(roc_auc_score(y_true[is_labeled, i], y_pred[is_labeled, i]))

    if len(rocauc_list) == 0:
        raise RuntimeError('No positively labeled data available. Cannot compute ROC-AUC.')

    return sum(rocauc_list) / len(rocauc_list)


def train(mp_model, device, loader, task_type, imratio, lr_decay=1):
    lr = args.lr/lr_decay
    lr_v = lr * args.lambda_value if not args.lambda_value == 0 else lr
    beta = args.beta
    beta_ct = args.beta_ct

    label_set = np.linspace(0, 128- 1, 128).astype(int)
    tempmark = 1

    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device)


        if batch.x.shape[0] == 1 or batch.batch[-1] == 0:
            pass
        else:
            if (step%2 == 0):
                # following Hu et al. 2022
                pred = mp_model(batch, params=w_weights)
                
                is_labeled = batch.y == batch.y

                np.random.shuffle(label_set)
                selectTasks = np.sort(label_set[:args.task_BATCH_SIZE])

                for task_idx in range(128):
                    if task_idx in selectTasks:
                        continue
                    else:
                        is_labeled[:,task_idx]=False

                loss_ce = cls_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled])

                for p in w_weights:
                    try:
                        p.grad.data.zero_()
                        print(1)
                        raise
                    except:
                        pass
                    
                grads_ce = torch.autograd.grad(loss_ce, w_weights, retain_graph=False)
                
                for g, v, w in zip(grads_ce, v_weights, w_weights):
                    v.data = v.data - lr_v * beta_ct * g.data
                    w.data = w.data - lr_v * beta_ct * g.data
                    
                    
            else:
                np.random.shuffle(label_set)
                selectTasks = label_set[:args.task_BATCH_SIZE]

                pred = mp_model(batch, params=u_weights)
                
                pred_sig = torch.sigmoid(pred)
                Loss_auc = AUCLoss_multiLabel(imratio=imratio, m=1)
                grads_a, grads_b, grads_alp, loss_auc = 0, 0, 0, 0

                for task_idx in selectTasks:
                    ## ignore nan targets (unlabeled) when computing training loss.
                    is_labeled_i = batch.y[:,task_idx]==batch.y[:,task_idx]

                    if sum(is_labeled_i)==0:
                        continue

                    y_pred_i = pred_sig.to(torch.float32)[:,task_idx][is_labeled_i]
                    y_true_i = batch.y.to(torch.float32)[:, task_idx][is_labeled_i]

                    grads_a += Loss_auc.g1_grad_a(y_pred_i, a, y_true_i, task=task_idx)
                    grads_b += Loss_auc.g1_grad_b(y_pred_i, b, y_true_i, task=task_idx)
                    grads_alp += Loss_auc.g2(y_pred_i, y_true_i, task=task_idx) - Loss_auc.g3_grad(alpha, task=task_idx)
                    loss_auc += Loss_auc.g1(y_pred_i, a, b, y_true_i, task=task_idx) \
                               + alpha[task_idx] * Loss_auc.g2(y_pred_i, y_true_i, task=task_idx) \
                               - Loss_auc.g3(alpha, task=task_idx)

                grads_a = grads_a / args.task_BATCH_SIZE
                grads_b = grads_b / args.task_BATCH_SIZE
                grads_alp = grads_alp / args.task_BATCH_SIZE
                loss_auc = loss_auc / args.task_BATCH_SIZE
                

                z_a.data = (1 - beta) * z_a + beta * grads_a
                a.data = a - lr * z_a

                z_b.data = (1 - beta) * z_b + beta * grads_b
                b.data = b - lr * z_b

                alpha.data = alpha.data + lr * grads_alp
                alpha.data = torch.clamp(alpha.data, 0, 999)

                grads_auc_u = torch.autograd.grad(loss_auc, u_weights, retain_graph=False)

                for u, v, g_ce, g_auc_u, w, z_w in zip(u_weights, v_weights, grads_ce, grads_auc_u, w_weights, z_w_list):
                    u.data = u.data - lr * (g_auc_u.data + args.lambda_value * (u.data - w.data + beta_ct * g_ce.data))
                    grad_w = args.lambda_value * (v.data - u.data)
                    z_w.data = (1 - beta) * z_w + beta * grad_w
                    w.data = w.data - lr * z_w
                    u.data = w.data
                    v.data = w.data




def eval(model, device, loader):
    model.eval()
    y_true = []
    y_pred = []

    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device)

        if batch.x.shape[0] == 1:
            pass
        else:
            with torch.no_grad():
                pred = model(batch)

            y_true.append(batch.y.view(pred.shape).detach().cpu())
            y_pred.append(pred.detach().cpu())

    y_true = torch.cat(y_true, dim = 0).numpy()
    y_pred = torch.cat(y_pred, dim = 0).numpy()

    return eval_rocauc(y_true, y_pred)

def main():
    # Training settings
    parser = argparse.ArgumentParser(description='GNN baselines on ogbgmol* data with Pytorch Geometrics')
    parser.add_argument('--gpuid', type=str, default='0',
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--gnn', type=str, default='gin',
                        help='GNN gin, gin-virtual, or gcn, or gcn-virtual (default: gin-virtual)')
    parser.add_argument('--drop_ratio', type=float, default=0.5,
                        help='dropout ratio (default: 0.5)')
    parser.add_argument('--num_layer', type=int, default=5,
                        help='number of GNN message passing layers (default: 5)')
    parser.add_argument('--emb_dim', type=int, default=300,
                        help='dimensionality of hidden units in GNNs (default: 300)')
    parser.add_argument('--batch_size', type=int, default=32,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--epochs', type=int, default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--num_workers', type=int, default=0,
                        help='number of workers (default: 0)')
    parser.add_argument('--dataset', type=str, default="ogbg-molpcba",
                        help='dataset name (default: ogbg-molpcba)')
    parser.add_argument('--data_dir', type=str, default="")
    parser.add_argument('--feature', type=str, default="full",
                        help='full feature or simple feature')
    parser.add_argument('--filename', type=str, default='results_main_pyg_ct_rand',
                        help='filename to output result (default: )')
    parser.add_argument('--SEED', default=123, type=int, help='random seed')
    parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
    parser.add_argument('--beta', default=0.9, type=float)
    parser.add_argument('--beta_ct', default=0.9, type=float)
    parser.add_argument('--decay_point', default=1000, type=int)
    parser.add_argument('--task_BATCH_SIZE', default=10, type=int)
    parser.add_argument('--method', default='ct', type=str)
    
    parser.add_argument('--lambda_value', default=1, type=float)
    parser.add_argument('--save_dir', default='exp', type=str)


    global args
    args = parser.parse_args()
    
    
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpuid

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    set_all_seeds(args.SEED)

    ### automatic dataloading and splitting
    dataset = PygGraphPropPredDataset(name = args.dataset, root=args.data_dir)

    if args.feature == 'full':
        pass
    elif args.feature == 'simple':
        print('using simple feature')
        # only retain the top two node/edge features
        dataset.data.x = dataset.data.x[:,:2]
        dataset.data.edge_attr = dataset.data.edge_attr[:,:2]

    split_idx = dataset.get_idx_split()

    train_loader = DataLoader(dataset[split_idx["train"]], batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers)
    valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers)
    test_loader = DataLoader(dataset[split_idx["test"]], batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers)

    ###### Compute imratio
    labels = train_loader.dataset.data.y[split_idx["train"]]

    imratio = []
    for i in range(128):
        nonzero_counts = torch.count_nonzero(labels[:, i])
        nan_counts = torch.isnan(labels[:, i]).sum()
        imratio.append((nonzero_counts-nan_counts) / len(labels))

    if args.gnn == 'gin':
        model = GNN(gnn_type = 'gin', num_tasks = dataset.num_tasks, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = False).to(device)
    elif args.gnn == 'gin-virtual':
        model = GNN(gnn_type = 'gin', num_tasks = dataset.num_tasks, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = True).to(device)
    elif args.gnn == 'gcn':
        model = GNN(gnn_type = 'gcn', num_tasks = dataset.num_tasks, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = False).to(device)
    elif args.gnn == 'gcn-virtual':
        model = GNN(gnn_type = 'gcn', num_tasks = dataset.num_tasks, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = True).to(device)
    else:
        raise ValueError('Invalid GNN type')
    
    mp_model = higher.monkeypatch(model, copy_initial_weights=True).cuda()
    
    global w_weights, u_weights, v_weights

    w_weights_list = list(model.parameters())
    w_weights = [param.requires_grad_(True) for param in w_weights_list]
    u_weights_list = copy.deepcopy(w_weights_list)
    u_weights = [param.requires_grad_(True) for param in u_weights_list]
    v_weights_list = copy.deepcopy(w_weights_list)
    v_weights = [param.requires_grad_(True) for param in v_weights_list]
    

    valid_curve = []
    test_curve = []
    train_curve = []

    ## Initials
    global alpha, a, b, z_a, z_b, z_w_list

    alpha = Variable(torch.zeros(dataset.num_tasks).type(dtype), requires_grad=False)
    a = Variable(torch.zeros(dataset.num_tasks).type(dtype), requires_grad=False)
    b = Variable(torch.zeros(dataset.num_tasks).type(dtype), requires_grad=False)
    z_a = Variable(torch.zeros(dataset.num_tasks).type(dtype), requires_grad=False)
    z_b = Variable(torch.zeros(dataset.num_tasks).type(dtype), requires_grad=False)
    z_w_list = []
    for (name, w) in model.named_parameters():
        z_w_list.append(torch.zeros_like(w))
    
    save_dir = args.save_dir
    try:
        os.mkdir(save_dir)
    except:
        shutil.rmtree(save_dir)
        os.mkdir(save_dir)
    global logdir
    logdir = args.save_dir + '/log.txt'
    
    log(logdir, str(vars(args)))
    

    for epoch in range(1, args.epochs + 1):
        print("=====Epoch {}".format(epoch))
        print('Training...')

        if epoch>=(args.decay_point-1):
            lr_decay = 10
        else:
            lr_decay = 1
        
        train(mp_model, device, train_loader, dataset.task_type, imratio, lr_decay=lr_decay)

        print('Evaluating...')
        train_perf = eval(mp_model, device, train_loader)
        valid_perf = eval(mp_model, device, valid_loader)
        test_perf = eval(mp_model, device, test_loader)

        print({'Train': train_perf, 'Validation': valid_perf, 'Test': test_perf})
        log(logdir, f'''Train: {train_perf}, Validation: {valid_perf}, Test: {test_perf},
            a = {a}, b = {b}, alpha = {alpha}''')

        train_curve.append(train_perf)
        valid_curve.append(valid_perf)
        test_curve.append(test_perf)

    best_val_epoch = np.argmax(np.array(valid_curve))
    best_train = max(train_curve)

    print('TEST_AUC = ')
    print(test_curve)
    print('Best validation score: {}'.format(valid_curve[best_val_epoch]))
    print('Test score: {}'.format(test_curve[best_val_epoch]))

    if not args.filename == '':
        torch.save({'Val': valid_curve[best_val_epoch], 'Test': test_curve[best_val_epoch], 'Train': train_curve[best_val_epoch], 'BestTrain': best_train}, args.filename)

if __name__ == "__main__":
    main()