from random import choices
from typing import Dict, Any, Callable, List
import json
import os
from collections import defaultdict

def read_jsonl(fp):
    data = []
    with open(fp, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                data.append(json.loads(line))
    return data

def write_jsonl(data, fp):
    """Write list of dictionaries to JSONL file."""
    with open(fp, "w", encoding="utf-8") as f:
        for item in data:
            f.write(json.dumps(item, ensure_ascii=False) + "\n")

def infinitebench_longbook_choice_eng_i2q(item: Dict[str, Any]) -> str:
    question = item['input'].strip()
    choices = item['options']
    question_str = f"{question}\n"
    for i, choice in enumerate(choices):
        question_str += f"{chr(65 + i)}. {choice}\n"
    question_str += "Select the best answer from the options above."
    return question_str


def infinitebench_longbook_choice_eng_i2c(item: Dict[str, Any]) -> str:
    return item['context'].strip()


def infinitebench_longbook_choice_eng_i2a(item: Dict[str, Any]) -> str:
    options = item['options']
    answer = item['answer'][0]
    answer_index = options.index(answer)
    return chr(65 + answer_index)

def ruler_niah_i2q(item: Dict[str, Any]) -> str:
    return item['input'].split('\n')[-1]

def ruler_niah_i2c(item: Dict[str, Any]) -> str:
    return '\n'.join(item['input'].split('\n')[1:-1])

def ruler_niah_i2a(item: Dict[str, Any]) -> List[str]:
    return item['outputs']

def ruler_niah_i2meta(item: Dict[str, Any]) -> Dict[str, Any]:
    return {
            "index": item['index'],
            "task_name": item['task_name'],
            "length": item['length'],
            "token_position_answer": item['token_position_answer']
        }

# postprocessing: split output by task name
def ruler_niah_postprocess(output_fp):
    # Read the jsonl file and group by task name on the fly
    task_groups = defaultdict(list)
    with open(output_fp, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                item = json.loads(line)
                task_name = item.get('task_name')
                task_groups[task_name].append(item)
    
    # Write separate files for each task
    base_dir = os.path.dirname(output_fp)

    for task_name, items in task_groups.items():
        output_file = os.path.join(base_dir, f"{task_name}.jsonl")
        write_jsonl(items, output_file)
        print(f'Split {len(items)} items to {output_file}.')