import os
import json
import numpy as np
import pickle
import argparse
import warnings
from sklearn.model_selection import train_test_split
import random

from rdkit import Chem
from sklearn.cluster import SpectralClustering, KMeans
from cheminformatics.splitting import map_scaffolds
from cheminformatics.descriptors import mols_to_ecfp
from cheminformatics.multiprocessing import tanimoto_matrix

import deepchem as dc
from deepchem.feat.molecule_featurizers import RDKitDescriptors

from process_utils import *
import splito

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
warnings.filterwarnings('ignore')

def set_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', '-d', default='../data')
    parser.add_argument('--dataset_name', '-dn', default='molnet', choices=['molnet', 'drugood'])
    parser.add_argument('--data_rep', '-dr', type=str, default=None, choices=['gnn', 'smi_ted'])
    parser.add_argument('--property', '-p', type=str, required=True, choices=['bace', 'esol', 'freesolv', 'lipo', 
                                                                              'bace_x', 'esol_x', 'freesolv_x', 'lipo_x',
                                                                              'core_ec50', 'core_ic50'])
    parser.add_argument('--split_x_type', '-st', default='scaffold', type=str, choices=['scaffold', 'mces', 'max_dis', 'hi', 'lo'])
    parser.add_argument('--split_ratio', '-or', type=float, default=0.05)
    parser.add_argument('--seed', default=42)
    args = parser.parse_args()
    return args

def set_seed(args):
    seed = args.seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    

def split_based_on_x_scaffold(args, smiles):
    scaffold_smiles, uniques = map_scaffolds(smiles, scaffold_type='cyclic_skeleton')
    df_scaffolds = organize_dataframe(uniques, smiles)
    scaffold_mols = [Chem.MolFromSmiles(smi) for smi in df_scaffolds['scaffolds']]
    ecfps = mols_to_ecfp(scaffold_mols, radius=2, nbits=2048)
    tanimoto_score = tanimoto_matrix(ecfps, dtype=float)
    
    n_clusters = eigenvalue_cluster_approx(tanimoto_score)
    spectral = SpectralClustering(n_clusters=n_clusters, affinity='precomputed', assign_labels='kmeans')
    clusters = spectral.fit_predict(tanimoto_score)
    df_scaffolds['cluster'] = clusters
    df_clusters = group_and_sort(clusters, similarity_matrix=tanimoto_score, n_smiles_for_each_scaffold=df_scaffolds['n_mols'])
    ood_clusters = select_ood_clusters(df_clusters, size_cutoff=int(len(smiles) * args.split_ratio))
    
    clusters_per_original = [''] * len(smiles)
    for original_smi, cluster in zip(df_scaffolds['original_smiles'], clusters):
        for smi in original_smi.split(';'):
            idx = smiles.index(smi)
            clusters_per_original[idx] = cluster
    train_indices, eval_indices, ood_indices = split_data(args, clusters_per_original, ood_clusters)

    return train_indices, eval_indices, ood_indices, clusters_per_original


def split_based_on_x_mces(args, mols_list, method='spectral_clustering'):
    mces = MCES(args, method)
    train_indices, eval_indices, ood_indices = mces.get_mces_ood(mols_list)
    return train_indices, eval_indices, ood_indices

def split_based_on_x_max_dis(args, smiles):
    '''
    ecfps = mols_to_ecfp(smiles, radius=2, nbits=2048)
    tanimoto_score = tanimoto_matrix(ecfps, dtype=float)
    '''
    splitter = splito.MaxDissimilaritySplit(n_jobs=-1, test_size=args.split_ratio, random_state=args.seed)
    smiles = [Chem.MolToSmiles(mol) for mol in smiles]
    train_eval_indices, ood_indices = next(splitter.split(smiles))
    train_indices, eval_indices = train_test_split(train_eval_indices, test_size = args.split_ratio, shuffle=True, random_state=args.seed)
    return train_indices, eval_indices, ood_indices

    
def split_based_on_y(args, valid_targets):
    split_ratio = args.split_ratio
    
    # split train, eval and ood set by args.split_ratio
    y_sorted_indices = np.argsort(np.array(valid_targets).ravel())
    all_n_samples = len(valid_targets)
    ood_n_samples = int(split_ratio * all_n_samples)
    ood_indices = y_sorted_indices[-ood_n_samples:]
    train_eval_indices = y_sorted_indices[:-ood_n_samples]
    train_indices, eval_indices = train_test_split(train_eval_indices, test_size = split_ratio, random_state=args.seed)
    return train_indices, eval_indices, ood_indices


def process_smiles(args, data):
    featurizer = RDKitDescriptors(is_normalized=True)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    valid_mols = []
    graph_data_list = []
    valid_reps = []
    valid_smiles = []
    valid_targets = []
    
    smiles_list = [d['smiles'] for d in data]
    target_list = [d['cls_label'] for d in data]
    for smiles, target in tqdm(zip(smiles_list, target_list), total=len(smiles_list), desc='Processing Smiles...'):
        if isinstance(smiles, Chem.Mol):
            mol = smiles
        else:
            mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            continue
        rep = featurizer.featurize([mol])[0] 
        valid_mols.append(mol)
        data = mol_to_graph(mol)
        data.y = [target]
        graph_data_list.append(data)
        
        valid_reps.append(rep)
        valid_smiles.append(smiles)
        valid_targets.append(target)
    if args.data_rep == None:
        valid_reps = valid_reps
    elif args.data_rep == 'gnn':
        valid_reps = get_gnn_embedding(graph_data_list, device)
    elif args.data_rep == 'smi_ted':
        valid_reps = get_smi_ted_embedding(valid_smiles, device)
    processed_data = {'graphs': graph_data_list,
                      'reps': valid_reps,
                      'smiles': valid_smiles,
                      'targets': valid_targets}
    return processed_data
    
