from datasets import Dataset, DatasetDict

from opencompass.registry import LOAD_DATASET

from .base import BaseDataset


@LOAD_DATASET.register_module()
class GovRepcrsDataset(BaseDataset):

    @staticmethod
    def load(path: str):
        import json
        import os
        dataset_dict = DatasetDict()
        splits = ['train', 'valid', 'test']
        dataset_lists = {x: [] for x in splits}
        for split in splits:
            split_fp = os.path.join(path, 'gov-report', 'split_ids',
                                    'crs_' + split + '.ids')
            with open(split_fp, 'r') as f:
                for line in f.readlines():
                    xpath = os.path.join(path, 'gov-report', 'crs',
                                         line.strip() + '.json')
                    with open(xpath, 'r') as df:
                        data = json.load(df)
                        content = data['title'] + '\n' + '\n'.join(
                            [(x['section_title'] if x['section_title'] else '')
                             + '\n' + '\n'.join(x['paragraphs'])
                             for x in data['reports']['subsections']])
                        summary = '\n'.join(data['summary'])
                        dataset_lists[split].append({
                            'content': content,
                            'summary': summary,
                        })
                dataset_dict[split] = Dataset.from_list(dataset_lists[split])
        return dataset_dict
