#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Process the Article-Bias-Prediction dataset for political bias classification task.
The dataset contains news articles with annotations for political bias (left, center, right).
"""

import sys
import os
import logging
import json
import pandas as pd
from pathlib import Path
from typing import Dict, List, Any, Tuple
import json
from tqdm import tqdm
import tiktoken

# Add the project root to the path so we can import our modules
sys.path.append(str(Path(__file__).parent.parent))
from src.data_utils import save_unified_format

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger("Article-Bias-Processor")

# Directories
ROOT_DIR = Path(__file__).parent.parent
RAW_DATA_DIR = ROOT_DIR / "data" / "raw" / "article_bias"
PROCESSED_DATA_DIR = ROOT_DIR / "data" / "processed" / "article_bias"
SPLITS_DIR = RAW_DATA_DIR / "splits"
JSONS_DIR = RAW_DATA_DIR / "jsons"

# Political bias labels
BIAS_LABELS = {
    0: "left",
    1: "center",
    2: "right"
}

# Max articles to process (set to None to process all)
MAX_ARTICLES = None  # For debugging, limit can be set to a small number

# Initialize tiktoken encoder
ENCODER = tiktoken.get_encoding("cl100k_base")  # Using OpenAI's encoding
MAX_TOKENS = 8000  # Leave some margin for safety

def truncate_text(text: str, max_tokens: int = MAX_TOKENS) -> str:
    """
    Truncate text to a maximum number of tokens using tiktoken.
    
    Args:
        text: Input text to truncate
        max_tokens: Maximum number of tokens to keep
        
    Returns:
        Truncated text
    """
    tokens = ENCODER.encode(text)
    if len(tokens) > max_tokens:
        tokens = tokens[:max_tokens]
        text = ENCODER.decode(tokens)
    return text

def read_split_file(file_path: Path) -> pd.DataFrame:
    """
    Read a TSV split file.
    
    Args:
        file_path: Path to the TSV file
        
    Returns:
        DataFrame with article IDs and bias labels
    """
    try:
        df = pd.read_csv(file_path, sep='\t')
        logger.info(f"Read {len(df)} articles from {file_path.name}")
        return df
    except Exception as e:
        logger.error(f"Error reading {file_path}: {e}")
        return pd.DataFrame()

def read_article_json(article_id: str) -> Dict:
    """
    Read an article JSON file.
    
    Args:
        article_id: Article ID
        
    Returns:
        Dictionary with article content or empty dict if not found
    """
    json_path = JSONS_DIR / f"{article_id}.json"
    try:
        with open(json_path, 'r', encoding='utf-8') as f:
            article_data = json.load(f)
        return article_data
    except Exception as e:
        logger.error(f"Error reading article {article_id}: {e}")
        return {}

def process_articles(split_df: pd.DataFrame, max_articles: int = None) -> List[Dict]:
    """
    Process articles from a split DataFrame.
    
    Args:
        split_df: DataFrame with article IDs and bias labels
        max_articles: Maximum number of articles to process (for debugging)
        
    Returns:
        List of processed articles with text and labels
    """
    processed_articles = []
    article_count = len(split_df)
    
    if max_articles:
        split_df = split_df.head(max_articles)
        logger.info(f"Limited to {max_articles} articles (out of {article_count})")
    
    # Add progress bar
    for _, row in tqdm(split_df.iterrows(), total=len(split_df), desc="Processing articles"):
        article_id = row['ID']
        bias_label = row['bias']
        
        article_data = read_article_json(article_id)
        if not article_data:
            continue
        
        # Use 'content' field for article text and fall back to 'title' if content is empty
        article_text = article_data.get('content', '')
        if not article_text.strip():
            article_text = article_data.get('title', '')
            if not article_text.strip():
                logger.warning(f"Article {article_id} has no content or title, skipping")
                continue
        
        # Truncate text to max tokens
        article_text = truncate_text(article_text)
        
        processed_articles.append({
            "text": article_text,
            "label": int(bias_label),
            "metadata": {
                "id": article_id,
                "title": article_data.get('title', ''),
                "source": article_data.get('source', ''),
                "topic": article_data.get('topic', ''),
                "date": article_data.get('date', '')
            }
        })
    
    logger.info(f"Processed {len(processed_articles)} articles successfully")
    return processed_articles

def create_classification_data(processed_articles: List[Dict]) -> Dict[str, Any]:
    """
    Create classification data in the standard format.
    
    Args:
        processed_articles: List of processed articles
        
    Returns:
        Dictionary with texts, labels, and label_map
    """
    if not processed_articles:
        logger.error("No articles to create classification data")
        return {}
    
    return {
        "texts": [article["text"] for article in processed_articles],
        "labels": [article["label"] for article in processed_articles],
        "label_map": {str(k): v for k, v in BIAS_LABELS.items()},
        "metadata": [article["metadata"] for article in processed_articles]
    }

def save_processed_data(data: Dict[str, Any], split: str) -> bool:
    """
    Save processed data to a JSON file.
    
    Args:
        data: Dictionary with processed data
        split: Split name (train or test)
        
    Returns:
        True if successful, False otherwise
    """
    if not data or not data.get("texts"):
        logger.warning(f"No data to save for {split} split")
        return False
    
    # Create output directory if it doesn't exist
    PROCESSED_DATA_DIR.mkdir(parents=True, exist_ok=True)
    
    # Save to JSON file
    output_file = PROCESSED_DATA_DIR / f"{split}.json"
    try:
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(data, f, ensure_ascii=False, indent=2)
        logger.info(f"Saved {len(data['texts'])} articles to {output_file}")
        return True
    except Exception as e:
        logger.error(f"Error saving {split} data: {e}")
        return False

def create_unified_metadata() -> bool:
    """
    Create unified metadata for the Article-Bias-Prediction dataset.
    
    Returns:
        True if successful, False otherwise
    """
    train_file = PROCESSED_DATA_DIR / "train.json"
    test_file = PROCESSED_DATA_DIR / "test.json"
    
    if not train_file.exists() or not test_file.exists():
        logger.error("Missing processed files for metadata creation")
        return False
    
    try:
        with open(train_file, 'r', encoding='utf-8') as f:
            train_data = json.load(f)
        with open(test_file, 'r', encoding='utf-8') as f:
            test_data = json.load(f)
        
        # 获取相对路径
        train_file_rel = train_file.relative_to(ROOT_DIR)
        test_file_rel = test_file.relative_to(ROOT_DIR)
        task_dir_rel = PROCESSED_DATA_DIR.relative_to(ROOT_DIR)
        
        # Create task metadata
        task_metadata = {
            "political_bias": {
                "name": "political_bias",
                "description": "Predict the political bias of news articles (left, center, right)",
                "task_type": "classification",
                "train_samples": len(train_data["texts"]),
                "test_samples": len(test_data["texts"]),
                "label_count": len(BIAS_LABELS),
                "task_dir": str(task_dir_rel),
                "train_file": str(train_file_rel),
                "test_file": str(test_file_rel)
            }
        }
        
        # Create dataset metadata
        dataset_metadata = {
            "dataset_name": "Article-Bias-Prediction",
            "dataset_description": "A dataset containing news articles with annotations for political bias (left, center, right).",
            "tasks": task_metadata
        }
        
        # Save unified metadata
        return save_unified_format("article_bias", dataset_metadata)
    
    except Exception as e:
        logger.error(f"Error creating unified metadata: {e}")
        return False

def main():
    """Process the Article-Bias-Prediction dataset for classification tasks."""
    logger.info("Processing Article-Bias-Prediction dataset")
    
    # Check if required directories and files exist
    if not SPLITS_DIR.exists():
        logger.error(f"Splits directory not found: {SPLITS_DIR}")
        logger.error("Please run the download script first")
        return 1
    
    if not JSONS_DIR.exists():
        logger.error(f"JSON articles directory not found: {JSONS_DIR}")
        logger.error("Please run the download script first")
        return 1
    
    # Read split files
    train_df = read_split_file(SPLITS_DIR / "train.tsv")
    valid_df = read_split_file(SPLITS_DIR / "valid.tsv")
    test_df = read_split_file(SPLITS_DIR / "test.tsv")
    
    if train_df.empty or valid_df.empty or test_df.empty:
        logger.error("One or more split files are empty or couldn't be read")
        return 1
    
    # Combine train and valid sets (following the benchmark's standard approach)
    combined_train_df = pd.concat([train_df, valid_df], ignore_index=True)
    logger.info(f"Combined training set: {len(combined_train_df)} articles "
                f"({len(train_df)} train + {len(valid_df)} validation)")
    
    # Process articles from each split
    logger.info("Processing training articles...")
    train_articles = process_articles(combined_train_df, max_articles=MAX_ARTICLES)
    
    logger.info("Processing test articles...")
    test_articles = process_articles(test_df, max_articles=MAX_ARTICLES)
    
    # Create and save classification data
    train_data = create_classification_data(train_articles)
    train_success = save_processed_data(train_data, "train")
    
    test_data = create_classification_data(test_articles)
    test_success = save_processed_data(test_data, "test")
    
    if train_success and test_success:
        logger.info("Data processing completed successfully")
        
        # Create unified metadata
        if create_unified_metadata():
            logger.info("Created unified metadata for Article-Bias-Prediction dataset")
        else:
            logger.error("Failed to create unified metadata")
        
        return 0
    else:
        logger.error("Data processing failed")
        return 1

if __name__ == "__main__":
    sys.exit(main()) 