import random


class cached_property(object):
    """
    Descriptor (non-data) for building an attribute on-demand on first use.
    """
    def __init__(self, factory):
        """
        <factory> is called such: factory(instance) to build the attribute.
        """
        self._attr_name = factory.__name__
        self._factory = factory

    def __get__(self, instance, owner):
        # Build the attribute.
        attr = self._factory(instance)

        # Cache the value; hide ourselves.
        setattr(instance, self._attr_name, attr)
        return attr

def get_inds(expected_num, clu_nums, cid2clu, seq2ind):
    cur_len, cur_idx, query_cids, query_idx = 0, 0, [], []
    while cur_len < expected_num:
        cid, l = clu_nums[cur_idx % (len(clu_nums))]
        cur_idx += 1
        # check if this cluster has been selected
        if cid in query_cids:
            continue
        if random.random() > 0.5:
            for seq in cid2clu[cid]:
                # seq2ind: ensure it is in limited lengths
                if seq in seq2ind.keys():
                    query_idx.append(seq2ind[seq])
                    cur_len += 1

            query_cids.append(cid)
    return query_cids, query_idx

def get_num(N, valid_num=100):
    train_n, valid_n = int(0.9 * N), min(valid_num, int(0.05 * N))
    test_n = N - train_n - valid_n
    return train_n, valid_n, test_n

def get_full_inds(expected_num, clu_nums, cid2clu, full_seq2ind):
    cur_len, cur_idx, query_cids, query_idx = 0, 0, [], {}
    # build query_idx for each dataset
    for dataname in full_seq2ind.keys():
        if dataname not in query_idx.keys():
            query_idx[dataname] = []
    cur_idx_lst = list(range(len(clu_nums)))
    while cur_len < expected_num:
        cur_idx = random.choice(cur_idx_lst)
        cid, l = clu_nums[cur_idx]
        # check if this cluster has been selected
        if cid in query_cids:
            continue
        for seq in set(cid2clu[cid]):
            # seq2ind: ensure it is in limited lengths
            for dataname in full_seq2ind.keys():
                if seq in full_seq2ind[dataname].keys():
                    query_idx[dataname].append(full_seq2ind[dataname][seq])
                    cur_len += 1
        query_cids.append(cid)
        cur_idx_lst.remove(cur_idx)
    return query_cids, query_idx

def get_inds(expected_num, clu_nums, cid2clu, seq2ind):
    cur_len, query_cids, query_idx = 0, [], []
    cur_idx_lst = list(range(len(clu_nums)))
    while cur_len < expected_num:
        try:
            cur_idx = random.choice(cur_idx_lst)
            cid, l = clu_nums[cur_idx]
            # check if this cluster has been selected
            if cid in query_cids:
                continue

            # check if this cluster is too big
            pre = abs(expected_num - cur_len)
            aft = abs(cur_len + l - expected_num)
            if pre < aft:
                continue

            for seq in cid2clu[cid]:
                # seq2ind: ensure it is in limited lengths
                if seq in seq2ind.keys():
                    query_idx.append(seq2ind[seq])
                    cur_len += 1
            query_cids.append(cid)
            cur_idx_lst.remove(cur_idx)
        except:
            break
    return query_cids, query_idx