import os
import time
import argparse
from datetime import datetime

import pandas as pd
import torch
import chemprop
from ruamel.yaml import YAML

from util import *


def set_args():
	parser = argparse.ArgumentParser()
	parser.add_argument('--config_name', '-c', default='chemprop')
	parser.add_argument('--seed', '-s', default=42)
	parser.add_argument('--model_name', default='chemprop')
	parser.add_argument('--task', '-t', type=str, required=True, choices=['regression', 'classification'])
	parser.add_argument('--dataset_name', '-dn', default='molnet')
	parser.add_argument('--dataset_split_type', '-ds', default='scaffold', choices=['scaffold', 'ac', 'hi', 'lo'])
	parser.add_argument('--prop_type', '-p', required=True, choices=['bace', 'esol', 'freesolv', 'lipo', 
																	 'bace_x', 'esol_x', 'freesolv_x', 'lipo_x',
																	 'bbbp_x', 'clintox_x', 'sider_x',
																	 'core_ec50', 'core_ic50',
																	 'CHEMBL1862_Ki','CHEMBL1871_Ki','CHEMBL2034_Ki','CHEMBL2047_EC50','CHEMBL204_Ki','CHEMBL2147_Ki','CHEMBL214_Ki','CHEMBL218_EC50','CHEMBL219_Ki','CHEMBL228_Ki','CHEMBL231_Ki','CHEMBL233_Ki','CHEMBL234_Ki','CHEMBL235_EC50','CHEMBL236_Ki','CHEMBL237_EC50','CHEMBL237_Ki','CHEMBL238_Ki','CHEMBL239_EC50','CHEMBL244_Ki','CHEMBL262_Ki','CHEMBL264_Ki','CHEMBL2835_Ki','CHEMBL287_Ki','CHEMBL2971_Ki','CHEMBL3979_EC50','CHEMBL4005_Ki','CHEMBL4203_Ki','CHEMBL4616_EC50','CHEMBL4792_Ki',
																	 'homo', 'lumo'])
	parser.add_argument('--data_dir', '-d', default='../../data')
	parser.add_argument('--wandb_log', '-w', default=False, action='store_true')
	parser.add_argument('--proj_name', default='transductive_learning')
	parser.add_argument('--model_path', '-mp', type=str)
	parser.add_argument('--save_path', '-sp', type=str, default='saved_results')

	args = parser.parse_args()
	return args

def evaluate(args, split):
	smiles_path = os.path.join(args.data_dir, args.dataset_name, args.dataset_split_type, args.prop_type,  f'{split}_smiles.csv')
	gt_path = os.path.join(args.data_dir, args.dataset_name, args.dataset_split_type, args.prop_type,  f'{split}_featurized.csv')
	data = pd.read_csv(gt_path)
	gts = data['target'].values.tolist()
	data['smiles'].to_csv(smiles_path, index=None)
	preds_path = os.path.join(args.save_path, f'{split}_preds.csv')

	pred_args = chemprop.args.PredictArgs().parse_args([
		'--test_path', smiles_path,
		'--preds_path', preds_path,
		'--checkpoint_dir', args.checkpoint_path 
	])

	torch.cuda.reset_peak_memory_stats('cuda')
	preds = chemprop.train.make_predictions(args=pred_args)
	preds = [x[0] for x in preds]
	return gts, preds
	
def main():

	date_str = datetime.now().strftime('%Y%m%d_%H%M%S')
	device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
	args = set_args()
	args.date_str = date_str
	args.device = device
	args.save_path = os.path.join(args.save_path, args.dataset_name, args.dataset_split_type, args.prop_type, str(args.seed), date_str)
	args.checkpoint_path = os.path.join(args.save_path, 'ckpts')
	os.makedirs(args.save_path, exist_ok=True)
	os.makedirs(args.checkpoint_path, exist_ok=True)
	args.seed = str(args.seed)
	yaml = YAML()
	config_dict = yaml.load(open(f'configs/{args.config_name}.yml'))
	config = ConfigNamespace(config_dict)
	if args.wandb_log: set_wandb(args, config_dict)
	
	print('Loading data...')
	train_path = os.path.join(args.data_dir, args.dataset_name, args.dataset_split_type, args.prop_type, 'train_featurized.csv')
	test_path = os.path.join(args.data_dir, args.dataset_name, args.dataset_split_type, args.prop_type, 'eval_featurized.csv') #
	train_args = chemprop.args.TrainArgs().parse_args([
		'--data_path', train_path,
		'--dataset_type', args.task,
		'--save_dir', args.checkpoint_path,
		'--epochs', str(config.exp.num_epochs),
		'--separate_test_path', test_path, 
		'--seed', args.seed
	])
	mean_score, std_score = chemprop.train.cross_validate(args=train_args, train_func=chemprop.train.run_training)
	
	if not args.dataset_split_type in ['hi','lo']:
		gts, preds = evaluate(args, 'eval')
		gts = np.array(gts)
		preds = np.array(preds)
		results = calculate_metrics(gts, preds, 'eval', args.task)
		save_results(args, results, 'eval')
		if args.wandb_log: wandb.summary.update(results)
		
		gts, preds = evaluate(args, 'ood')
		gts = np.array(gts)
		preds = np.array(preds)
		results = calculate_metrics(gts, preds, 'ood', args.task)
		save_results(args, results, 'ood')
		if args.wandb_log: wandb.summary.update(results)
	
	else:
		gts, preds = evaluate(args, 'eval')
		gts = np.array(gts)
		preds = np.array(preds)
		results = calculate_metrics(gts, preds, 'eval', args.task)
		save_results(args, results, 'eval')
		if args.wandb_log: wandb.summary.update(results)
		
if __name__ == "__main__":
	main()
