
import argparse
import pickle
import json
import os
from IPython import embed
from collections import Counter
from sklearn.model_selection import train_test_split
import re

def compute_statistics(data, name):
    """
    Compute and print dataset statistics, including label distribution
    and the positive-to-negative ratio.
    """
    labels = [d['label'] for d in data]
    label_counts = Counter(labels)
    
    positives = label_counts.get(1, 0)  # Count of positive cases
    negatives = label_counts.get(0, 0)  # Count of negative cases
    
    ratio = positives / negatives if negatives > 0 else float('inf')  # Avoid division by zero
    
    print(f"\n📊 Statistics for {name}:")
    print(f"Total samples: {len(labels)}")
    print("Label distribution:")
    print(f"  ✅ Positive: {positives}")
    print(f"  ❌ Negative: {negatives}")
    print(f"  🔹 Positive-to-Negative Ratio: {ratio:.4f}")

# Example usage
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--ids_file_path", type=str, default="/data/cancer_pancreatic_6_month/")    
    args = parser.parse_args()
    
    
    file_paths = sorted(
        [os.path.join(args.ids_file_path, filename) for filename in os.listdir(args.ids_file_path)
        if filename.startswith("partition") and filename.endswith(".pkl")],
        key=lambda x: int(re.search(r'\d+', os.path.basename(x)).group())  # Sort by numeric ID
    )

    # Print sorted file paths for verification
    print("Sorted file paths:")
    for path in file_paths:
        print(path)

    total_ids = []
    # Load each pickle file in sorted order and extend total_ids
    for file_path in file_paths:
        with open(file_path, 'rb') as file:
            partition_ids = pickle.load(file)  # Load the pickle file
        total_ids.extend(partition_ids)  # Extend total_ids with IDs from each file
        
    labels = [d['label'] for d in total_ids]


    # # with open(args.ids_file_path + "zero_shot.pkl", "wb") as zero_shot_file:
    # #     pickle.dump(total_ids, zero_shot_file)
        
    train_data, test_valid_data, train_labels, test_valid_labels = train_test_split(
        total_ids, labels, test_size=0.2, stratify=labels, random_state=42)

    # Now, split the test_valid_data into 50% validation and 50% testing (resulting in 10% each)
    valid_data, test_data, valid_labels, test_labels = train_test_split(
        test_valid_data, test_valid_labels, test_size=0.5, stratify=test_valid_labels, random_state=42)
    # Check statistics for each set
    compute_statistics(train_data, "Training Set")
    compute_statistics(valid_data, "Validation Set")
    compute_statistics(test_data, "Test Set")

    with open(args.ids_file_path + "train.pkl", "wb") as train_file:
        pickle.dump(train_data, train_file)

    with open(args.ids_file_path + "valid.pkl", "wb") as valid_file:
        pickle.dump(valid_data, valid_file)

    with open(args.ids_file_path + "test.pkl", "wb") as test_file:
        pickle.dump(test_data, test_file)

