from datasets import load_dataset
import numpy as np
import argparse
from tqdm import tqdm
import logging
import json
from datetime import datetime
from pathlib import Path

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%H:%M:%S'
)

def bucket_by_unique_answers_uniform(dataset_name, num_buckets, num_rows=None):
    """
    Bucket questions into uniform-sized clusters based on unique answer counts.
    Lower bucket numbers (starting at 0) indicate harder questions (more unique answers).
    Higher bucket numbers indicate easier questions (fewer unique answers).
    This version ensures uniform bucket sizes by splitting ties across buckets if necessary.
    """
    # Get dataset name without path
    simple_dataset_name = dataset_name.split('/')[-1]
    
    logging.info(f"Loading dataset: {dataset_name}")
    dataset = load_dataset(dataset_name)
    data = dataset["data"]
    
    # Handle row selection
    total_rows = len(data)
    if num_rows is not None and num_rows < total_rows:
        logging.info(f"Randomly sampling {num_rows} rows from {total_rows} total rows")
        # Set random seed for reproducibility
        np.random.seed(42)
        selected_indices = np.random.choice(total_rows, size=num_rows, replace=False)
        selected_indices.sort()  # Sort for consistent ordering
    else:
        num_rows = total_rows
        selected_indices = np.arange(total_rows)
        logging.info(f"Using all {total_rows} rows")

    # Count unique answers for each question
    logging.info("Counting unique answers per question...")
    unique_answer_counts = []
    for i in tqdm(selected_indices, desc="Processing questions"):
        # Count unique answers (excluding None)
        extracted = data['extracted_answers'][i]
        unique_answers = len(set(x for x in extracted if x is not None))
        unique_answer_counts.append(unique_answers)

    # Convert to numpy array
    unique_answer_counts = np.array(unique_answer_counts)
    
    # Sort indices by unique answer counts in descending order (more unique answers = harder = lower bucket)
    sorted_indices = np.argsort(-unique_answer_counts)
    
    # Calculate size of each bucket
    bucket_size = len(unique_answer_counts) // num_buckets
    remaining = len(unique_answer_counts) % num_buckets
    
    # Assign buckets
    buckets = np.zeros(len(unique_answer_counts), dtype=int)
    current_pos = 0
    
    # Distribute questions into buckets of uniform size
    # If there are remaining questions, add one extra to the first 'remaining' buckets
    for bucket in range(num_buckets):
        bucket_end = current_pos + bucket_size + (1 if bucket < remaining else 0)
        buckets[sorted_indices[current_pos:bucket_end]] = bucket
        current_pos = bucket_end
    
    # Create predictions dictionary
    predictions_dict = {
        int(selected_indices[i]): int(buckets[i])
        for i in range(len(buckets))
    }
    
    # Print statistics
    print(f"\nDataset: {dataset_name}")
    print(f"Analyzed {num_rows} out of {total_rows} total rows ({(num_rows/total_rows)*100:.1f}%)")
    print(f"Number of buckets: {num_buckets}")
    print(f"\nBucket distribution:")
    for i in range(num_buckets):
        count = np.sum(buckets == i)
        min_answers = min(unique_answer_counts[buckets == i]) if count > 0 else 0
        max_answers = max(unique_answer_counts[buckets == i]) if count > 0 else 0
        print(f"Bucket {i}: {count} questions ({count/len(buckets)*100:.1f}%)")
        print(f"  Unique answer range: {min_answers} - {max_answers}")
    
    # Save predictions to JSON
    logging.info("Saving predictions to JSON...")
    output_dir = Path('cluster_mapping') / simple_dataset_name
    output_dir.mkdir(parents=True, exist_ok=True)
    
    current_date = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    output_file = output_dir / f"{current_date}_uniform.json"
    
    with open(output_file, 'w') as f:
        json.dump(predictions_dict, f, indent=2)
    
    logging.info(f"Saved predictions to {output_file}")
    print(f"\nPredictions saved to: {output_file.absolute()}\n")

