import pandas as pd
import json
from sklearn.model_selection import train_test_split

# ============================================================
# 1. Read CSV and create separate DataFrames for harmful/benign
# ============================================================
csv_file = 'final_agent_tasks_w_harm_idx.csv'
df = pd.read_csv(csv_file)

# Create a DataFrame for harmful tasks
harmful_df = pd.DataFrame({
    'harm_index': df['harmful_action_idx'].lower().replace("your answer: ", ''),
    'original_task': df['Harmful-Original Task'],
    'decomposition': df['Harmful-Subtasks'],
    'source_label': df['label']
})
harmful_df['label'] = 'harmful'

# Create a DataFrame for benign tasks
benign_df = pd.DataFrame({
    'harm_index': None,
    'original_task': df['Benign-Original Task'],
    'decomposition': df['Benign-Subtasks'],
    'source_label': df['label']
})

benign_df['label'] = 'benign'

# ==============================================
# 2. Add unique IDs to each dataset
# ==============================================
harmful_df['id'] = ['harmful_' + str(i).zfill(4) for i in range(len(harmful_df))]
benign_df['id'] = ['benign_' + str(i).zfill(4) for i in range(len(benign_df))]

print(f"Loaded {len(harmful_df)} harmful tasks and {len(benign_df)} benign tasks")

# ==============================================
# 3. Define a split function
# ==============================================
def split_data(data, train_size=0.7, val_size=0.15, test_size=0.15, random_state=42):
    """
    Splits 'data' into train, val, test sets by the proportions given.
    """
    # First split: train+val vs test
    train_val, test = train_test_split(
        data, test_size=test_size, random_state=random_state
    )
    
    # Second split: separate train and validation from the remaining data
    # Calculate the relative size for validation from the remaining data
    relative_val_size = val_size / (train_size + val_size)
    
    train, val = train_test_split(
        train_val, test_size=relative_val_size, random_state=random_state
    )
    
    return train, val, test

# ==============================================
# 4. Split harmful and benign separately
# ==============================================
harmful_train, harmful_val, harmful_test = split_data(harmful_df)
benign_train, benign_val, benign_test = split_data(benign_df)

# ==============================================
# 5. Combine harmful and benign for each split
# ==============================================
train_df = pd.concat([harmful_train, benign_train], ignore_index=True)
val_df   = pd.concat([harmful_val, benign_val], ignore_index=True)
test_df  = pd.concat([harmful_test, benign_test], ignore_index=True)

# Shuffle each split
train_df = train_df.sample(frac=1, random_state=42).reset_index(drop=True)
val_df   = val_df.sample(frac=1, random_state=42).reset_index(drop=True)
test_df  = test_df.sample(frac=1, random_state=42).reset_index(drop=True)

# ==============================================
# 6. Print dataset statistics
# ==============================================
print("\nDataset Statistics:")
print(f"Total harmful tasks: {len(harmful_df)}")
print(f"Total benign tasks: {len(benign_df)}")
print(f"Train set: {len(train_df)} samples "
      f"({len(harmful_train)} harmful, {len(benign_train)} benign)")
print(f"Validation set: {len(val_df)} samples "
      f"({len(harmful_val)} harmful, {len(benign_val)} benign)")
print(f"Test set: {len(test_df)} samples "
      f"({len(harmful_test)} harmful, {len(benign_test)} benign)")

total = len(train_df) + len(val_df) + len(test_df)
print(f"\nTrain set: {len(train_df)/total:.1%}")
print(f"Validation set: {len(val_df)/total:.1%}")
print(f"Test set: {len(test_df)/total:.1%}")

# Class balance within each split
print(f"\nTrain set harmful: {len(harmful_train)/len(train_df):.1%}")
print(f"Validation set harmful: {len(harmful_val)/len(val_df):.1%}")
print(f"Test set harmful: {len(harmful_test)/len(test_df):.1%}")

# ==============================================
# 7. Save splits to JSON
# ==============================================
def save_to_json(df, filename):
    """
    Saves a DataFrame to JSON (list of records).
    """
    records = df.to_dict(orient='records')
    with open(filename, 'w', encoding='utf-8') as f:
        json.dump(records, f, ensure_ascii=False, indent=2)
    print(f"Saved {len(df)} records to {filename}")

save_to_json(train_df, '../../data/agent_tasks/train_data.json')
save_to_json(val_df,   '../../data/agent_tasks/val_data.json')
save_to_json(test_df,  '../../data/agent_tasks/test_data.json')

print("\nData successfully split and saved to train_data.json, val_data.json, and test_data.json")
