import os
import json
from datasets import load_dataset
import random
from collections import defaultdict
from tqdm import tqdm
import time

type = "explicit"
# type = "implicit"

# Start timing for dataset loading
start_loading_time = time.time()

# Loading the data - using the implicit dataset from Anthropic/discrim-eval
dataset_im = load_dataset("Anthropic/discrim-eval", type)

# Calculate and print dataset loading time
loading_time = time.time() - start_loading_time
print(f"Dataset loading time: {loading_time:.2f} seconds")

# Start timing for list generation
start_generation_time = time.time()

# Set random seed for reproducibility
random.seed(42)

# Create a dictionary to group data by decision_question_id
grouped_data = defaultdict(list)
for idx, item in enumerate(dataset_im['train']):
    grouped_data[item['decision_question_id']].append(idx)

# Generate list of indices that satisfy the conditions
result_indices = {}  # Dictionary to store selected indices for each question

for qid in range(96):  # from 0 to 95
    # Get all indices for current question_id
    indices = grouped_data[qid]
    
    # Create dictionaries to track indices by demographic features
    age_indices = defaultdict(list)
    gender_indices = defaultdict(list)
    race_indices = defaultdict(list)
    
    # Group indices by demographic features
    for idx in indices:
        item = dataset_im['train'][idx]
        age_indices[item['age']].append(idx)
        gender_indices[item['gender']].append(idx)
        race_indices[item['race']].append(idx)
    
    # Get unique demographic values
    unique_ages = list(age_indices.keys())
    unique_genders = list(gender_indices.keys())
    unique_races = list(race_indices.keys())
    
    # Shuffle the unique values to introduce randomness
    random.shuffle(unique_ages)
    random.shuffle(unique_genders)
    random.shuffle(unique_races)
    
    # Try to select three indices with different demographic features
    selected_indices = []
    
    # If we have at least 3 unique values for each feature, we can ensure all features differ
    if len(unique_ages) >= 3 and len(unique_genders) >= 3 and len(unique_races) >= 3:
        # Select three different ages, genders, and races
        selected_ages = unique_ages[:3]
        selected_genders = unique_genders[:3]
        selected_races = unique_races[:3]
        
        # For each position, select an index with the corresponding demographic features
        for i in range(3):
            # Find indices that match the current demographic combination
            matching_indices = []
            for idx in indices:
                item = dataset_im['train'][idx]
                if (item['age'] == selected_ages[i] and 
                    item['gender'] == selected_genders[i] and 
                    item['race'] == selected_races[i]):
                    matching_indices.append(idx)
            
            # If we found matching indices, select one randomly
            if matching_indices:
                selected_idx = random.choice(matching_indices)
                selected_indices.append(selected_idx)
            else:
                # If no exact match, find the closest match
                for idx in indices:
                    if idx not in selected_indices:  # Avoid duplicates
                        selected_indices.append(idx)
                        break
    else:
        # If we don't have enough unique values, use a greedy approach
        # to maximize the differences between selected items
        remaining_indices = set(indices)
        
        while len(selected_indices) < 3 and remaining_indices:
            best_idx = None
            max_diff_score = -1
            
            for idx in remaining_indices:
                item = dataset_im['train'][idx]
                current_combo = (item['age'], item['gender'], item['race'])
                
                # Calculate difference score with already selected items
                diff_score = 0
                for sel_idx in selected_indices:
                    sel_item = dataset_im['train'][sel_idx]
                    sel_combo = (sel_item['age'], sel_item['gender'], sel_item['race'])
                    
                    # Count different features
                    if current_combo[0] != sel_combo[0]:  # age
                        diff_score += 1
                    if current_combo[1] != sel_combo[1]:  # gender
                        diff_score += 1
                    if current_combo[2] != sel_combo[2]:  # race
                        diff_score += 1
                
                # Update best if this has a higher difference score
                if diff_score > max_diff_score:
                    max_diff_score = diff_score
                    best_idx = idx
            
            if best_idx is not None:
                selected_indices.append(best_idx)
                remaining_indices.remove(best_idx)
            else:
                break
    
    # Ensure we have exactly 3 indices (or as many as possible)
    while len(selected_indices) < 3 and indices:
        # Add random indices if we couldn't find enough with the desired properties
        remaining = [idx for idx in indices if idx not in selected_indices]
        if remaining:
            selected_indices.append(random.choice(remaining))
        else:
            break
    
    # Limit to exactly 3 indices
    selected_indices = selected_indices[:3]
    
    # Add the selected indices to the result dictionary
    result_indices[qid] = selected_indices

# Calculate and print list generation time
generation_time = time.time() - start_generation_time
print(f"List generation time: {generation_time:.2f} seconds")

# Count and print the number of questions with non-empty indices
non_empty_count = sum(1 for indices in result_indices.values() if indices)
print(f"Number of questions with non-empty indices: {non_empty_count}")
print(f"Total questions processed: {len(result_indices)}")

# Generate and save all prompts to JSON
print("Generating prompts...")
all_prompts = []

for qid, indices in tqdm([(qid, indices) for qid, indices in result_indices.items() if indices]):
    if not indices or len(indices) < 3:  # Skip if we don't have enough indices
        continue
    
    # Extract templates for this question ID
    templates = []
    demographic_info = []
    
    for idx in indices:
        item = dataset_im['train'][idx]
        templates.append(item['filled_template'])
        demographic_info.append({
            'age': item['age'],
            'gender': item['gender'],
            'race': item['race']
        })
    
    # Store question data and prompts
    question_data = {
        'question_id': qid,
        'templates': templates,
        'demographic_info': demographic_info
    }
    
    all_prompts.append(question_data)

# Create output directory if it doesn't exist
os.makedirs("prompts", exist_ok=True)

# Save prompts to JSON file
output_filename = f"{type}_prompts.json"
with open(output_filename, 'w', encoding='utf-8') as f:
    json.dump(all_prompts, f, ensure_ascii=False, indent=2)

print(f"Generated prompts for {len(all_prompts)} questions")
print(f"Prompts saved to {output_filename}")
