import torch
from torch_geometric.data import DataLoader
import torch.optim as optim
import torch.nn.functional as F
from gnn import GNN
from Multi_pAUC_KL import Multi_pAUC_KL
from multi_task_sampler import DataSampler
from utils import PAUC_MultiLabel, partial_auc, pAUC_mini, FocalLoss
from tqdm import tqdm
import argparse
import time
import wandb
import numpy as np
from torch_geometric.data import Data
### importing OGB
from ogb.graphproppred import PygGraphPropPredDataset


reg_criterion = torch.nn.MSELoss()

def set_params():
    if args.dataset=='ogbg-moltox21':
        args.batch_size = 144
        args.iter_record = None
        args.sample_tasks = 12
        args.tasks = 12
        args.epochs = 100
        args.pretrain = True

        args.beta_pauc = 0.7
        args.eta1_pauc = 0.01
        args.eta2_pauc = 0.01
        args.tau1_pauc = 20
        args.tau2_pauc = 1.

        args.gamma_sopa = 0.1
        args.tau_sopa = 1.

        args.gamma_focal = 2
        args.alpha_focal = 0.5

    elif args.dataset=='ogbg-molpcba':
        args.batch_size = 128
        args.iter_record = None
        args.sample_tasks = 1
        args.tasks = 127
        args.epochs = 100
        args.pretrain = True

        args.beta_pauc = 0.5
        args.eta1_pauc = 0.01
        args.eta2_pauc = 0.01
        args.tau1_pauc = 10
        args.tau2_pauc = 1

        args.gamma_sopa = 0.1
        args.tau_sopa = 1.

        args.gamma_focal = 2
        args.alpha_focal = 0.75




def train(model, cls_criterion, device, loader, optimizer, task_type, T):
    model.train()

    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
    #for step, batch in enumerate(loader):
        batch = batch.to(device)
        T += 1
        if args.ls == 'PAUC':
            cls_criterion.beta1=1./np.sqrt(T)
        if batch.x.shape[0] == 1 or batch.batch[-1] == 0:
            pass
        else:
            pred = model(batch)
            pred = torch.sigmoid(pred)
            optimizer.zero_grad()
            ## ignore nan targets (unlabeled) when computing training loss.
            is_labeled = batch.y == batch.y
            if "classification" in task_type:
                if args.ls in ['SOPA','PAUC','MB']:
                    loss = cls_criterion(pred.to(torch.float32), batch.y.to(torch.float32),batch.idx)
                else:
                    loss = cls_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled])
            else:
                loss = reg_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled])
                
            loss.backward()
            optimizer.step()
            

def eval(model, device, loader):
    model.eval()
    y_true = []
    y_pred = []

    nan_mask = torch.LongTensor(range(128))
    nan_mask = nan_mask != 45
    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
    #for step, batch in enumerate(loader):
        if args.dataset == 'ogbg-molpcba' and batch.y.shape[1]==128:
            batch.y = batch.y[:,nan_mask]
        batch = batch.to(device)

        if batch.x.shape[0] == 1:
            pass
        else:
            with torch.no_grad():
                pred = model(batch)
                pred = torch.sigmoid(pred)

        
            y_true.append(batch.y.view(pred.shape).detach().cpu())
            y_pred.append(pred.detach().cpu())
               

    y_true = torch.cat(y_true, dim = 0).numpy()
    y_pred = torch.cat(y_pred, dim = 0).numpy()

    val_auc_mean1 = partial_auc(y_pred, y_true, max_fpr=0.1)
    val_auc_mean2 = partial_auc(y_pred, y_true, max_fpr=0.3)

    return [val_auc_mean1, val_auc_mean2]


