import os
import argparse
import pickle
from datetime import datetime
import wandb

import torch
from ruamel.yaml import YAML

from utils.util import *
from trainers.blt_trainer import define_model, train_model, test_model
from trainers.util import load_model
from transducers.blt_transducers import define_transducer


def set_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_name', '-c', default='blt')
    parser.add_argument('--seed', '-s', default=42)
    parser.add_argument('--model_name', default='blt')
    parser.add_argument('--task', '-t', default='regression', choices=['regression', 'classification'])
    parser.add_argument('--dataset_name', '-dn', default='molnet')
    parser.add_argument('--dataset_split_type', '-ds', default='scaffold', choices=['scaffold', 'ac', 'hi', 'lo'])
    parser.add_argument('--embedding_source', '-es', default=None, choices=['gnn', 'smi_ted'])
    parser.add_argument('--prop_type', '-p', required=True,
                        choices=['bace', 'esol', 'freesolv', 'lipo', 
                                 'bace_x', 'esol_x', 'freesolv_x', 'lipo_x',
                                 'bbbp_x', 'clintox_x', 'sider_x',
                                 'core_ec50', 'core_ic50',
                                 'CHEMBL1862_Ki','CHEMBL1871_Ki','CHEMBL2034_Ki','CHEMBL2047_EC50',
                                 'CHEMBL204_Ki','CHEMBL2147_Ki','CHEMBL214_Ki','CHEMBL218_EC50',
                                 'CHEMBL219_Ki','CHEMBL228_Ki','CHEMBL231_Ki','CHEMBL233_Ki',
                                 'CHEMBL234_Ki','CHEMBL235_EC50','CHEMBL236_Ki','CHEMBL237_EC50',
                                 'CHEMBL237_Ki','CHEMBL238_Ki','CHEMBL239_EC50','CHEMBL244_Ki',
                                 'CHEMBL262_Ki','CHEMBL264_Ki','CHEMBL2835_Ki','CHEMBL287_Ki',
                                 'CHEMBL2971_Ki','CHEMBL3979_EC50','CHEMBL4005_Ki','CHEMBL4203_Ki',
                                 'CHEMBL4616_EC50','CHEMBL4792_Ki',
                                 'homo', 'lumo'])
    parser.add_argument('--data_dir', '-d', default='./data')
    #wandb
    parser.add_argument('--wandb_log', '-w', default=False, action='store_true')
    parser.add_argument('--proj_name', default='transductive_learning')
    #path
    parser.add_argument('--model_path', '-mp', type=str)
    parser.add_argument('--save_path', '-sp', type=str, default='results/blt')
    args = parser.parse_args()
    return args

    
def main():
    date_str = datetime.now().strftime('%Y%m%d_%H%M%S')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    args = set_args()
    args.date_str = date_str
    args.device = device
    args.save_path = os.path.join(args.save_path, args.dataset_split_type, args.prop_type, str(args.seed), date_str)
    args.checkpoint_path = os.path.join(args.save_path, 'ckpts')
    args.train_deltas_path = os.path.join(args.save_path, 'train_deltas.pkl')
    os.makedirs(args.save_path, exist_ok=True)
    os.makedirs(args.checkpoint_path, exist_ok=True)
    
    args.seed = int(args.seed)
    
    yaml = YAML()
    try:
        config_dict = yaml.load(open(f'configs/blt/generated/{args.config_name}.yml'))
    except:
        config_dict = yaml.load(open(f'configs/blt/{args.config_name}.yml'))
    config = ConfigNamespace(config_dict)
    
    set_seed(args)
    if args.wandb_log: set_wandb(args, config_dict)
    if args.task == "classification":
        config.exp.num_epochs = 10    
    print('Loading data...')
    if args.embedding_source == None:
        filename= f'{args.prop_type}.pkl' 
    elif args.embedding_source == 'gnn':
        filename= f'{args.prop_type}_gnn.pkl' 
    elif args.embedding_source == 'smi_ted':
        filename= f'{args.prop_type}_smi_ted.pkl' 
    data_path = os.path.join(args.data_dir, args.dataset_name, args.dataset_split_type, args.prop_type, filename)
    dataset = pickle.load(open(data_path, 'rb'))
    processed_dataset = {'train': {}, 'eval': {}, 'ood': {}}
    for split in ['train', 'eval', 'ood']:
        valid_indices = []
        for i, rep in enumerate(dataset[split]['reps']):
            if np.isnan(rep).sum() > 0:
                continue
            else:
                valid_indices.append(i)
        print(len(dataset[split]['reps']), len(valid_indices))
        for key in dataset[split].keys():
            processed_dataset[split].update({key : [dataset[split][key][i] for i in valid_indices]})
                
    dataset = processed_dataset
    n_approx_deltas = config.model.mul_approx_train_deltas * len(dataset['train']['reps'])
    config.model.n_approx_deltas = n_approx_deltas
    
    input_dim = len(dataset['train']['reps'][0])
    model = define_model(config, input_dim, output_dim=1)
    
    if args.model_path is None:
        print('Training model...')
        for name, param in model.named_parameters():
            print(name, param.mean().item(), param.std().item())
        model, deltas = train_model(args, config, dataset, model)
    else:
        print('Loading model...')
        model, deltas = load_model(args, model)
        
    transducer = define_transducer(args, config, dataset, deltas)
    
    print('Evaluating...')
    # Evaluation - eval
    if args.dataset_name == 'drugood':
        pass
    else:
        test_dataset = {'test_X': dataset['eval']['reps'], 
                'test_Y': dataset['eval']['targets'], 
                'test_smiles': dataset['eval']['smiles']
        }
        eval_preds = test_model(args, config, test_dataset, model, transducer, prefix='eval')
        pickle.dump(eval_preds, open(os.path.join(args.save_path, config.model.model_type + '_eval_preds.pkl'), 'wb'))

    # Evaluation - ood
    test_dataset = {'test_X': dataset['ood']['reps'], 
               'test_Y': dataset['ood']['targets'], 
               'test_smiles': dataset['ood']['smiles']
    }
    eval_preds = test_model(args, config, test_dataset, model, transducer, prefix='ood')
    pickle.dump(eval_preds, open(os.path.join(args.save_path, config.model.model_type + '_ood_preds.pkl'), 'wb'))
    wandb.finish()
if __name__ == "__main__":
    main()
