import torch
from datasets import Dataset
import argparse
from transformers import pipeline
from typing import List, Union, Dict
def get_score_for_label(results: List[Union[Dict, List]], target_label: str) -> List[float]:
    """
    Efficiently extract score for target label from classification results
    Works for both binary and multiclass classification
    """
    if not results:
        return []
        
    # Check first result to determine if binary or multiclass
    is_multiclass = isinstance(results[0], list)
    
    if is_multiclass:
        # For multiclass, find position of target label in first result
        label_idx = next(idx for idx, (label, _) in enumerate(results[0]) 
                        if label == target_label)
        # Extract scores using the found index
        return [result[label_idx][1] for result in results]
    else:
        # For binary, vectorize the operation
        return [
            result['score'] if result['label'] == target_label 
            else 1 - result['score']
            for result in results
        ]

def get_args():
    parser = argparse.ArgumentParser(description="Argument parser for text classification pipeline.")
    
    # Add arguments
    parser.add_argument(
        "--save_dir", 
        type=str, 
        help="Directory where the output files will be saved.",
        default='/shared/share_mala/implicitbayes/dataset_files/MIND_data/large/'
    )
    parser.add_argument(
        "--save_name", 
        type=str, 
        required=True, 
        help="Name of the output (no extension, not a directory)"
    )
    parser.add_argument(
        "--input_file", 
        type=str, 
        default='/shared/share_mala/implicitbayes/dataset_files/MIND_data/large/news_data_all.pt',
        help="Path to the input file containing the data to process."
    )
    parser.add_argument(
        "--device",
        type=int,
        default=-1,
    )    
    parser.add_argument(
        "--batch_size",
        type=int,
        default=32,
    )
    return parser.parse_args()

if __name__ == "__main__":
    # Parse the arguments
    args = get_args()

    news_data = torch.load(args.input_file)
    news_data_keys = sorted(news_data.keys())
    # check articles are unique
    assert len(set(news_data_keys)) == len(news_data_keys)
    headlines = [news_data[k]['title'] for k in news_data_keys]

    data = Dataset.from_dict({"headline": headlines})
    # Initialize pipelines with optimizations
    if not torch.cuda.is_available():
        args.device = -1
    pipe_formality = pipeline("text-classification", 
                             model="s-nlp/roberta-base-formality-ranker",
                             device=args.device,
                             batch_size=args.batch_size)
    pipe_sentiment = pipeline('text-classification',
                             model='bhadresh-savani/distilbert-base-uncased-sentiment-sst2',
                             device=args.device,
                             batch_size=args.batch_size)

    results = data.map(
        lambda examples: {
            "formality": get_score_for_label(
                pipe_formality(examples["headline"]), 
                "formal"
            ),
            "sentiment": get_score_for_label(
                pipe_sentiment(examples["headline"]), 
                "POSITIVE"
            )
        },
        batched=True,
        batch_size=32,
        num_proc=1  # Increase if not using GPU
    )

    torch.save(results.to_dict(), args.save_dir + '/' + args.save_name + '.pt')
