import argparse

import torch

from logger import Logger
from engine import train, test
from model import DeformableGT
from utils import load_data, load_split, init_seed, generate_reference
from tqdm import tqdm

def main():
    parser = argparse.ArgumentParser(description='DeformableGraphTransformer')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--log_steps', type=int, default=1)
    # parser.add_argument('--use_sage', action='store_true')
    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--hidden_channels', type=int, default=64)
    parser.add_argument('--dropout', type=float, default=0)
    parser.add_argument('--lr', type=float, default=0.005)
    parser.add_argument('--offset_lr', type=float, default=0.05)
    parser.add_argument('--decay', type=float, default=5e-5)
    parser.add_argument('--epochs', type=int, default=1000)
    parser.add_argument('--runs', type=int, default=30)
    parser.add_argument('--edge_path', type=str, default='khopdata/ogbn-arxiv/1hop_edge.npy')
    parser.add_argument('--load_edge', action='store_true')
    parser.add_argument("--data_name", type=str, default="cora")
    parser.add_argument("--pe_type", type=str, default="katz_learnable_remove_self_loop")
    parser.add_argument("--relpe_type", type=str, default="no")
    parser.add_argument("--k", type=int, default=2)
    parser.add_argument("--T", type=float, default=0.5)
    parser.add_argument('--use_1hop', action='store_true')
    parser.add_argument('--model_type', type=str, default='mlp')
    parser.add_argument('--pat_epochs', type=int, default=100)
    parser.add_argument('--memo', type=str, default="linear_cora")
    parser.add_argument('--num_blocks', type=int, default=1)
    parser.add_argument('--khop', type=int, default=0)
    parser.add_argument('--sparse', action='store_true')
    parser.add_argument('--sparse_p', type=float, default=0.01, choices=[0.001, 0.005, 0.01,0.05,0.1,0.2,0.5])
    parser.add_argument('--sched', action='store_true')


    parser.add_argument('--sort_type', type=str, default="ppnp_load,feat_load,bfs_load")
    parser.add_argument('--interpolate_mode', type=str, default='gaussian')
    parser.add_argument('--bandwidth', type=int, default=24)
    parser.add_argument('--multi_threads', action='store_true')
    parser.add_argument('--separate_offset_params', action='store_true')
    parser.add_argument("--num_heads", type=int, default=4)
    parser.add_argument("--num_points", type=int, default=4)
    parser.add_argument("--eps", type=int, default=6)
    parser.add_argument("--query_gnn", action='store_true')
    parser.add_argument("--value_gnn", action='store_true')
    parser.add_argument("--change_bias", action='store_true')
    parser.add_argument("--bias", type=float, default=0)
    parser.add_argument("--change_weight", action='store_true')
    parser.add_argument("--weight", type=float, default=1)
    parser.add_argument("--gnn_encoding", action='store_true')
    parser.add_argument("--bandwidth_decaying", type=float, default=1)
    parser.add_argument("--min_bandwidth", type=float, default=8)
    parser.add_argument("--remove_negative_points", action='store_true')
    parser.add_argument("--multi_bandwidth", type=str, default='')
    parser.add_argument('--separate_bandwidth_params', action='store_true')
    parser.add_argument('--bottleneck', type=int, default=64)
    
    args = parser.parse_args()
    print(args)


    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)
    logger = Logger(args.runs, args)
    
    num_classes, data, dataset = load_data(args.data_name, device=device)
    
    
    spd_mat = None
    lap_pe=None
    reference_points = generate_reference(data, args.sort_type, data_name=args.data_name)
    
    for run in range(args.runs):
        init_seed(run)
        run_idx = run%10
        split_idx, train_idx = load_split(args.data_name, data, run_idx, device, dataset)
        model = DeformableGT(data.num_features, args.hidden_channels,
                            num_classes, data, args.dropout, args.pe_type, 
                            args.num_blocks, args.relpe_type, args=args,
                            num_heads=args.num_heads, num_axis=len(reference_points), spd_mat=spd_mat, lap_pe=lap_pe).to(device)


        
        if args.separate_offset_params:
            sampling_offsets = []
            res_params = []
            for name, param in model.named_parameters():
                if 'sampling_offsets' in name:
                    sampling_offsets.append(param)
                else:
                    res_params.append(param)
            optimizer = torch.optim.Adam([{'params':res_params}, {'params':sampling_offsets}], lr=args.lr, weight_decay=args.decay)
            optimizer.param_groups[1]['lr'] = args.offset_lr
            optimizer.param_groups[1]['weight_decay'] = 0
        else:
            if args.separate_bandwidth_params:
                bandwidth_params = []
                res_params = []
                for name, param in model.named_parameters():
                    if 'multi_bandwidth_attn' in name:
                        bandwidth_params.append(param)
                    else:
                        res_params.append(param)
                optimizer = torch.optim.Adam([{'params':res_params}, {'params':bandwidth_params}], lr=args.lr, weight_decay=args.decay)
                optimizer.param_groups[1]['lr'] = 0.05
                optimizer.param_groups[1]['weight_decay'] = 0
                
            else:
                optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay)
        
        
        
        max_acc, patient, min_loss = 0, 0, 100.0
        for epoch in tqdm(range(1, 1 + args.epochs)):
            if patient == args.pat_epochs:
                print("Early Stop!!")
                break
            
            loss = train(model, data, train_idx, optimizer, reference_points)
            train_acc, valid_acc, test_acc, loss_val = test(model, data, split_idx, reference_points)
            result = (train_acc, valid_acc, test_acc)
            logger.add_result(run, result)

            if valid_acc > max_acc or loss_val < min_loss:
                min_loss = loss_val
                max_acc = valid_acc
                patient = 0
            else:
                patient += 1

            # bandwidth decaying
            if args.bandwidth_decaying != 1:
                model.decay_bandwidth(args.bandwidth_decaying, args.min_bandwidth)
        logger.print_statistics(run)
            
            # scheduler.step()
    logger.print_statistics()

if __name__ == "__main__":
    main()