
import os

from datasets import load_dataset
from joblib import Parallel, delayed
import flpc as re
import argparse
import gc
import itertools
import pandas as pd
import sys
import gzip
import json
from joblib import Parallel, delayed
import glob

n_jobs = 80
print('[n_jobs]', n_jobs)

parser = argparse.ArgumentParser(description='Get data')
parser.add_argument('--begin_id', type=int, default=0,
                    help='begin id')
parser.add_argument('--n_size', type=int, default=100000,
                    help='n_size')
parser.add_argument('--dataset', type=str, default='pile_train')
parser.add_argument('--base_path', type=str, default='.', help='Base path')
args = parser.parse_args()
os.makedirs(f'{args.base_path}/secrets', exist_ok=True)

_regex = {
    'ethereum_wallet': r'\b(0x[a-fA-F0-9]{40})\b',
    'md5': r'\b[a-fA-F0-9]{32}\b',
    'sha1': r'\b[a-fA-F0-9]{40}\b',
    'sha256': r'\b[a-fA-F0-9]{64}\b',
    'sha512': r'\b[a-fA-F0-9]{128}\b',
    'java_serialization': r'(private |static |final |long ){2,}serialVersionUID ?= ?[0-9]{11,}',
}


begin_id = args.begin_id
end_id = args.begin_id + args.n_size

if os.path.isfile(f'{args.base_path}/secrets/{args.dataset}_{begin_id}_{end_id}.pkl'):
    print(f'Already done: {args}')
    sys.exit(0)

print('Running:', args)

def filter_pile(subset):
    def f(x):
        return x['meta']['pile_set_name'].replace(' ', '').replace('(', '').replace(')', '') == subset
    return f



def parse_dolma(fn):
    try:
        with gzip.open(fn, mode="rt", encoding="utf-8") as f:
            return [json.loads(line)['text'] for line in f]
    except Exception as e:
        print(f'Error in {fn}: {e}')
        return []

def parse_dolma_prefix(path, n_jobs=16):
    files = glob.glob(path)
    print(files)
    return sum(Parallel(n_jobs=n_jobs, verbose=3)(delayed(parse_dolma)(fn) for fn in files), [])


# Load only the training set
if args.dataset.startswith('pile-train'):
    subset = args.dataset.split('_')[1]
    # ['ArXiv', 'DM Mathematics', 'Enron Emails', 'EuroParl', 'FreeLaw',
    #        'Github', 'Gutenberg (PG-19)', 'HackerNews', 'NIH ExPorter',
    #        'PhilPapers', 'Pile-CC', 'PubMed Abstracts', 'PubMed Central',
    #        'StackExchange', 'USPTO Backgrounds', 'Ubuntu IRC',
    #        'Wikipedia (en)']
    texts = load_dataset("monology/pile-uncopyrighted",
                            split=f'train[{begin_id}:{end_id}]',
                            num_proc=n_jobs,
    ).filter(filter_pile(subset), num_proc=n_jobs)['text']
elif args.dataset.startswith('pile-test'):
    subset = args.dataset.split('_')[1]
    texts = load_dataset("monology/pile-uncopyrighted",
                            split=f'test[{begin_id}:{end_id}]',
                            data_files={'test': f"test.jsonl.zst"},
                            num_proc=n_jobs,
    ).filter(filter_pile(subset), num_proc=n_jobs)['text']
elif args.dataset.startswith('pile-val'):
    subset = args.dataset.split('_')[1]
    texts = load_dataset("monology/pile-uncopyrighted",
                            split=f'validation[{begin_id}:{end_id}]',
                          data_files={'validation': f"val.jsonl.zst"},
                            num_proc=n_jobs,
    ).filter(filter_pile(subset), num_proc=n_jobs)['text']
elif args.dataset.startswith('dolma'):
    subset = '_'.join(args.dataset.split('_')[1:])
    if '----' in subset:
        print(subset.split('----'))
        value = int(subset.split('----')[1])
        subset = subset.split('----')[0]
    else:
        value = 0
    subset_groups = f'{value:04}'

    texts = parse_dolma_prefix(f"{args.base_path}/dolma-v1_7/{subset}-{subset_groups}.json.gz")[begin_id:end_id]
elif args.dataset.startswith('proof-pile-2'):
    split = args.dataset.split('_')[1]
    subset = args.dataset.split('_')[2] # ["algebraic-stack", "arxiv", "open-web-math",]
    texts = load_dataset("EleutherAI/proof-pile-2", subset, split=f'{split}[{begin_id}:{end_id}]', trust_remote_code=True, num_proc=n_jobs)['text']
else:
    raise ValueError(f"Group {args.dataset} not found")

if len(texts) == 0:
    sys.exit(1)

gc.collect()

compiled_patterns = {k: re.compile(v) for k, v in _regex.items()}

def parse_string(x):
    ans = []
    for k, pattern in compiled_patterns.items():
        if out := re.finditer(pattern, x):
            out = out[0]
            a, b = out.span(0)
            if b-a > 8:
                ans.append((k, a, b, x))
    return ans

r = Parallel(n_jobs=n_jobs, verbose=2, backend='multiprocessing')(delayed(parse_string)(t) for t in texts)
print('list:', len(r))
r = list(itertools.chain(*r))
print('n_secrets:', len(r))

df = pd.DataFrame(r, columns=['secret_type', 'start', 'end', 'string']).astype({
    'secret_type': 'category',
    'start': 'int32',
    'end': 'int32',
    'string': 'str'
})

df.to_csv(f'{args.base_path}/secrets/{args.dataset}_{begin_id}_{end_id}.csv.xz', index=False)

print(f'DONE: {args}')
