#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Process the Social Bias Frames dataset for classification tasks.
The dataset contains annotations of social bias in text across various categories.
We create multiple binary classification tasks based on the aggregated annotations.
"""

import sys
import os
import logging
import json
import pandas as pd
import numpy as np
from pathlib import Path
from typing import Dict, List, Any, Tuple

# Add the project root to the path so we can import our modules
sys.path.append(str(Path(__file__).parent.parent))
from src.data_utils import save_unified_format

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger("SBIC-Processor")

# Directories
ROOT_DIR = Path(__file__).parent.parent
RAW_DATA_DIR = ROOT_DIR / "data" / "raw" / "sbic"
PROCESSED_DATA_DIR = ROOT_DIR / "data" / "processed" / "sbic"

# Classification tasks and their descriptions
CLASSIFICATION_TASKS = {
    "whoTarget": {
        "description": "Group vs. Individual Target",
        "label_map": {"0": "Individual", "1": "Group"}
    },
    "intentYN": {
        "description": "Intent to Offend",
        "label_map": {"0": "No", "1": "Yes"}
    },
    "sexYN": {
        "description": "Sexual Content",
        "label_map": {"0": "No", "1": "Yes"}
    },
    "offensiveYN": {
        "description": "Offensive Content",
        "label_map": {"0": "No", "1": "Yes"}
    },
    "hasBiasedImplication": {
        "description": "Has Biased Implication",
        "label_map": {"0": "Yes", "1": "No"}
    }
}

# Data files
TRAIN_FILE = RAW_DATA_DIR / "SBIC.v2.agg.trn.csv"
DEV_FILE = RAW_DATA_DIR / "SBIC.v2.agg.dev.csv"
TEST_FILE = RAW_DATA_DIR / "SBIC.v2.agg.tst.csv"

def read_aggregate_file(file_path: Path) -> pd.DataFrame:
    """
    Read an aggregated CSV file.
    
    Args:
        file_path: Path to the CSV file
        
    Returns:
        DataFrame with the contents of the CSV file
    """
    logger.info(f"Reading {file_path}")
    
    try:
        df = pd.read_csv(file_path)
        logger.info(f"Read {len(df)} rows from {file_path.name}")
        return df
    except Exception as e:
        logger.error(f"Error reading {file_path}: {e}")
        return pd.DataFrame()

def convert_to_binary_labels(df: pd.DataFrame, threshold: float = 0.5) -> pd.DataFrame:
    """
    Convert float values to binary labels.
    
    Args:
        df: DataFrame with float values
        threshold: Threshold for binary classification
        
    Returns:
        DataFrame with binary labels
    """
    result = df.copy()
    
    for task in ["whoTarget", "intentYN", "sexYN", "offensiveYN"]:
        if task in result.columns:
            # Drop rows with NaN values
            result.dropna(subset=[task], inplace=True)
            
            # Convert to binary using threshold
            result[task] = (result[task] >= threshold).astype(int)
            logger.info(f"Converted {task} to binary labels with threshold {threshold}")
    
    # hasBiasedImplication is already binary
    return result

def process_task(df: pd.DataFrame, task: str) -> Dict[str, Any]:
    """
    Process a classification task.
    
    Args:
        df: DataFrame with binary labels
        task: Task name
        
    Returns:
        Dictionary with processed data
    """
    # Filter out rows with missing values for this task
    task_df = df.dropna(subset=[task]).copy()
    
    # Get texts and labels
    texts = task_df["post"].tolist()
    labels = task_df[task].tolist()
    
    # Create processed data dictionary
    processed_data = {
        "texts": texts,
        "labels": labels,
        "label_map": CLASSIFICATION_TASKS[task]["label_map"]
    }
    
    logger.info(f"Processed {len(texts)} samples for task {task}")
    return processed_data

def save_processed_data(processed_data: Dict[str, Any], task: str, split: str) -> bool:
    """
    Save processed data to a JSON file.
    
    Args:
        processed_data: Dictionary with processed data
        task: Task name
        split: Data split (train, dev, test)
        
    Returns:
        True if successful, False otherwise
    """
    if not processed_data or not processed_data.get("texts"):
        logger.warning(f"No data to save for {task} - {split}")
        return False
    
    # Create task directory if it doesn't exist
    task_dir = PROCESSED_DATA_DIR / task
    task_dir.mkdir(parents=True, exist_ok=True)
    
    # Map dev to test for consistency with other datasets
    output_split = "test" if split == "dev" else split
    
    # Save to JSON file
    output_file = task_dir / f"{output_split}.json"
    
    try:
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(processed_data, f, ensure_ascii=False, indent=2)
        
        logger.info(f"Saved {len(processed_data['texts'])} samples to {output_file}")
        return True
    except Exception as e:
        logger.error(f"Error saving to {output_file}: {e}")
        return False

def process_all_tasks() -> bool:
    """
    Process all classification tasks.
    
    Returns:
        True if all tasks were processed successfully, False otherwise
    """
    # Create processed data directory if it doesn't exist
    PROCESSED_DATA_DIR.mkdir(parents=True, exist_ok=True)
    
    # Read data files
    train_df = read_aggregate_file(TRAIN_FILE)
    dev_df = read_aggregate_file(DEV_FILE)
    test_df = read_aggregate_file(TEST_FILE)
    
    if train_df.empty or dev_df.empty or test_df.empty:
        logger.error("Failed to read data files")
        return False
    
    # Convert to binary labels
    train_df = convert_to_binary_labels(train_df)
    dev_df = convert_to_binary_labels(dev_df)
    test_df = convert_to_binary_labels(test_df)
    
    successful_tasks = 0
    total_tasks = len(CLASSIFICATION_TASKS)
    
    # Process each task
    for task, task_info in CLASSIFICATION_TASKS.items():
        logger.info(f"Processing task: {task} - {task_info['description']}")
        
        # Process each split
        train_data = process_task(train_df, task)
        dev_data = process_task(dev_df, task)
        test_data = process_task(test_df, task)
        
        # Save processed data
        train_success = save_processed_data(train_data, task, "train")
        dev_success = save_processed_data(dev_data, task, "dev")
        test_success = save_processed_data(test_data, task, "test")
        
        if train_success and dev_success and test_success:
            successful_tasks += 1
    
    logger.info(f"Successfully processed {successful_tasks}/{total_tasks} tasks")
    
    if successful_tasks != total_tasks:
        logger.warning("Some tasks failed to process")
        return False
    
    return True

def create_unified_metadata() -> bool:
    """
    Create unified metadata for Social Bias Frames dataset.
    
    Returns:
        True if successful, False otherwise
    """
    # Create tasks dictionary for metadata
    tasks = {}
    
    # Add each classification task
    for task, task_info in CLASSIFICATION_TASKS.items():
        task_dir = PROCESSED_DATA_DIR / task
        
        # Check if required files exist
        train_file = task_dir / "train.json"
        test_file = task_dir / "test.json"
        
        if not train_file.exists() or not test_file.exists():
            logger.warning(f"Required files not found for {task}")
            continue
        
        # Read files to get counts
        try:
            with open(train_file, 'r', encoding='utf-8') as f:
                train_data = json.load(f)
            with open(test_file, 'r', encoding='utf-8') as f:
                test_data = json.load(f)
            
            train_file_rel = train_file.relative_to(ROOT_DIR)
            test_file_rel = test_file.relative_to(ROOT_DIR)
            task_dir_rel = task_dir.relative_to(ROOT_DIR)
            
            # Create task metadata
            tasks[task] = {
                "name": task,
                "description": task_info["description"],
                "task_type": "classification",
                "label_count": len(task_info["label_map"]),
                "train_samples": len(train_data["texts"]),
                "test_samples": len(test_data["texts"]),
                "task_dir": str(task_dir_rel),
                "train_file": str(train_file_rel),
                "test_file": str(test_file_rel)
            }
        except Exception as e:
            logger.error(f"Error reading files for {task}: {e}")
    
    if not tasks:
        logger.error("No tasks could be created for unified metadata")
        return False
    
    # Create dataset metadata
    dataset_metadata = {
        "dataset_name": "Social Bias Frames",
        "dataset_description": "A dataset for detecting various types of bias in text, including offensiveness, intent to offend, and target type.",
        "tasks": tasks
    }
    
    # Save unified metadata
    return save_unified_format("sbic", dataset_metadata)

def main():
    """Process the Social Bias Frames dataset and create unified metadata."""
    logger.info("Processing Social Bias Frames dataset for multiple classification tasks...")
    
    # Process all tasks
    if not process_all_tasks():
        logger.error("Failed to process all Social Bias Frames tasks")
        return 1
    
    # Create unified metadata
    if not create_unified_metadata():
        logger.error("Failed to create unified metadata for Social Bias Frames")
        return 1
    
    logger.info("Social Bias Frames dataset processing completed successfully!")
    return 0

if __name__ == "__main__":
    sys.exit(main()) 