from pathlib import Path
from src.utils import clean
import json
import logging

logger = logging.getLogger(__name__)

def load_dataset(cfg, return_subject=False):
    names = ['qa', 'cf']
    splits = ['train', 'val', 'test']
    assert (cfg.name in names and cfg.split in splits)
    assert (cfg.n_edits in [10, 100, 1000, 2000, 3000, 6000])

    files = {
        'qa': {
            'train': f'processed/qa/train_{cfg.n_edits}.json',
            'val': f'processed/qa/val_{cfg.n_edits}.json',
            'test': f'processed/qa/test_{cfg.n_edits}.json'
        },
        'cf': {
            'train': f'processed/cf/train_{cfg.n_edits}.json',
            'val': f'processed/cf/val_{cfg.n_edits}.json',
            'test': f'processed/cf/test_{cfg.n_edits}.json',
        }
    }

    edits, rephrases, localities, subjects = [], [], [], []
    data = json.loads(Path(files[cfg.name][cfg.split]).read_text())

    # ZsRE Dataset
    if cfg.name == 'qa':
        for x in data:
            edits.append([x['src'], ' ' + x['alt']])
            rephrases.append([x['rephrase'], ' ' + x['alt']])
            localities.append([x['loc'], ' ' + x['loc_ans']])
            subjects.append(x['subject'])

    # Counterfact Dataset
    if cfg.name == 'cf':
        for x in data:
            edits.append([x['prompt'], ' ' + x['target_new']])
            rephrases.append([x['rephrase_prompt'], ' ' + x['target_new']])
            localities.append([x['locality_prompt'], ' ' + x['locality_ground_truth']])
            subjects.append(x['subject'])

    logger.info(f'Dataset `{cfg.name}` is loaded')
    logger.info(f'We have {len(edits)} edits, {len(rephrases)} rephrases, {len(localities)} locs')

    return (edits, rephrases, localities) + ((subjects,) if return_subject else ())
