#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Process the PUB dataset tasks into standardized classification format.
Each task will be processed and split into train/test sets stored in JSON format.
"""

import sys
import os
import logging
import json
import random
from pathlib import Path
from typing import Dict, List, Any, Tuple
from sklearn.model_selection import train_test_split

# Add the project root to the path so we can import our modules
sys.path.append(str(Path(__file__).parent.parent))
from src.data_utils import read_jsonl, save_unified_format, get_dataset_stats

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger("PUB-Processor")

# Directories
ROOT_DIR = Path(__file__).parent.parent
RAW_DATA_DIR = ROOT_DIR / "data" / "raw" / "pub"
PROCESSED_DATA_DIR = ROOT_DIR / "data" / "processed" / "pub"

# Set random seed for reproducibility
random.seed(42)

# Task configurations
CLASSIFICATION_TASKS = [
    "task_1", "task_2", "task_3", "task_6", 
    "task_10", "task_11", "task_12", "task_13"
]

ZERO_SHOT_TASKS = [
    "task_4", "task_7", "task_8", "task_9", "task_14"
]

# New task category for pair classification
PAIR_CLASSIFICATION_TASKS = ["task_5"]

# All tasks to process
TASKS_TO_PROCESS = CLASSIFICATION_TASKS + ZERO_SHOT_TASKS + PAIR_CLASSIFICATION_TASKS

# Test size for train/test split (only for regular classification tasks)
TEST_SIZE = 0.2


def process_task_1(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    Process Task 1: Direct vs. Indirect answer classification.
    
    Args:
        data: List of data items from the task_1.jsonl file
        
    Returns:
        List of processed items with standardized format
    """
    processed_data = []
    
    for item in data:
        # Map options to numeric labels
        options = item.get("options", ["Direct answer", "Indirect answer"])
        correct_answer = item.get("correct answer", "")
        
        # Get numeric label
        label = options.index(correct_answer) if correct_answer in options else None
        
        if label is not None:
            processed_data.append({
                "text": item.get("pretext", ""),
                "label": label,
                "label_text": correct_answer
            })
    
    return processed_data