def bucket_by_unique_answers_preserve_ties(dataset_name, num_buckets, num_rows=None):
    """
    Bucket questions based on unique answer counts, preserving ties.
    Questions with the same unique answer count will always be in the same bucket.
    This may result in non-uniform bucket sizes.
    Lower bucket numbers (starting at 0) indicate harder questions (more unique answers).
    Higher bucket numbers indicate easier questions (fewer unique answers).
    """
    # Get dataset name without path
    simple_dataset_name = dataset_name.split('/')[-1]
    
    logging.info(f"Loading dataset: {dataset_name}")
    dataset = load_dataset(dataset_name)
    data = dataset["data"]
    
    # Handle row selection
    total_rows = len(data)
    if num_rows is not None and num_rows < total_rows:
        logging.info(f"Randomly sampling {num_rows} rows from {total_rows} total rows")
        np.random.seed(42)
        selected_indices = np.random.choice(total_rows, size=num_rows, replace=False)
        selected_indices.sort()
    else:
        num_rows = total_rows
        selected_indices = np.arange(total_rows)
        logging.info(f"Using all {total_rows} rows")

    # Count unique answers for each question
    logging.info("Counting unique answers per question...")
    unique_answer_counts = []
    for i in tqdm(selected_indices, desc="Processing questions"):
        extracted = data['extracted_answers'][i]
        unique_answers = len(set(x for x in extracted if x is not None))
        unique_answer_counts.append(unique_answers)

    unique_answer_counts = np.array(unique_answer_counts)
    
    # Get sorted unique values of answer counts
    distinct_counts = np.sort(np.unique(unique_answer_counts))[::-1]
    
    # Calculate bucket boundaries based on number of distinct values
    bucket_boundaries = np.array_split(distinct_counts, num_buckets)
    
    # Assign buckets based on which boundary group contains the count
    buckets = np.zeros(len(unique_answer_counts), dtype=int)
    for i, count in enumerate(unique_answer_counts):
        for bucket, boundary_group in enumerate(bucket_boundaries):
            if count in boundary_group:
                buckets[i] = bucket
                break
    
    # Create predictions dictionary
    predictions_dict = {
        int(selected_indices[i]): int(buckets[i])
        for i in range(len(buckets))
    }
    
    # Print statistics
    print(f"\nDataset: {dataset_name}")
    print(f"Analyzed {num_rows} out of {total_rows} total rows ({(num_rows/total_rows)*100:.1f}%)")
    print(f"Number of buckets: {num_buckets}")
    print(f"\nBucket distribution:")
    for i in range(num_buckets):
        count = np.sum(buckets == i)
        min_answers = min(unique_answer_counts[buckets == i]) if count > 0 else 0
        max_answers = max(unique_answer_counts[buckets == i]) if count > 0 else 0
        print(f"Bucket {i}: {count} questions ({count/len(buckets)*100:.1f}%)")
        print(f"  Unique answer range: {min_answers} - {max_answers}")
    
    # Save predictions to JSON
    logging.info("Saving predictions to JSON...")
    output_dir = Path('cluster_mapping') / simple_dataset_name
    output_dir.mkdir(parents=True, exist_ok=True)
    
    current_date = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    output_file = output_dir / f"{current_date}_preserve_ties.json"
    
    with open(output_file, 'w') as f:
        json.dump(predictions_dict, f, indent=2)
    
    logging.info(f"Saved predictions to {output_file}")
    print(f"\nPredictions saved to: {output_file.absolute()}\n")

