import torch
import numpy as np
from basic_gnn import GCN
from gdot import NNNodeBenchmarker_GDOT
from sampler import GraphSAINTRandomWalkBalancedSampler, GraphSAINTRandomWalkSampler
from utils import DomainDataset
import argparse
from copy import deepcopy




if __name__ == '__main__':

    dataset_pair = [['acm', 'dblp'], ['acm_before_2010', 'acm_after_2010'], ['acm_large', 'dblp_large']] 
    dataset_pair_id = 2
    parser = argparse.ArgumentParser()
    parser.add_argument('--lr', type = float , default = 0.001)
    parser.add_argument("--source_dataset", type=str, default=dataset_pair[dataset_pair_id][0]) 
    parser.add_argument("--target_dataset", type=str, default=dataset_pair[dataset_pair_id][1]) 
    parser.add_argument('--log', action='store_true')
    parser.add_argument('--alpha', type = float, default = 0.1)
    parser.add_argument('--beta', type = float, default = 0.1)
    parser.add_argument('--gpu', type = int, default = 0)
    args = parser.parse_args()
    print(args)

    device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else "cpu")
    dataset = DomainDataset(f"dataset/ood_testing/{args.source_dataset}", name=args.source_dataset)
    source_data = dataset[0]

    dataset = DomainDataset(f"dataset/ood_testing/{args.target_dataset}", name=args.target_dataset)
    target_data = dataset[0]

    train_mask = source_data.train_mask = source_data.train_mask.bool()
    val_mask = source_data.test_mask= source_data.test_mask.bool()
    source_num_nodes, target_num_nodes = source_data.x.size(0), target_data.x.size(0)
    test_mask = torch.ones((target_num_nodes,)).bool()

    if dataset_pair_id in [1]:
        num_epoch = 100
        test_mask[:source_num_nodes] = False
        range_l, range_r = source_num_nodes, target_num_nodes
    else: 
        num_epoch = 50
        range_l, range_r = 0, target_num_nodes
    target_data.test_mask =  test_mask

    if dataset_pair_id in [1, 2, 3]:
        GCN_benchmark = NNNodeBenchmarker_GDOT(arch='MLP_GCN', model_class=GCN, benchmark_params={'lr': args.lr, 'epochs': 200}, h_params={'in_channels':source_data.x.shape[1], 'hidden_channels':128, 'num_layers':2, 'out_channels':source_data.y.max().item() + 1, 'dropout': 0.2},device=device)
    elif dataset_pair_id in [0]:
        GCN_benchmark = NNNodeBenchmarker_GDOT(arch='I_GCN', model_class=GCN, benchmark_params={'lr': args.lr, 'epochs': 200}, h_params={'in_channels':source_data.x.shape[1], 'hidden_channels':128, 'num_layers':2, 'out_channels':source_data.y.max().item() + 1, 'dropout': 0.1},device=device)  


    GCN_benchmark.SetMasks(train_mask, val_mask, test_mask)

    
    loader = GraphSAINTRandomWalkBalancedSampler(source_data, batch_size=256, 
                                        walk_length=2,
                                        num_steps=50, 
                                        sample_coverage=100,
                                        save_dir=f'./dataset/ood_testing/{args.source_dataset}',
                                        #num_workers=4
                                        )
    tgt_loader = GraphSAINTRandomWalkSampler(target_data, batch_size=256, 
                                        walk_length=2,
                                        num_steps=50, 
                                        sample_coverage=10,
                                        range_l=range_l, 
                                        range_r=range_r,
                                        save_dir=f'./dataset/ood_testing/{args.target_dataset}',
                                        #num_workers=4
                                        )

    source_val = deepcopy(source_data).to(device)
    target_test = deepcopy(target_data).to(device)
    

    occur = set()
    num_run = 10
    multi_run_f1_micro = []
    multi_run_f1_macro = []
    early_stop_criterion = 3
    num_early_stop = 0
    for _ in range(num_run):
        best_val_roc = 0.0
        GCN_benchmark.reset_parameters()
        for epoch in range(num_epoch):
            for batch, tgt_batch in zip(loader, tgt_loader):
                loss = GCN_benchmark.train_batch(batch.to(device), tgt_batch.to(device), alpha=args.alpha, beta=args.beta) #
            
            with torch.no_grad():
                val_metrics = GCN_benchmark.test(source_val, test_on_val=True) 
                test_metrics = GCN_benchmark.test(target_test, test_on_val=False)
                if val_metrics['f1_micro'] > best_val_roc:
                    print(f"epoch:{epoch}, loss:{loss}")
                    print(f"best roc on validation {val_metrics['rocauc_ovr']}, {val_metrics['accuracy']}")
                    best_val_roc = val_metrics['f1_micro']
                    best_test_metrics = test_metrics
                    num_early_stop = 0
                elif epoch > 10:
                    num_early_stop += 1
            if num_early_stop > early_stop_criterion:
                break
                
        print("======Test result on best validation======")
        print("roc_auc:", best_test_metrics['rocauc_ovr'])
        print("accuracy:", best_test_metrics['accuracy'])
        print("micro f1:", best_test_metrics['f1_micro'])
        print("macro f1:", best_test_metrics['f1_macro'])
        multi_run_f1_micro.append(best_test_metrics['f1_micro'])
        multi_run_f1_macro.append(best_test_metrics['f1_macro'])
    
    print(f"micro f1: mean {np.mean(multi_run_f1_micro)}, std {np.std(multi_run_f1_micro)}")
    print(f"macro f1: mean {np.mean(multi_run_f1_macro)}, std {np.std(multi_run_f1_macro)}")
