import os
import wandb
import argparse
import pickle
from datetime import datetime

import torch
from ruamel.yaml import YAML

from utils.util import *
from trainers.blt_feature_trainer import *
from transducers.blt_graph_transducers import *


def set_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_name', '-c', default='topk_euc')
    parser.add_argument('--seed', '-s', default=42)
    parser.add_argument('--model_name', default='blt_feature')
    parser.add_argument('--task', '-t')
    parser.add_argument('--dataset_name', '-dn', default='molnet')
    parser.add_argument('--embedding_source', default=None, choices=['gnn', 'smi_ted'])
    parser.add_argument('--dataset_split_type', '-ds', default='scaffold', choices=['scaffold', 'ac', 'hi', 'lo'])
    parser.add_argument('--prop_type', '-p', required=True, choices=['bace', 'esol', 'freesolv', 'lipo', 
                                                                    'bace_x', 'esol_x', 'freesolv_x', 'lipo_x',
                                                                    'bbbp_x', 'clintox_x', 'sider_x',
                                                                    'core_ec50', 'core_ic50',
                                                                    'CHEMBL1862_Ki','CHEMBL1871_Ki','CHEMBL2034_Ki','CHEMBL2047_EC50',
                                                                    'CHEMBL204_Ki','CHEMBL2147_Ki','CHEMBL214_Ki','CHEMBL218_EC50',
                                                                    'CHEMBL219_Ki','CHEMBL228_Ki','CHEMBL231_Ki','CHEMBL233_Ki',
                                                                    'CHEMBL234_Ki','CHEMBL235_EC50','CHEMBL236_Ki','CHEMBL237_EC50',
                                                                    'CHEMBL237_Ki','CHEMBL238_Ki','CHEMBL239_EC50','CHEMBL244_Ki',
                                                                    'CHEMBL262_Ki','CHEMBL264_Ki','CHEMBL2835_Ki','CHEMBL287_Ki',
                                                                    'CHEMBL2971_Ki','CHEMBL3979_EC50','CHEMBL4005_Ki','CHEMBL4203_Ki',
                                                                    'CHEMBL4616_EC50','CHEMBL4792_Ki',
                                                                    'homo', 'lumo'])
    parser.add_argument('--data_dir', '-d', default='./data')
    #wandb
    parser.add_argument('--wandb_log', '-w', default=False, action='store_true')
    parser.add_argument('--proj_name', default='transductive_learning')
    #path
    parser.add_argument('--model_path', '-mp', type=str)
    parser.add_argument('--save_path', '-sp', type=str, default='results/blt_feature')
    args = parser.parse_args()
    return args
    
    
def main():
    date_str = datetime.now().strftime('%Y%m%d_%H%M%S')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    args = set_args()
    args.date_str = date_str
    args.device = device
    args.save_path = os.path.join(args.save_path, args.dataset_split_type, args.prop_type, str(args.seed), date_str)
    args.checkpoint_path = os.path.join(args.save_path, 'ckpts')
    args.train_deltas_path = os.path.join(args.save_path, 'train_deltas.pkl')
    os.makedirs(args.save_path, exist_ok=True)
    os.makedirs(args.checkpoint_path, exist_ok=True)
    args.seed = int(args.seed)
    yaml = YAML()
    try:
        config_dict = yaml.load(open(f'configs/blt_feature/generated/{args.config_name}.yml'))
    except:
        config_dict = yaml.load(open(f'configs/blt_feature/{args.config_name}.yml'))
    config = ConfigNamespace(config_dict)
    set_seed(args)
    if args.wandb_log: set_wandb(args, config_dict)
    if args.task == "classification":
        config.exp.num_epochs = 10
    print('Loading data...')
    if args.embedding_source == None:
        filename= f'{args.prop_type}.pkl' 
    elif args.embedding_source == 'gnn':
        filename= f'{args.prop_type}_gnn.pkl' 
    elif args.embedding_source == 'smi_ted':
        filename= f'{args.prop_type}_smi_ted.pkl' 
    data_path = os.path.join(args.data_dir, args.dataset_name, args.dataset_split_type, args.prop_type, filename)
    
    dataset = pickle.load(open(data_path, 'rb'))
    processed_dataset = {'train': {}, 'eval': {}, 'ood': {}}
    for split in ['train', 'eval', 'ood']:
        valid_indices = []
        for i, rep in enumerate(dataset[split]['reps']):
            if np.isnan(rep).sum() > 0:
                continue
            else:
                valid_indices.append(i)
        print(len(dataset[split]['reps']), len(valid_indices))
        for key in dataset[split].keys():
            processed_dataset[split].update({key : [dataset[split][key][i] for i in valid_indices]})
                
    dataset = processed_dataset
    input_dim = len(dataset['train']['reps'][0])
    #import pdb; pdb.set_trace()
    model = define_model(config, input_dim, num_tasks=1)
    log_params(model)
    model.to(args.device)
    
    features, smiles = dataset['train']['reps'], dataset['train']['smiles']
    memory_bank, smiles_bank = build_memory_bank_feature(args, model, features, smiles, batch_size=config.exp.batch_size)
    
    transducer = define_transducer(args, config, memory_bank, smiles_bank)
    print(f"Transducer initialized:")
    print(f"  Sampling Strategy: {transducer.sampling_strategy}")
    print(f"  Anchor Metric: {transducer.anchor_metric}")
    print(f"  Num Candidates (k): {transducer.num_candidates}")
    if transducer.sampling_strategy == 'diverse_topm':
        print(f"  Diversity Factor: {transducer.diverse_topm.diversity_factor}")
    elif transducer.sampling_strategy == 'adaptive_mask':
        print(f"  Adaptive k_min: {transducer.adaptive_k_min}")
        print(f"  Adaptive k_max: {transducer.adaptive_k_max}")
        print(f"  Adaptive Density Factor: {config.transducer.adaptive_mask.density_threshold_factor}")
        

    print('Training model...')
    model = train_model_gnn_with_memory(args, config, dataset, model, transducer)

    print('Evaluating...')
    # Evaluation - eval
    if args.dataset_name == 'drugood':
        pass
    else:
        eval_preds = test_model(args, config, dataset, model, transducer, prefix='eval')
        pickle.dump(eval_preds, open(os.path.join(args.save_path, config.model.model_type + '_eval_preds.pkl'), 'wb'))

    # Evaluation - ood
    eval_preds = test_model(args, config, dataset, model, transducer, prefix='ood')
    pickle.dump(eval_preds, open(os.path.join(args.save_path, config.model.model_type + '_ood_preds.pkl'), 'wb'))
    
    wandb.finish()
if __name__ == "__main__":
    main()
