import os
import argparse
import pickle
from datetime import datetime
import wandb
import torch
from ruamel.yaml import YAML

from utils.util import *
from models.blt_chemprop import define_model
from trainers.blt_chemprop_trainer import *
from transducers.blt_graph_transducers import *


def set_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_name', '-c', default='topk_euc')
    parser.add_argument('--seed', '-s', default=42)
    parser.add_argument('--model_name', default='blt_chemprop')
    parser.add_argument('--task', '-t', type=str, choices=['regression', 'classification'])
    parser.add_argument('--dataset_name', '-dn', default='molnet')
    parser.add_argument('--dataset_split_type', '-ds', default='scaffold', choices=['scaffold', 'ac', 'hi', 'lo'])
    parser.add_argument('--prop_type', '-p', required=True, choices=['bace', 'esol', 'freesolv', 'lipo', 
                                                                    'bace_x', 'esol_x', 'freesolv_x', 'lipo_x',
                                                                    'bbbp_x', 'clintox_x', 'sider_x',
                                                                    'core_ec50', 'core_ic50',
                                                                    'CHEMBL1862_Ki','CHEMBL1871_Ki','CHEMBL2034_Ki','CHEMBL2047_EC50',
                                                                    'CHEMBL204_Ki','CHEMBL2147_Ki','CHEMBL214_Ki','CHEMBL218_EC50',
                                                                    'CHEMBL219_Ki','CHEMBL228_Ki','CHEMBL231_Ki','CHEMBL233_Ki',
                                                                    'CHEMBL234_Ki','CHEMBL235_EC50','CHEMBL236_Ki','CHEMBL237_EC50',
                                                                    'CHEMBL237_Ki','CHEMBL238_Ki','CHEMBL239_EC50','CHEMBL244_Ki',
                                                                    'CHEMBL262_Ki','CHEMBL264_Ki','CHEMBL2835_Ki','CHEMBL287_Ki',
                                                                    'CHEMBL2971_Ki','CHEMBL3979_EC50','CHEMBL4005_Ki','CHEMBL4203_Ki',
                                                                    'CHEMBL4616_EC50','CHEMBL4792_Ki',
                                                                    'homo', 'lumo'])
    parser.add_argument('--data_dir', '-d', default='./data')
    #wandb
    parser.add_argument('--wandb_log', '-w', default=False, action='store_true')
    parser.add_argument('--proj_name', default='transductive_learning')
    #path
    parser.add_argument('--model_path', '-mp', type=str)
    parser.add_argument('--save_path', '-sp', type=str, default='results/blt_chemprop')
    args = parser.parse_args()
    return args
    
    
def main():
    date_str = datetime.now().strftime('%Y%m%d_%H%M%S')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    args = set_args()
    args.date_str = date_str
    args.device = device
    args.save_path = os.path.join(args.save_path, args.dataset_split_type, args.prop_type, str(args.seed), date_str)
    args.checkpoint_path = os.path.join(args.save_path, 'ckpts')
    args.train_deltas_path = os.path.join(args.save_path, 'train_deltas.pkl')
    os.makedirs(args.save_path, exist_ok=True)
    os.makedirs(args.checkpoint_path, exist_ok=True)
    
    args.seed = int(args.seed)
    yaml = YAML()
    config_dict = yaml.load(open(f'configs/blt_chemprop/{args.config_name}.yml'))
    config = ConfigNamespace(config_dict)
    set_seed(args)
    if args.wandb_log: set_wandb(args, config_dict)
    if args.task == "classification":
        config.exp.num_epochs = 10

    print('Loading data...')
    train_graphs, train_smiles, train_targets = load_dataset(args, 'train')
    
    model = define_model(args, config, num_tasks=1)
    if config.model.freeze_encoder:
        for param in getattr(model, 'encoder', torch.nn.Module()).parameters():
            param.requires_grad = False
    log_params(model)
    model.to(args.device)
    
    graph_dataset = GraphDataset(train_graphs, train_smiles, train_targets)
    graph_loader = DataLoader(graph_dataset, batch_size=config.exp.batch_size, shuffle=False, collate_fn=graph_collate_fn, num_workers=0)
    memory_bank, smiles_bank = build_memory_bank_gnn(model, graph_loader, batch_size=config.exp.batch_size)
    
    transducer = define_transducer(args, config, memory_bank, smiles_bank)
    print(f"Transducer initialized:")
    print(f"  Sampling Strategy: {transducer.sampling_strategy}")
    print(f"  Anchor Metric: {transducer.anchor_metric}")
    print(f"  Num Candidates (k): {transducer.num_candidates}")
    if transducer.sampling_strategy == 'diverse_topm':
        print(f"  Diversity Factor: {transducer.diverse_topm.diversity_factor}")
    elif transducer.sampling_strategy == 'adaptive_mask':
        print(f"  Adaptive k_min: {transducer.adaptive_k_min}")
        print(f"  Adaptive k_max: {transducer.adaptive_k_max}")
        print(f"  Adaptive Density Factor: {config.transducer.adaptive_mask.density_threshold_factor}")
        

    print('Training model...')
    model = train_model_gnn_with_memory(args, config, train_graphs, train_smiles, train_targets, model, transducer)

    print('Evaluating...')
    if args.task == 'classification':
        pass
    else:
        # Evaluation - eval
        eval_graphs, eval_smiles, eval_targets = load_dataset(args, 'eval')
        eval_preds = test_model(args, config, eval_graphs, eval_smiles, eval_targets, model, transducer, prefix='eval')
        pickle.dump(eval_preds, open(os.path.join(args.save_path, config.model.model_type + '_eval_preds.pkl'), 'wb'))

    if args.dataset_split_type in ['hi', 'lo']:
        pass
    else:
        # Evaluation - ood
        ood_graphs, ood_smiles, ood_targets = load_dataset(args, 'ood')
        eval_preds = test_model(args, config, ood_graphs, ood_smiles, ood_targets, model, transducer, prefix='ood')
        pickle.dump(eval_preds, open(os.path.join(args.save_path, config.model.model_type + '_ood_preds.pkl'), 'wb'))
        wandb.finish()
    
if __name__ == "__main__":
    main()
