import os
import argparse
from datetime import datetime

import torch
from ruamel.yaml import YAML

from util import *
from trainer import *
from trainer_utils import *


def set_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_name', '-c', default='smi_ted')
    parser.add_argument('--seed', '-s', default=42)
    parser.add_argument('--model_name', default='smi_ted_light')
    parser.add_argument('--task', '-t', default='regression', choices=['regression', 'classification'])
    parser.add_argument('--dataset_name', '-dn', default='molnet')
    parser.add_argument('--dataset_split_type', '-ds', default='scaffold', choices=['scaffold', 'ac', 'hi', 'lo'])
    parser.add_argument('--prop_type', '-p', required=True, choices=['bace', 'esol', 'freesolv', 'lipo', 
                                                                     'bace_x', 'esol_x', 'freesolv_x', 'lipo_x', 
                                                                     'bbbp_x', 'clintox_x', 'sider_x',
                                                                     'core_ec50', 'core_ic50'])
    parser.add_argument('--data_dir', '-d', default='../../data')
    #wandb
    parser.add_argument('--wandb_log', '-w', default=False, action='store_true')
    parser.add_argument('--proj_name', default='transductive_learning')
    #path
    parser.add_argument('--model_path', '-mp', type=str)
    parser.add_argument('--save_path', '-sp', type=str, default='results')
    args = parser.parse_args()
    return args
    
    
def main():
    date_str = datetime.now().strftime('%Y%m%d_%H%M%S')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    args = set_args()
    args.date_str = date_str
    args.device = device
    args.save_path = os.path.join(args.save_path, args.dataset_split_type, args.prop_type, str(args.seed), date_str)
    args.checkpoint_path = os.path.join(args.save_path, 'ckpts')
    args.train_deltas_path = os.path.join(args.save_path, 'train_deltas.pkl')
    os.makedirs(args.save_path, exist_ok=True)
    os.makedirs(args.checkpoint_path, exist_ok=True)
    
    yaml = YAML()
    config_dict = yaml.load(open(f'configs/{args.config_name}.yml'))
    config = ConfigNamespace(config_dict)
    args.seed = int(args.seed)
    set_seed(args)
    if args.wandb_log: set_wandb(args, config_dict)

    model = define_model(args, config)
    optimizer, loss_fn = setup_optimizer_and_loss(model, config)
    
    trainer = TrainerRegressor(args, config, model, optimizer, loss_fn)
    
    trainer.fit()
    preds, tgts = trainer.evaluate()
    print((tgts==0).sum())
    results = calculate_metrics(tgts, preds, 'eval', args.task)
    save_results(args, results, 'eval')
    pickle.dump(preds, open(os.path.join(args.save_path, config.model.model_type + '_eval_preds.pkl'), 'wb'))
    if args.wandb_log: wandb.summary.update(results)
    
    preds, tgts = trainer.evaluate_ood()
    results = calculate_metrics(tgts, preds, 'ood', args.task)
    save_results(args, results, 'ood')
    pickle.dump(preds, open(os.path.join(args.save_path, config.model.model_type + '_ood_preds.pkl'), 'wb'))
    if args.wandb_log: wandb.summary.update(results)
    
if __name__ == "__main__":
    main()
