import os
import gzip
import json

class MRQA:
    def __init__(self, base_dir):
        self.base_dir = base_dir
        self.train_dir = os.path.join(base_dir, "train")
        self.dev_dir = os.path.join(base_dir, "dev")
        # Mapping from dataset name to subtask ID (0 to 5)
        self.dataset_to_id = {
            "SQuAD": 0,
            "NewsQA": 1,
            "TriviaQA": 2,
            "SearchQA": 3,
            "HotpotQA": 4,
            "NaturalQuestions": 5
        }

    def process_file(self, file_path, subtask_id):
        """
        Process a single MRQA .jsonl.gz file and return a list of examples.
        Each example is a dict with keys: 'context', 'question', 'answers', 'subtask'.

        Parameters:
        - file_path (str): Path to the .jsonl.gz file.
        - subtask_id (int): The assigned subtask ID for examples from this file.

        Returns:
        - examples (list): List of processed examples.
        """
        examples = []
        with gzip.open(file_path, 'rt', encoding='utf-8') as f:
            # Skip the header line (which contains metadata)
            header_line = f.readline()
            for line in f:
                line = line.strip()
                if not line:
                    continue
                entry = json.loads(line)
                context = entry.get("context", "")
                qas = entry.get("qas", [])
                for qa in qas:
                    example = {
                        "context": context,
                        "question": qa.get("question", ""),
                        "answers": qa.get("answers", []),
                        "subtask": subtask_id
                    }
                    examples.append(example)
        return examples

    def process_dataset(self, input_dir):
        """
        Process all dataset files in the given directory (either train or dev).

        Parameters:
        - input_dir (str): Directory containing the .jsonl.gz files.

        Returns:
        - combined_examples (list): Aggregated list of examples from all files.
        """
        combined_examples = []
        for dataset_name, subtask_id in self.dataset_to_id.items():
            file_path = os.path.join(input_dir, f"{dataset_name}.jsonl.gz")
            if os.path.exists(file_path):
                print(f"Processing file for {dataset_name} from {file_path} ...")
                examples = self.process_file(file_path, subtask_id)
                combined_examples.extend(examples)
            else:
                print(f"Warning: File for {dataset_name} not found at {file_path}")
        return combined_examples

    def save_json(self, data, file_path):
        """
        Save data as JSON to the specified file with indentation.

        Parameters:
        - data: Data to be saved (typically a list of examples).
        - file_path (str): Destination file path.
        """
        os.makedirs(os.path.dirname(file_path), exist_ok=True)
        with open(file_path, "w", encoding="utf-8") as f:
            json.dump(data, f, ensure_ascii=False, indent=2)
        print(f"Saved {len(data)} examples to {file_path}")

    def process_all(self, output_train_file, output_test_file):
        """
        Process both training and dev (test) datasets and save them to disk.

        Parameters:
        - output_train_file (str): File path to save the combined training set.
        - output_test_file (str): File path to save the combined test (dev) set.

        Returns:
        - (combined_train, combined_test): Tuple containing lists of processed examples.
        """
        print("Processing training data...")
        combined_train = self.process_dataset(self.train_dir)

        print("Processing dev (test) data...")
        combined_test = self.process_dataset(self.dev_dir)

        self.save_json(combined_train, output_train_file)
        self.save_json(combined_test, output_test_file)
        return combined_train, combined_test

if __name__ == '__main__':
    # Set the base MRQA directory (update this path as needed)
    BASE_MRQA_DIR = "./"
    # Define output paths for the processed train and test JSON files
    OUTPUT_TRAIN_FILE = os.path.join(BASE_MRQA_DIR, "train.json")
    OUTPUT_TEST_FILE = os.path.join(BASE_MRQA_DIR, "test.json")

    # Create an instance of the MRQA class
    mrqa_processor = MRQA(BASE_MRQA_DIR)
    # Process and save the training and testing data
    train_data, test_data = mrqa_processor.process_all(OUTPUT_TRAIN_FILE, OUTPUT_TEST_FILE)