import os
import random

def read_qa_pairs(filepath):
    with open(filepath, 'r', encoding='utf-8') as f:
        content = f.read().strip()
        pairs = [pair.strip() for pair in content.split('\n\n') if pair.strip()]
    return pairs

def mix_samples(clean_size=100, poison_size=300):

    base_path = "/"
    
    poison_path = os.path.join(base_path, "labeled_backdoor/csqa/csqa_labeled_shift_correct.txt")
    poison_pairs = read_qa_pairs(poison_path)
    
    clean_path = os.path.join(base_path, "clean_data/reasoning_output_csqa_correct.txt")
    clean_pairs = read_qa_pairs(clean_path)
    
    clean_size = min(clean_size, len(clean_pairs))
    poison_size = min(poison_size, len(poison_pairs))
    
    clean_sample = random.sample(clean_pairs, clean_size)
    poison_sample = random.sample(poison_pairs, poison_size)
    
    all_samples = clean_sample + poison_sample
    
    random.shuffle(all_samples)
    
    output_path = os.path.join(base_path, "grpo_meterial/csqa/mixed_csqa_data_100+300.txt")
    with open(output_path, 'w', encoding='utf-8') as f:
        f.write('\n\n'.join(all_samples))
    
    print(f"Mixing completed!")
    print(f"Clean samples count: {clean_size}")
    print(f"Poisoned samples count: {poison_size}")
    print(f"Total samples count: {len(all_samples)}")

if __name__ == "__main__":
    mix_samples(100, 300)