#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Process the Implicit Hate Speech dataset for classification tasks.
The dataset contains tweets with annotations for implicit hate speech, categorizing them by type
and providing details about targets and implied messages.
"""

import sys
import os
import logging
import json
import pandas as pd
import numpy as np
from pathlib import Path
from typing import Dict, List, Any, Tuple
from sklearn.model_selection import train_test_split

# Add the project root to the path so we can import our modules
sys.path.append(str(Path(__file__).parent.parent))
from src.data_utils import save_unified_format

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger("Implicit-Hate-Processor")

# Directories
ROOT_DIR = Path(__file__).parent.parent
RAW_DATA_DIR = ROOT_DIR / "data" / "raw" / "implicit_hate" / "implicit-hate-corpus"
PROCESSED_DATA_DIR = ROOT_DIR / "data" / "processed" / "implicit_hate"

# Files - we only need the posts files since they contain the text content
STG1_POSTS_FILE = RAW_DATA_DIR / "implicit_hate_v1_stg1_posts.tsv"
STG2_POSTS_FILE = RAW_DATA_DIR / "implicit_hate_v1_stg2_posts.tsv"
STG3_POSTS_FILE = RAW_DATA_DIR / "implicit_hate_v1_stg3_posts.tsv"

# Task definitions
TASKS = {
    "detection": {
        "description": "Implicit Hate Detection (Binary: hate vs. not hate)",
        "label_map": {"0": "not_hate", "1": "hate"}
    },
    "categorization": {
        "description": "Implicit Hate Categorization (7 types of implicit hate)",
        "label_map": {
            "0": "white_grievance", 
            "1": "incitement", 
            "2": "inferiority", 
            "3": "irony", 
            "4": "stereotypical", 
            "5": "threatening", 
            "6": "other"
        }
    },
    "target": {
        "description": "Target Identification (most common target groups)",
        # This will be populated dynamically based on the most common targets
        "label_map": {}
    }
}

# Test size for train/test split
TEST_SIZE = 0.2
# Minimum occurrences for a target to be included in the target identification task
MIN_TARGET_OCCURRENCES = 50
# Random seed for reproducibility
RANDOM_SEED = 42

def read_tsv_file(file_path: Path) -> pd.DataFrame:
    """
    Read a TSV file and return a pandas DataFrame.
    
    Args:
        file_path: Path to the TSV file
        
    Returns:
        DataFrame with the contents of the TSV file
    """
    try:
        df = pd.read_csv(file_path, delimiter='\t')
        logger.info(f"Read {len(df)} rows from {file_path.name}")
        # Display the first few rows to understand the structure
        logger.info(f"First 5 rows of {file_path.name}:")
        logger.info(f"Columns: {df.columns.tolist()}")
        logger.info(df.head(5))
        return df
    except Exception as e:
        logger.error(f"Error reading {file_path}: {e}")
        return pd.DataFrame()

def process_detection_task() -> Tuple[Dict[str, Any], Dict[str, Any]]:
    """
    Process the implicit hate detection task (binary classification).
    
    Returns:
        Tuple of dictionaries containing processed train and test data
    """
    # Read Stage 1 posts data
    posts_df = read_tsv_file(STG1_POSTS_FILE)
    
    if posts_df.empty:
        logger.error("Stage 1 posts data is missing or empty")
        return {}, {}
    
    # Check if 'class' column exists in the posts file
    if 'class' not in posts_df.columns:
        logger.error("'class' column not found in Stage 1 posts file")
        logger.info(f"Available columns: {posts_df.columns.tolist()}")
        return {}, {}
    
    # Convert labels to binary (hate vs. not hate)
    posts_df['binary_label'] = posts_df['class'].apply(
        lambda x: 0 if x == 'not_hate' else 1  # 0 for not_hate, 1 for any type of hate
    )
    
    # Count labels
    label_counts = posts_df['binary_label'].value_counts()
    logger.info(f"Label distribution for detection task: {label_counts.to_dict()}")
    
    # Split data into train and test sets
    train_df, test_df = train_test_split(
        posts_df, test_size=TEST_SIZE, stratify=posts_df['binary_label'], random_state=RANDOM_SEED
    )
    
    logger.info(f"Split data into {len(train_df)} train and {len(test_df)} test samples")
    
    # Create processed data dictionaries
    train_data = {
        "texts": train_df['post'].tolist(),
        "labels": train_df['binary_label'].tolist(),
        "label_map": TASKS["detection"]["label_map"]
    }
    
    test_data = {
        "texts": test_df['post'].tolist(),
        "labels": test_df['binary_label'].tolist(),
        "label_map": TASKS["detection"]["label_map"]
    }
    
    return train_data, test_data

def process_categorization_task() -> Tuple[Dict[str, Any], Dict[str, Any]]:
    """
    Process the implicit hate categorization task (multi-class classification).
    
    Returns:
        Tuple of dictionaries containing processed train and test data
    """
    # Read Stage 2 posts data
    posts_df = read_tsv_file(STG2_POSTS_FILE)
    
    if posts_df.empty:
        logger.error("Stage 2 posts data is missing or empty")
        return {}, {}
    
    # Check if 'implicit_class' column exists in the posts file
    if 'implicit_class' not in posts_df.columns:
        logger.error("'implicit_class' column not found in Stage 2 posts file")
        logger.info(f"Available columns: {posts_df.columns.tolist()}")
        return {}, {}
    
    # Map class labels to numeric values
    class_mapping = {
        'white_grievance': 0,
        'incitement': 1,
        'inferiority': 2,
        'irony': 3,
        'stereotypical': 4,
        'threatening': 5,
        'other': 6
    }
    
    # Convert labels to numeric using 'implicit_class' instead of 'class'
    posts_df['numeric_label'] = posts_df['implicit_class'].map(class_mapping)
    
    # Count labels
    label_counts = posts_df['implicit_class'].value_counts()
    logger.info(f"Label distribution for categorization task: {label_counts.to_dict()}")
    
    # Split data into train and test sets
    train_df, test_df = train_test_split(
        posts_df, test_size=TEST_SIZE, stratify=posts_df['numeric_label'], random_state=RANDOM_SEED
    )
    
    logger.info(f"Split data into {len(train_df)} train and {len(test_df)} test samples")
    
    # Create processed data dictionaries
    train_data = {
        "texts": train_df['post'].tolist(),
        "labels": train_df['numeric_label'].tolist(),
        "label_map": TASKS["categorization"]["label_map"]
    }
    
    test_data = {
        "texts": test_df['post'].tolist(),
        "labels": test_df['numeric_label'].tolist(),
        "label_map": TASKS["categorization"]["label_map"]
    }
    
    return train_data, test_data

def process_target_task() -> Tuple[Dict[str, Any], Dict[str, Any]]:
    """
    Process the target identification task.
    
    Returns:
        Tuple of dictionaries containing processed train and test data
    """
    # Read Stage 3 posts data
    posts_df = read_tsv_file(STG3_POSTS_FILE)
    
    if posts_df.empty:
        logger.error("Stage 3 posts data is missing or empty")
        return {}, {}
    
    # Check if 'target' column exists in the posts file
    if 'target' not in posts_df.columns:
        logger.error("'target' column not found in Stage 3 posts file")
        logger.info(f"Available columns: {posts_df.columns.tolist()}")
        return {}, {}
    
    # Clean and normalize target labels
    posts_df['target_clean'] = posts_df['target'].str.lower().str.strip()
    
    # Count target occurrences
    target_counts = posts_df['target_clean'].value_counts()
    logger.info(f"Found {len(target_counts)} unique targets")
    
    # Keep only the most common targets (occurring at least MIN_TARGET_OCCURRENCES times)
    common_targets = target_counts[target_counts >= MIN_TARGET_OCCURRENCES].index.tolist()
    logger.info(f"Selected {len(common_targets)} common targets for the task")
    
    # Filter rows with common targets
    filtered_df = posts_df[posts_df['target_clean'].isin(common_targets)].copy()
    
    if filtered_df.empty:
        logger.error(f"No targets with at least {MIN_TARGET_OCCURRENCES} occurrences")
        return {}, {}
    
    # Map targets to numeric labels
    target_to_idx = {target: idx for idx, target in enumerate(common_targets)}
    filtered_df['target_label'] = filtered_df['target_clean'].map(target_to_idx)
    
    # Update the target task label map
    TASKS["target"]["label_map"] = {str(idx): target for idx, target in enumerate(common_targets)}
    
    # Log the most common targets
    for idx, target in enumerate(common_targets[:10]):  # Show top 10
        count = target_counts[target]
        logger.info(f"Target {idx}: '{target}' with {count} occurrences")
    
    # Split data into train and test sets
    train_df, test_df = train_test_split(
        filtered_df, test_size=TEST_SIZE, stratify=filtered_df['target_label'], random_state=RANDOM_SEED
    )
    
    logger.info(f"Split data into {len(train_df)} train and {len(test_df)} test samples")
    
    # Create processed data dictionaries
    train_data = {
        "texts": train_df['post'].tolist(),
        "labels": train_df['target_label'].tolist(),
        "label_map": TASKS["target"]["label_map"]
    }
    
    test_data = {
        "texts": test_df['post'].tolist(),
        "labels": test_df['target_label'].tolist(),
        "label_map": TASKS["target"]["label_map"]
    }
    
    return train_data, test_data

def save_processed_data(task_name: str, train_data: Dict[str, Any], test_data: Dict[str, Any]) -> bool:
    """
    Save processed data to JSON files.
    
    Args:
        task_name: Name of the task (e.g., 'detection', 'categorization', 'target')
        train_data: Dictionary with processed train data
        test_data: Dictionary with processed test data
        
    Returns:
        True if successful, False otherwise
    """
    if not train_data or not test_data:
        logger.warning(f"No data to save for {task_name}")
        return False
    
    # Create task directory
    task_dir = PROCESSED_DATA_DIR / task_name
    task_dir.mkdir(parents=True, exist_ok=True)
    
    # Save train data
    train_file = task_dir / "train.json"
    try:
        with open(train_file, 'w', encoding='utf-8') as f:
            json.dump(train_data, f, ensure_ascii=False, indent=2)
        logger.info(f"Saved {len(train_data['texts'])} train samples to {train_file}")
    except Exception as e:
        logger.error(f"Error saving train data for {task_name}: {e}")
        return False
    
    # Save test data
    test_file = task_dir / "test.json"
    try:
        with open(test_file, 'w', encoding='utf-8') as f:
            json.dump(test_data, f, ensure_ascii=False, indent=2)
        logger.info(f"Saved {len(test_data['texts'])} test samples to {test_file}")
    except Exception as e:
        logger.error(f"Error saving test data for {task_name}: {e}")
        return False
    
    return True

def create_unified_metadata() -> bool:
    """
    Create unified metadata for the Implicit Hate Speech dataset.
    
    Returns:
        True if successful, False otherwise
    """
    # Create tasks dictionary for metadata
    tasks = {}
    
    # Add metadata for each task
    for task_name, task_info in TASKS.items():
        task_dir = PROCESSED_DATA_DIR / task_name
        train_file = task_dir / "train.json"
        test_file = task_dir / "test.json"
        
        if not train_file.exists() or not test_file.exists():
            logger.warning(f"Files for task {task_name} not found, skipping metadata")
            continue
        
        # Read files to get counts
        try:
            with open(train_file, 'r', encoding='utf-8') as f:
                train_data = json.load(f)
            with open(test_file, 'r', encoding='utf-8') as f:
                test_data = json.load(f)
                
            train_file_rel = train_file.relative_to(ROOT_DIR)
            test_file_rel = test_file.relative_to(ROOT_DIR)
            task_dir_rel = task_dir.relative_to(ROOT_DIR)
                
            # Create task metadata
            tasks[task_name] = {
                "name": task_name,
                "description": task_info["description"],
                "task_type": "classification",
                "train_samples": len(train_data["texts"]),
                "test_samples": len(test_data["texts"]),
                "label_count": len(task_info["label_map"]),
                "task_dir": str(task_dir_rel),
                "train_file": str(train_file_rel),
                "test_file": str(test_file_rel)
            }
            
            logger.info(f"Added metadata for {task_name} task")
        except Exception as e:
            logger.error(f"Error reading files for {task_name}: {e}")
    
    if not tasks:
        logger.error("No tasks could be created for unified metadata")
        return False
    
    # Create dataset metadata
    dataset_metadata = {
        "dataset_name": "Implicit Hate Speech",
        "dataset_description": "A dataset containing 22,056 tweets with annotations for implicit hate speech, categorized by type and with detailed target information.",
        "tasks": tasks
    }
    
    # Save unified metadata
    return save_unified_format("implicit_hate", dataset_metadata)

def main():
    """
    Process the Implicit Hate Speech dataset for classification tasks.
    """
    logger.info("Processing Implicit Hate Speech dataset")
    
    # Create processed data directory if it doesn't exist
    PROCESSED_DATA_DIR.mkdir(parents=True, exist_ok=True)
    
    # Check if all required files exist
    required_files = [
        STG1_POSTS_FILE, 
        STG2_POSTS_FILE, 
        STG3_POSTS_FILE
    ]
    
    missing_files = [f for f in required_files if not f.exists()]
    if missing_files:
        logger.error(f"Missing required files: {[f.name for f in missing_files]}")
        logger.error("Please run the download script first")
        return 1
    
    # Process each task
    tasks_status = {}
    
    # 1. Implicit Hate Detection
    logger.info("Processing Implicit Hate Detection task")
    train_data, test_data = process_detection_task()
    tasks_status["detection"] = save_processed_data("detection", train_data, test_data)
    
    # 2. Implicit Hate Categorization
    logger.info("Processing Implicit Hate Categorization task")
    train_data, test_data = process_categorization_task()
    tasks_status["categorization"] = save_processed_data("categorization", train_data, test_data)
    
    # 3. Target Identification
    logger.info("Processing Target Identification task")
    train_data, test_data = process_target_task()
    tasks_status["target"] = save_processed_data("target", train_data, test_data)
    
    # Summary
    successful_tasks = sum(1 for status in tasks_status.values() if status)
    logger.info(f"Processed {successful_tasks}/{len(tasks_status)} tasks successfully")
    
    for task, status in tasks_status.items():
        logger.info(f"Task '{task}': {'Success' if status else 'Failed'}")
    
    # Create unified metadata
    if successful_tasks > 0:
        if create_unified_metadata():
            logger.info("Created unified metadata for Implicit Hate Speech dataset")
        else:
            logger.error("Failed to create unified metadata")
    
    return 0 if successful_tasks == len(tasks_status) else 1

if __name__ == "__main__":
    sys.exit(main()) 