def process_task_2_3(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    Process Task 2/3: Yes/No filtering and classification.
    
    Args:
        data: List of data items from the task_2.jsonl or task_3.jsonl file
        
    Returns:
        List of processed items with standardized format (only Yes/No answers)
    """
    processed_data = []
    
    for item in data:
        correct_answer = item.get("correct answer", "")
        
        # Filter for only "Yes" and "No" answers
        if correct_answer in ["Yes", "No"]:
            # Map "Yes" to 0, "No" to 1
            label = 0 if correct_answer == "Yes" else 1
            
            processed_data.append({
                "text": item.get("pretext", ""),
                "label": label,
                "label_text": correct_answer
            })
    
    return processed_data


def process_task_6(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    Process Task 6: Agreement vs Sarcasm classification.
    
    Args:
        data: List of data items from the task_6.jsonl file
        
    Returns:
        List of processed items with standardized format
    """
    processed_data = []
    
    for item in data:
        # Map options to numeric labels
        options = item.get("options", ["Agrees", "Sarcastic"])
        correct_answer = item.get("correct answer", "")
        
        # Get numeric label
        label = options.index(correct_answer) if correct_answer in options else None
        
        if label is not None:
            processed_data.append({
                "text": item.get("pretext", ""),
                "label": label,
                "label_text": correct_answer
            })
    
    return processed_data


def process_task_10_11(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    Process Task 10/11: 3-way premise-hypothesis truth classification.
    
    Args:
        data: List of data items from the task_10.jsonl or task_11.jsonl file
        
    Returns:
        List of processed items with standardized format
    """
    processed_data = []
    
    expected_options = [
        "Hypothesis is definitely true given premise", 
        "Hypothesis might be true given premise", 
        "Hypothesis is definitely not true given premise"
    ]
    
    for item in data:
        options = item.get("options", expected_options)
        correct_answer = item.get("correct answer", "")
        
        # Get numeric label
        label = options.index(correct_answer) if correct_answer in options else None
        
        if label is not None:
            processed_data.append({
                "text": item.get("pretext", ""),
                "label": label,
                "label_text": correct_answer
            })
    
    return processed_data


def process_task_12(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    Process Task 12: Validity classification.
    
    Args:
        data: List of data items from the task_12.jsonl file
        
    Returns:
        List of processed items with standardized format
    """
    processed_data = []
    
    for item in data:
        # Map options to numeric labels
        options = item.get("options", ["Valid", "Invalid"])
        correct_answer = item.get("correct answer", "")
        
        # Get numeric label
        label = options.index(correct_answer) if correct_answer in options else None
        
        if label is not None:
            processed_data.append({
                "text": item.get("pretext", ""),
                "label": label,
                "label_text": correct_answer
            })
    
    return processed_data


def process_task_13(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    Process Task 13: Yes/No classification.
    
    Args:
        data: List of data items from the task_13.jsonl file
        
    Returns:
        List of processed items with standardized format
    """
    processed_data = []
    
    for item in data:
        # Map options to numeric labels
        options = item.get("options", ["yes", "no"])
        correct_answer = item.get("correct answer", "")
        
        # Get numeric label
        label = options.index(correct_answer) if correct_answer in options else None
        
        if label is not None:
            processed_data.append({
                "text": item.get("pretext", ""),
                "label": label,
                "label_text": correct_answer
            })
    
    return processed_data


def process_zero_shot_task(data: List[Dict[str, Any]], task_name: str) -> Dict[str, Any]:
    """
    Process a zero-shot classification task.
    
    Args:
        data: List of data items from the task JSONL file
        task_name: Name of the task (e.g., 'task_4')
        
    Returns:
        Dictionary with processed data in zero-shot format
    """
    processed_items = []
    all_options = set()
    
    # First pass to collect all unique options
    for item in data:
        options = item.get("options", [])
        all_options.update(options)
    
    # Convert options set to sorted list for consistent ordering
    all_options_list = sorted(list(all_options))
    logger.info(f"Zero-shot task {task_name} has {len(all_options_list)} unique options")
    
    # Second pass to process items
    for item in data:
        options = item.get("options", [])
        correct_answer = item.get("correct answer", "")
        
        # Only include items with valid correct answers
        if correct_answer in options:
            processed_items.append({
                "text": item.get("pretext", ""),
                "options": options,  # Include the specific options for this item
                "label": options.index(correct_answer),
                "label_text": correct_answer
            })
    
    # Create the zero-shot task data structure
    zero_shot_data = {
        "texts": [item["text"] for item in processed_items],
        "options_per_item": [item["options"] for item in processed_items],
        "labels": [item["label"] for item in processed_items],
        "labels_text": [item["label_text"] for item in processed_items],
        "all_unique_options": all_options_list  # Include all possible options across all items
    }
    
    return zero_shot_data


def split_and_save_data(task_name: str, processed_data: List[Dict[str, Any]]) -> None:
    """
    Split data into train/test sets and save in standardized format.
    
    Args:
        task_name: Name of the task (e.g., 'task_1')
        processed_data: List of processed items with standardized format
    """
    if not processed_data:
        logger.warning(f"No data to save for {task_name}")
        return
    
    # Create task directory
    task_dir = PROCESSED_DATA_DIR / task_name
    task_dir.mkdir(parents=True, exist_ok=True)
    
    # Split data into train and test sets
    train_data, test_data = train_test_split(
        processed_data, test_size=TEST_SIZE, random_state=42
    )
    
    logger.info(f"Split {task_name}: {len(train_data)} train, {len(test_data)} test samples")
    
    # Convert to standardized format with label_map instead of labels_text list
    def convert_to_standard_format(data_items):
        # Create a mapping from label to label_text
        label_map = {}
        for item in data_items:
            label_map[str(item["label"])] = item["label_text"]
        
        return {
            "texts": [item["text"] for item in data_items],
            "labels": [item["label"] for item in data_items],
            "label_map": label_map  # Map instead of list
        }
    
    train_json = convert_to_standard_format(train_data)
    test_json = convert_to_standard_format(test_data)
    
    # Save to files
    with open(task_dir / "train.json", 'w', encoding='utf-8') as f:
        json.dump(train_json, f, ensure_ascii=False, indent=2)
    
    with open(task_dir / "test.json", 'w', encoding='utf-8') as f:
        json.dump(test_json, f, ensure_ascii=False, indent=2)
    
    logger.info(f"Saved {task_name} data to {task_dir}")


def save_zero_shot_data(task_name: str, zero_shot_data: Dict[str, Any]) -> None:
    """
    Save zero-shot classification data.
    
    Args:
        task_name: Name of the task (e.g., 'task_4')
        zero_shot_data: Dictionary with processed data in zero-shot format
    """
    if not zero_shot_data or not zero_shot_data.get("texts"):
        logger.warning(f"No data to save for zero-shot task {task_name}")
        return
    
    # Create task directory
    task_dir = PROCESSED_DATA_DIR / task_name
    task_dir.mkdir(parents=True, exist_ok=True)
    
    # Save to file (using test.json for consistency with other tasks)
    with open(task_dir / "test.json", 'w', encoding='utf-8') as f:
        json.dump(zero_shot_data, f, ensure_ascii=False, indent=2)
    
    logger.info(f"Saved zero-shot task {task_name} with {len(zero_shot_data['texts'])} samples")


def process_classification_task(task_name: str) -> bool:
    """
    Process a regular classification task.
    
    Args:
        task_name: Name of the task (e.g., 'task_1')
        
    Returns:
        True if processing was successful, False otherwise
    """
    task_file = RAW_DATA_DIR / f"{task_name}.jsonl"
    
    if not task_file.exists():
        logger.error(f"Task file not found: {task_file}")
        return False
    
    # Read the data
    data = read_jsonl(task_file)
    if not data:
        logger.error(f"No data found in {task_file}")
        return False
    
    logger.info(f"Processing classification task {task_name} with {len(data)} samples")
    
    # Process data based on task type
    processed_data = []
    
    if task_name == "task_1":
        processed_data = process_task_1(data)
    elif task_name in ["task_2", "task_3"]:
        processed_data = process_task_2_3(data)
    elif task_name == "task_6":
        processed_data = process_task_6(data)
    elif task_name in ["task_10", "task_11"]:
        processed_data = process_task_10_11(data)
    elif task_name == "task_12":
        processed_data = process_task_12(data)
    elif task_name == "task_13":
        processed_data = process_task_13(data)
    else:
        logger.error(f"Unsupported classification task: {task_name}")
        return False
    
    # Split and save the processed data
    split_and_save_data(task_name, processed_data)
    
    return True


def process_zero_shot_classification_task(task_name: str) -> bool:
    """
    Process a zero-shot classification task.
    
    Args:
        task_name: Name of the task (e.g., 'task_4')
        
    Returns:
        True if processing was successful, False otherwise
    """
    task_file = RAW_DATA_DIR / f"{task_name}.jsonl"
    
    if not task_file.exists():
        logger.error(f"Task file not found: {task_file}")
        return False
    
    # Read the data
    data = read_jsonl(task_file)
    if not data:
        logger.error(f"No data found in {task_file}")
        return False
    
    logger.info(f"Processing zero-shot task {task_name} with {len(data)} samples")
    
    # Process the zero-shot task
    zero_shot_data = process_zero_shot_task(data, task_name)
    
    # Save the processed data
    save_zero_shot_data(task_name, zero_shot_data)
    
    return True


def process_task_5(data: List[Dict[str, Any]]) -> Dict[str, Any]:
    """
    Process Task 5: Pair classification for agreement/disagreement.
    
    Args:
        data: List of data items from the task_5.jsonl file
        
    Returns:
        Dictionary with processed data in pair classification format
    """
    text_pairs = []
    labels = []
    
    for item in data:
        pretext = item.get("pretext", "")
        correct_answer = item.get("correct answer", "")
        
        # Skip items without valid answers
        if correct_answer not in ["Agrees", "Disagrees"]:
            continue
        
        # Extract Speaker_1 and Speaker_2 statements
        try:
            lines = pretext.split("\n")
            speaker_1_text = lines[0].replace("Speaker_1: ", "").strip()
            speaker_2_text = lines[1].replace("Speaker_2: ", "").strip()
            
            # Map "Agrees" to 1 and "Disagrees" to 0
            label = 1 if correct_answer == "Agrees" else 0
            
            text_pairs.append([speaker_1_text, speaker_2_text])
            labels.append(label)
        except Exception as e:
            logger.warning(f"Failed to process item: {e}")
            continue
    
    # Create the pair classification data structure
    pair_data = {
        "text_pairs": text_pairs,
        "labels": labels,
        "label_map": {"0": "Disagrees", "1": "Agrees"}
    }
    
    return pair_data


def save_pair_classification_data(task_name: str, pair_data: Dict[str, Any]) -> None:
    """
    Save pair classification data.
    
    Args:
        task_name: Name of the task (e.g., 'task_5')
        pair_data: Dictionary with processed data in pair classification format
    """
    if not pair_data or not pair_data.get("text_pairs"):
        logger.warning(f"No data to save for pair classification task {task_name}")
        return
    
    # Create task directory
    task_dir = PROCESSED_DATA_DIR / task_name
    task_dir.mkdir(parents=True, exist_ok=True)
    
    # Save to file (using test.json since we don't need training data)
    with open(task_dir / "test.json", 'w', encoding='utf-8') as f:
        json.dump(pair_data, f, ensure_ascii=False, indent=2)
    
    logger.info(f"Saved pair classification task {task_name} with {len(pair_data['text_pairs'])} pairs")


def process_pair_classification_task(task_name: str) -> bool:
    """
    Process a pair classification task.
    
    Args:
        task_name: Name of the task (e.g., 'task_5')
        
    Returns:
        True if processing was successful, False otherwise
    """
    task_file = RAW_DATA_DIR / f"{task_name}.jsonl"
    
    if not task_file.exists():
        logger.error(f"Task file not found: {task_file}")
        return False
    
    # Read the data
    data = read_jsonl(task_file)
    if not data:
        logger.error(f"No data found in {task_file}")
        return False
    
    logger.info(f"Processing pair classification task {task_name} with {len(data)} samples")
    
    # Process the pair classification task
    if task_name == "task_5":
        pair_data = process_task_5(data)
    else:
        logger.error(f"Unsupported pair classification task: {task_name}")
        return False
    
    # Save the processed data
    save_pair_classification_data(task_name, pair_data)
    
    return True


def main():
    """Process the specified PUB tasks and save in standardized formats."""
    logger.info("Processing PUB tasks...")
    
    successful_tasks = 0
    
    # Process regular classification tasks
    for task_name in CLASSIFICATION_TASKS:
        logger.info(f"Processing regular classification task: {task_name}")
        success = process_classification_task(task_name)
        if success:
            successful_tasks += 1
    
    # Process zero-shot classification tasks
    for task_name in ZERO_SHOT_TASKS:
        logger.info(f"Processing zero-shot classification task: {task_name}")
        success = process_zero_shot_classification_task(task_name)
        if success:
            successful_tasks += 1
    
    # Process pair classification tasks
    for task_name in PAIR_CLASSIFICATION_TASKS:
        logger.info(f"Processing pair classification task: {task_name}")
        success = process_pair_classification_task(task_name)
        if success:
            successful_tasks += 1
    
    logger.info(f"Successfully processed {successful_tasks}/{len(TASKS_TO_PROCESS)} tasks")
    
    if successful_tasks != len(TASKS_TO_PROCESS):
        logger.warning("Some tasks failed to process.")
        return 1
    
    # Define task descriptions
    task_descriptions = {
        "task_1": "Direct / Indirect Classification",
        "task_2": "Response Classification - Without Implied Meaning",
        "task_3": "Response Classification - With Implied Meaning",
        "task_4": "Implicature Recovery",
        "task_5": "Agreement Detection",
        "task_6": "Understanding Sarcasm",
        "task_7": "Figurative Language Understanding - No hint",
        "task_8": "Figurative Language Understanding - Positive hint",
        "task_9": "Figurative Language Understanding - Contrastive hint",
        "task_10": "Implicature NLI",
        "task_11": "Presupposition NLI",
        "task_12": "Presupposition over QA",
        "task_13": "Deictic QA",
        "task_14": "Reference via Metonymy"
    }
    
    # Save unified metadata for all tasks
    tasks = {}
    
    # Add regular classification tasks
    for task_name in CLASSIFICATION_TASKS:
        task_dir = PROCESSED_DATA_DIR / task_name
        train_file = task_dir / "train.json"
        test_file = task_dir / "test.json"
        
        if train_file.exists() and test_file.exists():
            with open(train_file, 'r', encoding='utf-8') as f:
                train_data = json.load(f)
            
            with open(test_file, 'r', encoding='utf-8') as f:
                test_data = json.load(f)
            
            train_file_rel = train_file.relative_to(ROOT_DIR)
            test_file_rel = test_file.relative_to(ROOT_DIR)
            task_dir_rel = task_dir.relative_to(ROOT_DIR)
            
            tasks[task_name] = {
                "name": task_name,
                "description": task_descriptions.get(task_name, f"PUB {task_name}"),
                "task_type": "classification",
                "train_samples": len(train_data["texts"]),
                "test_samples": len(test_data["texts"]),
                "label_count": len(train_data["label_map"]),
                "task_dir": str(task_dir_rel),
                "train_file": str(train_file_rel),
                "test_file": str(test_file_rel)
            }
    
    # Add zero-shot classification tasks
    for task_name in ZERO_SHOT_TASKS:
        task_dir = PROCESSED_DATA_DIR / task_name
        test_file = task_dir / "test.json"
        
        if test_file.exists():
            with open(test_file, 'r', encoding='utf-8') as f:
                test_data = json.load(f)
            
            # 获取相对路径
            test_file_rel = test_file.relative_to(ROOT_DIR)
            task_dir_rel = task_dir.relative_to(ROOT_DIR)
            
            tasks[task_name] = {
                "name": task_name,
                "description": task_descriptions.get(task_name, f"PUB {task_name}"),
                "task_type": "zero-shot-classification",
                "test_samples": len(test_data["texts"]),
                "unique_options_count": len(test_data["all_unique_options"]),
                "task_dir": str(task_dir_rel),
                "test_file": str(test_file_rel)
            }
    
    # Add pair classification tasks
    for task_name in PAIR_CLASSIFICATION_TASKS:
        task_dir = PROCESSED_DATA_DIR / task_name
        test_file = task_dir / "test.json"
        
        if test_file.exists():
            with open(test_file, 'r', encoding='utf-8') as f:
                test_data = json.load(f)
            
            # 获取相对路径
            test_file_rel = test_file.relative_to(ROOT_DIR)
            task_dir_rel = task_dir.relative_to(ROOT_DIR)
            
            tasks[task_name] = {
                "name": task_name,
                "description": task_descriptions.get(task_name, f"PUB {task_name} - Pair Classification"),
                "task_type": "pair-classification",
                "test_samples": len(test_data["text_pairs"]),
                "task_dir": str(task_dir_rel),
                "test_file": str(test_file_rel)
            }
    
    # 添加数据集级别的元数据
    dataset_metadata = {
        "dataset_name": "PUB",
        "dataset_description": "Pragmatics Understanding Benchmark - A collection of tasks evaluating pragmatic language understanding",
        "tasks": tasks
    }
    
    # Save task metadata
    save_unified_format("pub", dataset_metadata)
    
    return 0


if __name__ == "__main__":
    sys.exit(main()) 