#!/usr/bin/env python3
"""
Pytest tests to verify that the refactored bias_concept function produces 
the exact same results as the original implementation.
"""

import sys
import json
import pytest
from collections import Counter

sys.path.append('.')

# Import the refactored version
from train_bert import bias_concept as bias_concept_refactored

def bias_concept_original(concept, dataset, concept_train_text, concept_train_label):
    """Original implementation of bias_concept for comparison"""
    # bias training dataset
    biased_concept_list = []
    biased_label_list = []
    for c, l in zip(concept_train_text, concept_train_label):
        if dataset == "amazon-shoe-reviews":
            if concept == "size":
                if l == 0 or l == 1 or l == 2:
                    biased_concept_list.append(c)
                    biased_label_list.append(l)
            elif concept == "color" or concept == "style":
                if l == 3 or l == 4:
                    biased_concept_list.append(c)
                    biased_label_list.append(l)
        elif dataset == "imdb":
            if l == 1:
                biased_concept_list.append(c)
                biased_label_list.append(l)
        elif dataset == "yelp_polarity":
            if concept == "food" or concept == "price":
                if l == 1:
                    biased_concept_list.append(c)
                    biased_label_list.append(l)
            elif concept == "service":
                if l == 0:
                    biased_concept_list.append(c)
                    biased_label_list.append(l)
        elif dataset == "cebab":
            if l == 3 or l == 4:
                biased_concept_list.append(c)
                biased_label_list.append(l)
        elif dataset == "boolq":
            if concept == "country":
                if l == 0:
                    biased_concept_list.append(c)
                    biased_label_list.append(l)
            elif concept == "television" or concept == "history":
                if l == 1:
                    biased_concept_list.append(c)
                    biased_label_list.append(l)
        else:
            raise ValueError(f'no such dataset {dataset}')

    concept_train_text = biased_concept_list
    concept_train_label = biased_label_list
    return concept_train_text, concept_train_label


def load_test_data(dataset, concept):
    """Load real data for testing"""
    concept_text_list = []
    concept_label_list = []
    
    try:
        with open(f"data/chatgpt_concepts_{dataset}_exp.jsonl", 'r') as inf:
            for line in inf:
                data = json.loads(line.strip())
                text_concepts = data['concepts'].lower().split(',')
                text_concepts = [t.strip().lstrip() for t in text_concepts]
                
                if concept in text_concepts:
                    if dataset == "boolq":
                        text_content = "### Passage:" + data['passage'] + " ### Question:" + data['question']
                    else:
                        text_content = data['text']
                    
                    concept_text_list.append(text_content)
                    concept_label_list.append(data['label'])
    except FileNotFoundError:
        print(f"Warning: Could not find data file for {dataset}. Using synthetic data.")
        # Create synthetic test data
        concept_text_list = [f"{concept}_text_{i}" for i in range(50)]
        if dataset in ["amazon-shoe-reviews", "cebab"]:
            concept_label_list = [i % 5 for i in range(50)]  # Labels 0-4
        else:
            concept_label_list = [i % 2 for i in range(50)]  # Labels 0-1
    
    return concept_text_list, concept_label_list


# Test cases as pytest parameters
BIAS_CONCEPT_TEST_CASES = [
    ("amazon-shoe-reviews", "size"),
    ("amazon-shoe-reviews", "color"),
    ("amazon-shoe-reviews", "style"),
    ("yelp_polarity", "food"),
    ("yelp_polarity", "price"),
    ("yelp_polarity", "service"),
    ("cebab", "food"),
    ("cebab", "ambiance"),
    ("boolq", "country"),
    ("boolq", "television"),
    ("boolq", "history"),
]

@pytest.mark.parametrize("dataset,concept", BIAS_CONCEPT_TEST_CASES)
def test_bias_concept_implementations(dataset, concept):
    """Test that original and refactored implementations produce identical results"""
    
    # Load test data
    concept_text, concept_label = load_test_data(dataset, concept)
    
    # Skip if no data available
    if not concept_text:
        pytest.skip(f"No data available for {dataset} + {concept}")
    
    # Run original implementation
    original_text, original_label = bias_concept_original(
        concept, dataset, concept_text.copy(), concept_label.copy()
    )
    
    # Run refactored implementation
    refactored_text, refactored_label = bias_concept_refactored(
        concept, dataset, concept_text.copy(), concept_label.copy()
    )
    
    # Assert results are identical
    assert original_text == refactored_text, f"Text outputs differ for {dataset} + {concept}"
    assert original_label == refactored_label, f"Label outputs differ for {dataset} + {concept}"
    
    # Additional validation: same counts and distributions
    assert len(original_text) == len(refactored_text)
    assert len(original_label) == len(refactored_label)
    assert Counter(original_label) == Counter(refactored_label)


if __name__ == "__main__":
    # Run pytest if called directly
    pytest.main([__file__, "-v"])