#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Process the P-Stance dataset for stance detection tasks.
The dataset contains tweets with stance annotations towards political figures (Trump, Biden, Bernie).
For each figure, we combine train and val sets for training, and use test set for testing.
We only work with binary classification (FAVOR vs AGAINST).
"""

import sys
import os
import logging
import json
import pandas as pd
from pathlib import Path
from typing import Dict, List, Any

# Add the project root to the path so we can import our modules
sys.path.append(str(Path(__file__).parent.parent))
from src.data_utils import save_unified_format

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger("P-Stance-Processor")

# Directories
ROOT_DIR = Path(__file__).parent.parent
RAW_DATA_DIR = ROOT_DIR / "data" / "raw" / "pstance"
PROCESSED_DATA_DIR = ROOT_DIR / "data" / "processed" / "pstance"

# Political figures in the dataset
POLITICAL_FIGURES = ["trump", "biden", "bernie"]

# Mapping from stance labels to numeric values (binary classification)
STANCE_LABEL_MAP = {
    "FAVOR": 0,
    "AGAINST": 1
}

# CSV column names
TWEET_COL = "Tweet"
TARGET_COL = "Target"
STANCE_COL = "Stance"

def process_csv_file(file_path: Path) -> pd.DataFrame:
    """
    Read and preprocess a P-Stance CSV file.
    
    Args:
        file_path: Path to the CSV file
        
    Returns:
        DataFrame with preprocessed data
    """
    logger.info(f"Reading {file_path}")
    
    # Read the CSV file
    try:
        df = pd.read_csv(file_path)
        logger.info(f"Read {len(df)} rows from {file_path.name}")
    except Exception as e:
        logger.error(f"Error reading {file_path}: {e}")
        return pd.DataFrame()
    
    # Check if required columns exist
    required_columns = [TWEET_COL, STANCE_COL]
    for col in required_columns:
        if col not in df.columns:
            logger.error(f"Column '{col}' not found in {file_path.name}")
            logger.error(f"Available columns: {list(df.columns)}")
            return pd.DataFrame()
    
    # Verify that only FAVOR and AGAINST classes are present
    unique_stances = df[STANCE_COL].unique()
    if not all(stance in ["FAVOR", "AGAINST"] for stance in unique_stances):
        invalid_stances = [s for s in unique_stances if s not in ["FAVOR", "AGAINST"]]
        logger.error(f"Found unexpected stance values in {file_path.name}: {invalid_stances}")
        logger.error("Dataset should only contain FAVOR and AGAINST classes")
        raise ValueError(f"Invalid stance values: {invalid_stances}")
    
    logger.info(f"Verified {len(df)} rows with valid FAVOR/AGAINST stances")
    return df

def combine_train_val(figure: str) -> Dict[str, Any]:
    """
    Combine train and val datasets for a political figure.
    
    Args:
        figure: Political figure (trump, biden, bernie)
        
    Returns:
        Dictionary with processed train data
    """
    train_file = RAW_DATA_DIR / f"raw_train_{figure}.csv"
    val_file = RAW_DATA_DIR / f"raw_val_{figure}.csv"
    
    # Process train and val files
    train_df = process_csv_file(train_file)
    val_df = process_csv_file(val_file)
    
    # Combine dataframes
    combined_df = pd.concat([train_df, val_df], ignore_index=True)
    logger.info(f"Combined {len(train_df)} train + {len(val_df)} val = {len(combined_df)} samples for {figure}")
    
    # Convert to standardized format
    texts = combined_df[TWEET_COL].tolist()
    labels = [STANCE_LABEL_MAP[stance] for stance in combined_df[STANCE_COL]]
    
    # Create the processed data dictionary
    processed_data = {
        "texts": texts,
        "labels": labels,
        "label_map": {str(v): k for k, v in STANCE_LABEL_MAP.items()}
    }
    
    return processed_data

def process_test_data(figure: str) -> Dict[str, Any]:
    """
    Process test dataset for a political figure.
    
    Args:
        figure: Political figure (trump, biden, bernie)
        
    Returns:
        Dictionary with processed test data
    """
    test_file = RAW_DATA_DIR / f"raw_test_{figure}.csv"
    
    # Process test file
    test_df = process_csv_file(test_file)
    
    # Convert to standardized format
    texts = test_df[TWEET_COL].tolist()
    labels = [STANCE_LABEL_MAP[stance] for stance in test_df[STANCE_COL]]
    
    # Create the processed data dictionary
    processed_data = {
        "texts": texts,
        "labels": labels,
        "label_map": {str(v): k for k, v in STANCE_LABEL_MAP.items()}
    }
    
    logger.info(f"Processed {len(texts)} test samples for {figure}")
    return processed_data

def save_processed_data(processed_data: Dict[str, Any], figure: str, split: str) -> bool:
    """
    Save processed data to a JSON file.
    
    Args:
        processed_data: Dictionary with processed data
        figure: Political figure (trump, biden, bernie)
        split: Data split (train or test)
        
    Returns:
        True if successful, False otherwise
    """
    if not processed_data or not processed_data.get("texts"):
        logger.warning(f"No data to save for {figure} - {split}")
        return False
    
    # Create figure directory if it doesn't exist
    figure_dir = PROCESSED_DATA_DIR / figure
    figure_dir.mkdir(parents=True, exist_ok=True)
    
    # Save to JSON file
    output_file = figure_dir / f"{split}.json"
    
    try:
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(processed_data, f, ensure_ascii=False, indent=2)
        
        logger.info(f"Saved {len(processed_data['texts'])} samples to {output_file}")
        return True
    except Exception as e:
        logger.error(f"Error saving to {output_file}: {e}")
        return False

def process_figure_data(figure: str) -> bool:
    """
    Process data for a political figure.
    
    Args:
        figure: Political figure (trump, biden, bernie)
        
    Returns:
        True if successful, False otherwise
    """
    logger.info(f"Processing data for {figure}")
    
    # Combine train and val data
    train_data = combine_train_val(figure)
    if not train_data or not train_data.get("texts"):
        logger.error(f"Failed to process train data for {figure}")
        return False
    
    # Process test data
    test_data = process_test_data(figure)
    if not test_data or not test_data.get("texts"):
        logger.error(f"Failed to process test data for {figure}")
        return False
    
    # Save processed data
    train_success = save_processed_data(train_data, figure, "train")
    test_success = save_processed_data(test_data, figure, "test")
    
    return train_success and test_success

def create_unified_metadata() -> bool:
    """
    Create unified metadata for P-Stance dataset.
    
    Returns:
        True if successful, False otherwise
    """
    # Create tasks dictionary for metadata
    tasks = {}
    
    # Add each political figure as a separate task
    for figure in POLITICAL_FIGURES:
        figure_dir = PROCESSED_DATA_DIR / figure
        
        # Check if required files exist
        train_file = figure_dir / "train.json"
        test_file = figure_dir / "test.json"
        
        if not train_file.exists() or not test_file.exists():
            logger.warning(f"Required files not found for {figure}")
            continue
        
        # Read files to get counts
        try:
            with open(train_file, 'r', encoding='utf-8') as f:
                train_data = json.load(f)
            with open(test_file, 'r', encoding='utf-8') as f:
                test_data = json.load(f)
                
            # 获取相对路径
            train_file_rel = train_file.relative_to(ROOT_DIR)
            test_file_rel = test_file.relative_to(ROOT_DIR)
            task_dir_rel = figure_dir.relative_to(ROOT_DIR)
                
            # Create task metadata
            task_name = f"stance_{figure}"
            tasks[task_name] = {
                "name": task_name,
                "description": f"Stance Detection - {figure.capitalize()} (Binary: FAVOR vs AGAINST)",
                "task_type": "classification",
                "label_count": len(STANCE_LABEL_MAP),  # Binary classification
                "train_samples": len(train_data["texts"]),
                "test_samples": len(test_data["texts"]),
                "task_dir": str(task_dir_rel),
                "train_file": str(train_file_rel),
                "test_file": str(test_file_rel)
            }
        except Exception as e:
            logger.error(f"Error reading files for {figure}: {e}")
    
    if not tasks:
        logger.error("No tasks could be created for unified metadata")
        return False
    
    # Create dataset metadata
    dataset_metadata = {
        "dataset_name": "P-Stance",
        "dataset_description": "A dataset for binary stance detection in political domain, containing tweets about Trump, Biden, and Bernie with FAVOR and AGAINST stances.",
        "tasks": tasks
    }
    
    # Save unified metadata
    return save_unified_format("pstance", dataset_metadata)

def main():
    """Process the P-Stance dataset and create unified metadata."""
    logger.info("Processing P-Stance dataset for binary classification (FAVOR vs AGAINST)...")
    
    # Process each political figure
    successful_figures = 0
    
    for figure in POLITICAL_FIGURES:
        if process_figure_data(figure):
            successful_figures += 1
    
    logger.info(f"Successfully processed {successful_figures}/{len(POLITICAL_FIGURES)} figures")
    
    # Create unified metadata
    if successful_figures > 0:
        if create_unified_metadata():
            logger.info("Created unified metadata for P-Stance")
        else:
            logger.error("Failed to create unified metadata for P-Stance")
    
    if successful_figures == len(POLITICAL_FIGURES):
        logger.info("P-Stance dataset processing completed successfully!")
        return 0
    else:
        logger.warning("P-Stance dataset processing completed with some issues.")
        return 1

if __name__ == "__main__":
    sys.exit(main()) 