def main():
    # Training settings
    parser = argparse.ArgumentParser(description='GNN baselines on ogbgmol* data with Pytorch Geometrics')
    parser.add_argument('--device', type=int, default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--gnn', type=str, default='gin-virtual',
                        help='GNN gin, gin-virtual, or gcn, or gcn-virtual (default: gin-virtual)')
    parser.add_argument('--drop_ratio', type=float, default=0.5,
                        help='dropout ratio (default: 0.5)')
    parser.add_argument('--num_layer', type=int, default=5,
                        help='number of GNN message passing layers (default: 5)')
    parser.add_argument('--emb_dim', type=int, default=300,
                        help='dimensionality of hidden units in GNNs (default: 300)')
    parser.add_argument('--batch_size', type=int, default=32,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--epochs', type=int, default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--num_workers', type=int, default=2,
                        help='number of workers (default: 2)')
    parser.add_argument('--dataset', type=str, default="ogbg-molpcba",
                        help='dataset name (default: ogbg-moltox21)')

    parser.add_argument('--feature', type=str, default="full",
                        help='full feature or simple feature')
    parser.add_argument('--filename', type=str, default="",
                        help='filename to output result (default: )')
    parser.add_argument('--ls', default='PAUC', type=str)
    
    global args
    args = parser.parse_args()
    
    set_params()

    device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")


    ## automatic dataloading and splitting
    dataset = PygGraphPropPredDataset(name = args.dataset)
    

    if args.feature == 'full':
        pass
    elif args.feature == 'simple':
        print('using simple feature')
        # only retain the top two node/edge features
        dataset.data.x = dataset.data.x[:,:2]
        dataset.data.edge_attr = dataset.data.edge_attr[:,:2]

    split_idx = dataset.get_idx_split()

    ### automatic evaluator. takes dataset name as input
    # evaluator = Evaluator(args.dataset)
    # if args.ls in ['SOPA','PAUC','MB']:
    train_labels = dataset.data.y[split_idx["train"]]
    train_dataset = dataset[split_idx["train"]]
        # python 3.9
        #train_data = [Data(**{k:v for k,v in data}|{'idx':idx}) for idx,data in enumerate(train_dataset)]
        # others
    train_data = []
    nan_mask = torch.LongTensor(range(128))
    nan_mask = nan_mask != 45
    for idx, data in enumerate(train_dataset):
        if args.dataset == 'ogbg-molpcba':
            data_dict = {k:v for k,v in data if k !='y'}
            data_dict.update({'y':data.y[:,nan_mask]})
        else:
            data_dict = {k:v for k,v in data}
        data_dict.update({'idx':idx})
        train_data.append(Data(**data_dict))
        
        
    train_loader = DataLoader(train_data,sampler=DataSampler(train_labels,batchSize=args.batch_size,multi_tasks=args.sample_tasks),
                                  batch_size=args.batch_size, num_workers = args.num_workers, drop_last=True)
    # else:
    #     train_loader = DataLoader(dataset[split_idx["train"]], batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers)
    valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers)
    test_loader = DataLoader(dataset[split_idx["test"]], batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers)

    if args.gnn == 'gin':
        model = GNN(gnn_type = 'gin', num_tasks = args.tasks, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = False).to(device)
    elif args.gnn == 'gin-virtual':
        model = GNN(gnn_type = 'gin', num_tasks = args.tasks, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = True).to(device)
    elif args.gnn == 'gcn':
        model = GNN(gnn_type = 'gcn', num_tasks = args.tasks, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = False).to(device)
    elif args.gnn == 'gcn-virtual':
        model = GNN(gnn_type = 'gcn', num_tasks = args.tasks, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = True).to(device)
    else:
        raise ValueError('Invalid GNN type')

    
    if args.ls == 'CE':
        cls_criterion = torch.nn.BCELoss()
    elif args.ls == 'focal':
        cls_criterion = FocalLoss(gamma=args.gamma_focal,alpha=args.alpha_focal)
    elif args.ls == 'SOPA':
        cls_criterion = Multi_pAUC_KL(data_len=len(train_loader.dataset), gamma=args.gamma_sopa, Lambda=args.tau_sopa, total_tasks=args.tasks)
    elif args.ls == 'PAUC':
        cls_criterion = PAUC_MultiLabel(num_classes=args.tasks, eta1=args.eta1_pauc, eta2=args.eta2_pauc, beta=args.beta_pauc, tau1=args.tau1_pauc, tau2=args.tau2_pauc)
    elif args.ls == 'MB':
        cls_criterion = pAUC_mini(threshold=1., gamma=0.7)
        
        
    if args.ls in ['PAUC']:
        #optimizer = optim.SGD(list(model.parameters())+list(cls_criterion.parameters()),momentum=0.9, lr=1e-2, weight_decay=1e-4)
        optimizer = optim.Adam(list(model.parameters())+list(cls_criterion.parameters()), lr=5e-4, weight_decay=1e-4) 
    else:
        optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50], gamma=0.1)

    valid_curve = []
    test_curve = []
    train_curve = []
    valid_curve1 = []
    test_curve1 = []
    train_curve1 = []
    best_val = 0
    save_flag = False
    T = 0
    for epoch in range(1, args.epochs + 1):
        print("=====Epoch {}".format(epoch))
        print('Training...')
        train(model, cls_criterion, device, train_loader, optimizer, dataset.task_type, T)

        print('Evaluating...')
        train_perf = eval(model, device, train_loader)
        valid_perf = eval(model, device, valid_loader)
        test_perf = eval(model, device, test_loader)

        print({'Train': train_perf, 'Validation': valid_perf, 'Test': test_perf})

        train_curve.append(train_perf[0])
        valid_curve.append(valid_perf[0])
        test_curve.append(test_perf[0])
        train_curve1.append(train_perf[1])
        valid_curve1.append(valid_perf[1])
        test_curve1.append(test_perf[1])

        scheduler.step()


    if valid_curve[-1]>best_val:
        best_val = valid_curve[-1]
        save_flag = True

    print('Finished training!')
    print('Best validation score: {}'.format(best_val))
    print('Test score: {}, {}, {}'.format(train_curve[-1],valid_curve[-1],test_curve[-1]))

    model_path = 'models/'+args.dataset+'_'+args.ls+str(idx)+'.pth'
    if save_flag:
        save_flag = False
        torch.save(model.state_dict(), model_path)
    np.save(args.dataset+'_'+args.ls+'_history_'+'.npy', [train_curve,valid_curve,test_curve,train_curve1,valid_curve1,test_curve1])

if __name__ == "__main__":
        main()