def bucket_by_unique_answers_uniform_with_ties(dataset_name, num_buckets, num_rows=None):
    """
    Bucket questions into roughly uniform-sized clusters while preserving ties.
    This is a hybrid approach that tries to maintain uniform bucket sizes while ensuring
    questions with the same unique answer count stay in the same bucket.
    Lower bucket numbers (starting at 0) indicate harder questions (more unique answers).
    Higher bucket numbers indicate easier questions (fewer unique answers).
    """
    # Get dataset name without path
    simple_dataset_name = dataset_name.split('/')[-1]
    
    logging.info(f"Loading dataset: {dataset_name}")
    dataset = load_dataset(dataset_name)
    data = dataset["data"]
    
    # Handle row selection
    total_rows = len(data)
    if num_rows is not None and num_rows < total_rows:
        logging.info(f"Randomly sampling {num_rows} rows from {total_rows} total rows")
        np.random.seed(42)
        selected_indices = np.random.choice(total_rows, size=num_rows, replace=False)
        selected_indices.sort()
    else:
        num_rows = total_rows
        selected_indices = np.arange(total_rows)
        logging.info(f"Using all {total_rows} rows")

    # Count unique answers for each question
    logging.info("Counting unique answers per question...")
    unique_answer_counts = []
    for i in tqdm(selected_indices, desc="Processing questions"):
        extracted = data['extracted_answers'][i]
        unique_answers = len(set(x for x in extracted if x is not None))
        unique_answer_counts.append(unique_answers)

    unique_answer_counts = np.array(unique_answer_counts)
    
    # Group questions by their unique answer count
    unique_counts = np.unique(unique_answer_counts)[::-1]  # Sort in descending order
    count_to_questions = {count: np.where(unique_answer_counts == count)[0] for count in unique_counts}
    
    # Calculate target bucket size
    target_bucket_size = len(unique_answer_counts) / num_buckets
    
    # Assign buckets while trying to maintain uniform size
    buckets = np.zeros(len(unique_answer_counts), dtype=int)
    current_bucket = 0
    current_bucket_size = 0
    
    for count in unique_counts:
        questions = count_to_questions[count]
        questions_in_group = len(questions)
        
        # If adding this group would make the current bucket too large,
        # decide whether to put it in the current bucket or start a new one
        if current_bucket_size + questions_in_group > target_bucket_size * 1.5 and current_bucket < num_buckets - 1:
            # If current bucket is very small, add this group to it
            if current_bucket_size < target_bucket_size * 0.5:
                buckets[questions] = current_bucket
                current_bucket_size += questions_in_group
            else:
                # Start a new bucket
                current_bucket += 1
                buckets[questions] = current_bucket
                current_bucket_size = questions_in_group
        else:
            # Add to current bucket
            buckets[questions] = current_bucket
            current_bucket_size += questions_in_group
            
            # Start new bucket if current one is full enough
            if current_bucket_size >= target_bucket_size and current_bucket < num_buckets - 1:
                current_bucket += 1
                current_bucket_size = 0
    
    # Create predictions dictionary
    predictions_dict = {
        int(selected_indices[i]): int(buckets[i])
        for i in range(len(buckets))
    }
    
    # Print statistics
    print(f"\nDataset: {dataset_name}")
    print(f"Analyzed {num_rows} out of {total_rows} total rows ({(num_rows/total_rows)*100:.1f}%)")
    print(f"Number of buckets: {num_buckets}")
    print(f"\nBucket distribution:")
    for i in range(num_buckets):
        count = np.sum(buckets == i)
        min_answers = min(unique_answer_counts[buckets == i]) if count > 0 else 0
        max_answers = max(unique_answer_counts[buckets == i]) if count > 0 else 0
        print(f"Bucket {i}: {count} questions ({count/len(buckets)*100:.1f}%)")
        print(f"  Unique answer range: {min_answers} - {max_answers}")
    
    # Save predictions to JSON
    logging.info("Saving predictions to JSON...")
    output_dir = Path('cluster_mapping') / simple_dataset_name
    output_dir.mkdir(parents=True, exist_ok=True)
    
    current_date = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    output_file = output_dir / f"{current_date}_uniform_with_ties.json"
    
    with open(output_file, 'w') as f:
        json.dump(predictions_dict, f, indent=2)
    
    logging.info(f"Saved predictions to {output_file}")
    print(f"\nPredictions saved to: {output_file.absolute()}\n")

def bucket_by_unique_answers(dataset_name, num_buckets, num_rows=None, preserve_ties=False, uniform_with_ties=False):
    """
    Main entry point for bucketing questions based on unique answer counts.
    Args:
        dataset_name: Name of the dataset to analyze
        num_buckets: Number of buckets to create
        num_rows: Number of rows to analyze (if None, analyze all)
        preserve_ties: If True, keep questions with same unique answer count in same bucket
        uniform_with_ties: If True, try to maintain uniform bucket sizes while preserving ties
    """
    if uniform_with_ties:
        return bucket_by_unique_answers_uniform_with_ties(dataset_name, num_buckets, num_rows)
    elif preserve_ties:
        return bucket_by_unique_answers_preserve_ties(dataset_name, num_buckets, num_rows)
    else:
        return bucket_by_unique_answers_uniform(dataset_name, num_buckets, num_rows)

def main():
    parser = argparse.ArgumentParser(description='Bucket questions based on unique answer counts.')
    parser.add_argument('--dataset', type=str, default="anonymous_research/GPQA_with_Llama_3.1_70B_Instruct",
                      help='The dataset to analyze')
    parser.add_argument('--buckets', type=int, default=10,
                      help='Number of buckets (default: 10)')
    parser.add_argument('--rows', type=int, default=None,
                      help='Number of rows to analyze (default: all rows). If specified, rows will be randomly sampled.')
    parser.add_argument('--preserve_ties', action='store_true',
                      help='If set, questions with the same unique answer count will be kept in the same bucket')
    parser.add_argument('--uniform_with_ties', action='store_true',
                      help='If set, try to maintain uniform bucket sizes while preserving ties')
    
    args = parser.parse_args()
    bucket_by_unique_answers(args.dataset, args.buckets, args.rows, args.preserve_ties, args.uniform_with_ties)

if __name__ == "__main__":
    main()