import logging
import json
import random
from pathlib import Path
from src.utils import clean

random.seed(42)
logging.basicConfig(level=logging.INFO)

DATA_DIR='/data/MLP_plus/data/'
PROCESSED_DIR='/data/MLP_plus/processed/'

def process_counterfact():
    data = Path(DATA_DIR) / 'counterfact' / 'counterfact-train.json'
    data = json.loads(data.read_text())
    random.shuffle(data)
    # Split into train/val/test
    # get 1000 record with non-empty target first
    n_edits = 6000
    sub_data = []
    num_parts = 3

    for x in data:
        if x['target_new']:
            x['rephrase_prompt'] = clean(x['rephrase_prompt'], x['subject'])
            if len(sub_data) < 3 * n_edits:
                sub_data.append(x)
            else: break

    split_size = len(sub_data) // num_parts
    train, val, test = [sub_data[i * split_size:(i + 1) * split_size] for i in range(num_parts)]

    for size in [10, 100, 1000, 2000, 3000, 6000]:
        assert len(train) >= size
        assert len(val) >= size
        assert len(test) >= size
        (Path(PROCESSED_DIR) / 'cf' / f'train_{size}.json').write_text(json.dumps(train[:size], indent=2))
        (Path(PROCESSED_DIR) / 'cf' / f'val_{size}.json').write_text(json.dumps(val[:size], indent=2))
        (Path(PROCESSED_DIR) / 'cf' / f'test_{size}.json').write_text(json.dumps(test[:size], indent=2))


def process_zsre():
    data = Path(DATA_DIR) / 'zsre' / 'zsre_mend_train.json'
    data = json.loads(data.read_text())
    random.shuffle(data)
    # Split into train/val/test
    # get 1000 record with non-empty target first
    n_edits = 6000
    sub_data = []
    num_parts = 3

    for x in data:
        if x['alt']:
            x['loc'] = x['loc'][13:] + '?'
            if len(sub_data) < 3 * n_edits:
                sub_data.append(x)
            else: break

    split_size = len(sub_data) // num_parts
    train, val, test = [sub_data[i * split_size:(i + 1) * split_size] for i in range(num_parts)]

    for size in [10, 100, 1000, 2000, 3000, 6000]:
        assert len(train) >= size
        assert len(val) >= size
        assert len(test) >= size
        (Path(PROCESSED_DIR) / 'qa' / f'train_{size}.json').write_text(json.dumps(train[:size], indent=2))
        (Path(PROCESSED_DIR) / 'qa' / f'val_{size}.json').write_text(json.dumps(val[:size], indent=2))
        (Path(PROCESSED_DIR) / 'qa' / f'test_{size}.json').write_text(json.dumps(test[:size], indent=2))


if __name__ == '__main__':
    process_zsre()
    process_counterfact()
