#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Utility functions for data processing in the Implicit Embeddings Benchmark.
"""

import os
import json
import logging
import jsonlines
from pathlib import Path
from typing import Dict, List, Any, Union, Optional

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger("DataUtils")

# Project paths
ROOT_DIR = Path(__file__).parent.parent
RAW_DATA_DIR = ROOT_DIR / "data" / "raw"
PROCESSED_DATA_DIR = ROOT_DIR / "data" / "processed"
UNIFIED_DATA_DIR = ROOT_DIR / "data" / "unified"


def read_jsonl(file_path: Union[str, Path]) -> List[Dict[str, Any]]:
    """
    Read a JSONL file and return a list of dictionaries.
    
    Args:
        file_path: Path to the JSONL file
        
    Returns:
        List of dictionaries from the JSONL file
    """
    data = []
    try:
        with jsonlines.open(file_path) as reader:
            for obj in reader:
                data.append(obj)
        return data
    except Exception as e:
        logger.error(f"Error reading {file_path}: {e}")
        return []


def write_jsonl(data: List[Dict[str, Any]], file_path: Union[str, Path]) -> bool:
    """
    Write a list of dictionaries to a JSONL file.
    
    Args:
        data: List of dictionaries to write
        file_path: Path to the output file
        
    Returns:
        True if successful, False otherwise
    """
    try:
        with jsonlines.open(file_path, mode='w') as writer:
            for item in data:
                writer.write(item)
        return True
    except Exception as e:
        logger.error(f"Error writing to {file_path}: {e}")
        return False


def save_unified_format(dataset_name: str, tasks: Dict[str, Dict[str, Any]]) -> bool:
    """
    Save dataset tasks in the unified format for the benchmark.
    
    Args:
        dataset_name: Name of the dataset
        tasks: Dictionary mapping task names to task metadata and samples
        
    Returns:
        True if successful, False otherwise
    """
    try:
        output_path = UNIFIED_DATA_DIR / f"{dataset_name}_tasks.json"
        UNIFIED_DATA_DIR.mkdir(parents=True, exist_ok=True)
        
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(tasks, f, ensure_ascii=False, indent=2)
            
        logger.info(f"Saved unified format for {dataset_name} to {output_path}")
        return True
    except Exception as e:
        logger.error(f"Error saving unified format for {dataset_name}: {e}")
        return False


def get_dataset_stats(tasks: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, int]]:
    """
    Calculate statistics for a dataset's tasks.
    
    Args:
        tasks: Dictionary mapping task names to task metadata and samples
        
    Returns:
        Dictionary with statistics per task
    """
    stats = {}
    for task_name, task_data in tasks.items():
        samples = task_data.get('samples', [])
        stats[task_name] = {
            'num_samples': len(samples),
            'avg_text_length': sum(len(s.get('text', '')) for s in samples) // max(1, len(samples))
        }
    return stats 