def drugood_preprocess(args):
    data = json.load(open(f'../data/drugood_all/lbap_{args.property}_scaffold.json', 'r'))
    dataset_dict = {'train': process_smiles(args, data['split']['train']),
                    'eval': process_smiles(args, data['split']['iid_test']),
                    'ood': process_smiles(args, data['split']['ood_test'])}
    
    if args.data_rep == None:
        filename = f'{args.property}.pkl'
    elif args.data_rep == 'gnn':
        filename = f'{args.property}_gnn.pkl'
    elif args.data_rep == 'smi_ted':
        filename = f'{args.property}_smi_ted.pkl'

    
    save_dataset_as_csv(args, dataset_dict)
    save_path = os.path.join(args.data_dir, args.dataset_name, args.split_x_type, args.property, filename)
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    pickle.dump(dataset_dict, open(save_path, 'wb'))
    
    print(f'data saved as {save_path}')
    
    
def molnet_preprocess(args):
    featurizer = RDKitDescriptors(is_normalized=True)
    if args.property in ['bace', 'bace_x']:
        _, datasets, _ = dc.molnet.load_bace_regression(featurizer=featurizer)
    elif args.property in ['esol', 'esol_x']:
        _, datasets, _ = dc.molnet.load_delaney(featurizer=featurizer)
    elif args.property in ['freesolv', 'freesolv_x']:
        _, datasets, _ = dc.molnet.load_freesolv(featurizer=featurizer)
    elif args.property in ['lipo', 'lipo_x']:
        _, datasets, _ = dc.molnet.load_lipo(featurizer=featurizer)
    
    
    train, valid, test = datasets
    all_reps = np.concatenate([train.X, valid.X, test.X])
    all_smiles = np.concatenate([train.ids, valid.ids, test.ids])
    all_targets = np.concatenate([train.y, valid.y, test.y])
    
    graph_data_list = []
    valid_reps = []
    valid_smiles = []
    valid_targets = []
    valid_mols = []
    for rep, smi, target in zip(all_reps, all_smiles, all_targets):
        if np.isnan(rep).sum() > 0:
            continue
        if isinstance(smi, Chem.Mol):
            mol = smi
        else:
            mol = Chem.MolFromSmiles(smi)
        if mol is None:
            continue
        valid_mols.append(mol)
        data = mol_to_graph(mol)
        data.y = [target[0]]
        graph_data_list.append(data)
        
        valid_reps.append(rep)
        valid_smiles.append(Chem.MolToSmiles(mol))
        valid_targets.append(target)
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if args.data_rep == None:
        valid_reps = valid_reps
        filename = f'{args.property}.pkl'
    elif args.data_rep == 'gnn':
        valid_reps = get_gnn_embedding(graph_data_list, device)
        filename = f'{args.property}_gnn.pkl'
    elif args.data_rep == 'smi_ted':
        valid_reps = get_smi_ted_embedding(valid_smiles, device)
        filename = f'{args.property}_smi_ted.pkl'
        
    if '_x' in args.property:
        if args.split_x_type == 'scaffold':
            train_indices, eval_indices, ood_indices, _ = split_based_on_x_scaffold(args, valid_smiles)
        elif args.split_x_type == 'mces':
            train_indices, eval_indices, ood_indices = split_based_on_x_mces(args, valid_mols)
        elif args.split_x_type == 'max_dis':
            train_indices, eval_indices, ood_indices = split_based_on_x_max_dis(args, valid_mols)
    else:
        train_indices, eval_indices, ood_indices = split_based_on_y(args, valid_targets)
    print(train_indices.shape, train_indices)
    print(eval_indices.shape, eval_indices)
    print(ood_indices.shape, ood_indices)
    
    dataset_dict = {'train': build_dataset(graph_data_list, valid_reps, valid_smiles, valid_targets, train_indices),
                    'eval': build_dataset(graph_data_list, valid_reps, valid_smiles, valid_targets, eval_indices),
                    'ood': build_dataset(graph_data_list, valid_reps, valid_smiles, valid_targets, ood_indices),
    }
    
    save_dataset_as_csv(args, dataset_dict)
    save_path = os.path.join(args.data_dir, args.dataset_name, args.split_x_type, args.property, filename)
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    pickle.dump(dataset_dict, open(save_path, 'wb'))
    
    print(f'data saved as {save_path}')
    
def main():
    args = set_args()
    set_seed(args)
    if args.dataset_name == 'molnet':
        molnet_preprocess(args)
    if args.dataset_name == 'drugood':
        drugood_preprocess(args)
        
    
if __name__ == "__main__":
    main()
 