import os
import argparse
import pickle
from datetime import datetime

import torch
import torch.optim as optim
from ruamel.yaml import YAML

from util import *
from trainer import *


def set_args():

    parser = argparse.ArgumentParser()
    parser.add_argument('--config_name', '-c', default='pretrained_gnn')
    parser.add_argument('--seed', '-s', default=42)
    parser.add_argument('--model_name', default='pretrained_gnn')
    parser.add_argument('--dataset_name', '-dn', default='molnet')
    parser.add_argument('--task', '-t', type=str, default='regression', choices=['regression', 'classification'])
    parser.add_argument('--dataset_split_type', '-ds', default='scaffold')
    parser.add_argument('--prop_type', '-p', required=True)
    parser.add_argument('--data_dir', '-d', default='../../data')
    parser.add_argument('--wandb_log', '-w', default=False, action='store_true')
    parser.add_argument('--proj_name', default='transductive_learning')
    parser.add_argument('--model_path', '-mp', type=str)
    parser.add_argument('--save_path', '-sp', type=str, default='saved_results')
    args = parser.parse_args()
    return args
    
    
def main():
    date_str = datetime.now().strftime('%Y%m%d_%H%M%S')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    args = set_args()
    args.date_str = date_str
    args.device = device
    args.save_path = os.path.join(args.save_path, args.dataset_name, args.dataset_split_type, args.prop_type, str(args.seed), date_str)
    args.checkpoint_path = os.path.join(args.save_path, 'ckpts')
    args.train_deltas_path = os.path.join(args.save_path, 'train_deltas.pkl')
    os.makedirs(args.save_path, exist_ok=True)
    os.makedirs(args.checkpoint_path, exist_ok=True)

    yaml = YAML()
    config_dict = yaml.load(open(f'configs/{args.config_name}.yml'))
    config = ConfigNamespace(config_dict)
    
    args.seed = int(args.seed)
    set_seed(args)
    if args.wandb_log: set_wandb(args, config_dict)
    if args.wandb_log: wandb.log({'dataset': 'original_' + args.prop_type})

    print('Loading data...')
    data_path = os.path.join(args.data_dir, args.dataset_name, args.dataset_split_type, args.prop_type, f"{args.prop_type}.pkl")
    all_dataset = pickle.load(open(data_path, 'rb'))

    if not args.dataset_split_type in ['hi','lo']:
        train_dataset = all_dataset['train']
        eval_dataset = all_dataset['eval']
        ood_dataset = all_dataset['ood']
    else:
        train_dataset = all_dataset['train']
        eval_dataset = all_dataset['eval']

    model = define_model(config, num_tasks=1)
    if config.model.path:
        model.from_pretrained(config.model.path)
    model.to(args.device)

    print('Training model...')
    lr = config.exp.lr
    lr_scale = config.exp.lr_scale
    model_param_group = []
    model_param_group.append({"params": model.gnn.parameters()})
    if config.model.graph_pooling == "attention":
        model_param_group.append({"params": model.pool.parameters(), "lr": lr * lr_scale})
    model_param_group.append({"params": model.graph_pred_linear.parameters(), "lr": lr * lr_scale})
    optimizer = optim.Adam(model_param_group, lr=lr, weight_decay=config.exp.decay)
    model = train_model(args, config, train_dataset, model, optimizer)

    print('Evaluating...')
    # Evaluation - eval
    if not args.dataset_split_type in ['hi','lo']:
        eval_true, eval_preds = test_model(config, model, device, eval_dataset)
        results = calculate_metrics(eval_true, eval_preds, 'eval', args.task)
        save_results(args, results, 'eval')
        pickle.dump(eval_preds, open(os.path.join(args.save_path, config.model.model_type + '_eval_preds.pkl'), 'wb'))
        if args.wandb_log: wandb.summary.update(results)
        
        # Evaluation - ood
        ood_true, ood_preds = test_model(config, model, device, ood_dataset)
        results = calculate_metrics(ood_true, ood_preds, 'ood', args.task)
        save_results(args, results, 'ood')
        pickle.dump(ood_preds, open(os.path.join(args.save_path, config.model.model_type + '_ood_preds.pkl'), 'wb'))
        if args.wandb_log: wandb.summary.update(results)
    else:
        eval_true, eval_preds = test_model(config, model, device, eval_dataset)
        results = calculate_metrics(eval_true, eval_preds, 'eval', args.task)
        save_results(args, results, 'eval')
        pickle.dump(eval_preds, open(os.path.join(args.save_path, config.model.model_type + '_eval_preds.pkl'), 'wb'))
        if args.wandb_log: wandb.summary.update(results)
if __name__ == "__main__":
    main()
