import os
import pickle
import random
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
from urllib.request import urlopen
from collections import defaultdict
from functools import partialmethod

def cluster_by_entity(threshold, max_entry_cluster_size, cv, seed):
    db5_df = pd.read_csv('../DB5/db5_difficulty.csv', header=0, index_col=None, sep=',')
    db5_pdb_ids = list(db5_df['pdb_id'].unique())
    assert len(db5_pdb_ids) == 253
    rcsb_pdb_ids = list(set([fname[:4] for fname in os.listdir('dataset_RCSB/')]))

    if not os.path.isfile(f'dict1_{threshold}.pkl'):
        # fetch entity clustering info
        url = f'https://cdn.rcsb.org/resources/sequence/clusters/clusters-by-entity-{threshold}.txt'
        dict1 = defaultdict(list)
        dict2 = dict()
        input_pdb_ids = set(rcsb_pdb_ids + db5_pdb_ids)
        cnt = 0
        for _, line in tqdm(enumerate(urlopen(url))):
            decoded_line = line.decode('utf-8')
            pdb_ids = [s.split('_')[0] for s in decoded_line.strip('\n').split(' ')
                    if s.split('_')[0] in input_pdb_ids]
            if len(pdb_ids) < 1:
                continue
            pdb_ids = list(set(pdb_ids))
            for pdb_id in pdb_ids:
                dict1[pdb_id].append(cnt)
            dict2[cnt] = pdb_ids
            cnt += 1
        
        with open(f'dict1_{threshold}.pkl', 'wb') as f1:
            pickle.dump(dict1, f1)
        with open(f'dict2_{threshold}.pkl', 'wb') as f2:
            pickle.dump(dict2, f2)
    else:
        with open(f'dict1_{threshold}.pkl', 'rb') as f1:
            dict1 = pickle.load(f1)
        with open(f'dict2_{threshold}.pkl', 'rb') as f2:
            dict2 = pickle.load(f2)
    
    # create group mapping matrix
    num_pdb_ids = len(dict1.keys())
    num_groups = len(dict2.keys())
    all_pdb_ids = list(dict1.keys())
    group_map = np.zeros((num_pdb_ids, num_groups))
    for key, val in tqdm(dict1.items()):
        irow = all_pdb_ids.index(key)
        icol = np.array(val, dtype=int)
        group_map[irow, icol] = 1
    group_corr = group_map @ group_map.T
    print(f'{threshold}% seq similarity, num pdb ids: {group_map.shape[0]}, num entity groups: {group_map.shape[1]}')

    similar_ids = []
    for pid in tqdm(db5_pdb_ids):
        if pid not in all_pdb_ids: # 3RVW was recently removed from RCSB
            continue
        idx = all_pdb_ids.index(pid)
        similar_ids.extend(list(np.where(group_corr[idx] > 0)[0]))
    similar_ids = set(similar_ids)
    print(f'out of {len(rcsb_pdb_ids)} RCSB IDs, {len(similar_ids)} are highly similar to DB5')

    filtered_ids = [pid for pid in all_pdb_ids if all_pdb_ids.index(pid) not in similar_ids]
    filtered_ids = [pid for pid in filtered_ids if pid not in db5_pdb_ids]
    
    pdb_to_group = pd.Series({pdb: tuple(sorted(set(keys))) for pdb, keys in dict1.items()})

    ##############################  Survey entry clusters and subsample large clusters  ##############################
    print(f'size of filtered PDB IDs: {len(filtered_ids)}')
    group_counter = pdb_to_group[filtered_ids].value_counts()
    print(f'\t {len(group_counter)} PDB groups found in filtered PDB. Size:')
    print(group_counter.describe(percentiles=[0.01, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99]))
    print('\tTop 20 clusters: ')
    print(group_counter.head(n=20))

    sub_sampled = []
    valid_candidates = []
    for _, members in pdb_to_group[filtered_ids].groupby(pdb_to_group[filtered_ids].values):
        if len(members) > max_entry_cluster_size:
            sub_sampled.append(members.sample(n=max_entry_cluster_size, replace=False, random_state=seed))
        else:
            sub_sampled.append(members)
        if len(members) == 1:
            valid_candidates.append(members)
    filtered_ids_subsampled = pd.concat(sub_sampled)
    valid_candidates = pd.concat(valid_candidates)
    print(f'After subsampling: {len(filtered_ids_subsampled)} PDB IDs')
    print(filtered_ids_subsampled.value_counts().head(n=20))
    filtered_ids = set(list(filtered_ids_subsampled.index))
    valid_candidates = list(valid_candidates.index)
    print('size of valid/test candidates:', len(valid_candidates))
    print('PDB groups are PDB entries with identical entity id sets.')

    random.shuffle(valid_candidates)
    valid_ids = valid_candidates[:300]
    train_ids = [pid for pid in filtered_ids if pid not in valid_ids]
    print(f'size of training PDB IDs: {len(train_ids)}')
    print(f'size of validation PDB IDs: {len(valid_ids)}')

    with open(f'rcsb_train_{threshold}_cv{cv}.txt', 'w') as f:
        f.write('\n'.join(train_ids) + '\n')
    with open(f'rcsb_valid_{threshold}_cv{cv}.txt', 'w') as f:
        f.write('\n'.join(valid_ids) + '\n')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--threshold', type=int, default=30)
    parser.add_argument('--mute-tqdm', action='store_true')
    parser.add_argument('--max-entry-cluster-size', type=int, default=1E6, 
        help='Set the maximum entry cluster size (defined by entity cluster id set) by random subsampling large entry clusters')
    parser.add_argument('--cv', type=int, default=0, help='Cross validation ID')
    parser.add_argument('--seed', type=int, default=None, help='Random seed')
    args = parser.parse_args()
    print(args)

    random.seed(args.seed)
    
    # optionally mute tqdm
    if args.mute_tqdm:
        tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)

    cluster_by_entity(threshold=args.threshold, 
                      max_entry_cluster_size=args.max_entry_cluster_size, 
                      cv=args.cv,
                      seed=args.seed)


