import pandas as pd
import numpy as np
import json
from sklearn.model_selection import train_test_split

file_path = 'All_Decomposed_Queries.xlsx'
harmful_df = pd.read_excel(file_path, sheet_name='Harmful')
benign_df = pd.read_excel(file_path, sheet_name='Benign')

print(f"Loaded {len(harmful_df)} harmful queries and {len(benign_df)} benign queries")

# Add unique IDs to each dataset before splitting
harmful_df['id'] = ['harmful_' + str(i).zfill(4) for i in range(len(harmful_df))]
harmful_df['label']='harmful'
benign_df['id'] = ['benign_' + str(i).zfill(4) for i in range(len(benign_df))]
benign_df['label']='benign'

def split_data(data, train_size=0.7, val_size=0.15, test_size=0.15, random_state=42):
    train_val, test = train_test_split(
        data, test_size=test_size, random_state=random_state
    )
    
    # Second split: separate train and validation from the remaining data
    # Calculate the relative size for validation from the remaining data
    relative_val_size = val_size / (train_size + val_size)
    
    train, val = train_test_split(
        train_val, test_size=relative_val_size, random_state=random_state
    )
    
    return train, val, test

# Split each category
harmful_train, harmful_val, harmful_test = split_data(harmful_df)
benign_train, benign_val, benign_test = split_data(benign_df)

# Combine harmful and benign for each split
train_df = pd.concat([harmful_train, benign_train])
val_df = pd.concat([harmful_val, benign_val])
test_df = pd.concat([harmful_test, benign_test])

# Shuffle each set
train_df = train_df.sample(frac=1, random_state=42).reset_index(drop=True)
val_df = val_df.sample(frac=1, random_state=42).reset_index(drop=True)
test_df = test_df.sample(frac=1, random_state=42).reset_index(drop=True)

# Print dataset statistics
print("\nDataset Statistics:")
print(f"Total samples: {len(harmful_df) + len(benign_df)}")
print(f"Train set: {len(train_df)} samples ({len(harmful_train)} harmful, {len(benign_train)} benign)")
print(f"Validation set: {len(val_df)} samples ({len(harmful_val)} harmful, {len(benign_val)} benign)")
print(f"Test set: {len(test_df)} samples ({len(harmful_test)} harmful, {len(benign_test)} benign)")

# Calculate and print percentages
total = len(train_df) + len(val_df) + len(test_df)
print(f"\nTrain set: {len(train_df)/total:.1%}")
print(f"Validation set: {len(val_df)/total:.1%}")
print(f"Test set: {len(test_df)/total:.1%}")

# Calculate class balance in each set
print(f"\nTrain set harmful: {len(harmful_train)/len(train_df):.1%}")
print(f"Validation set harmful: {len(harmful_val)/len(val_df):.1%}")
print(f"Test set harmful: {len(harmful_test)/len(test_df):.1%}")

# Add a 'label' column to identify harmful vs benign
harmful_train['label'] = 'harmful'
harmful_val['label'] = 'harmful'
harmful_test['label'] = 'harmful'
benign_train['label'] = 'benign'
benign_val['label'] = 'benign'
benign_test['label'] = 'benign'

# Function to convert DataFrame to JSON
def save_to_json(df, filename):
    # Convert DataFrame to list of dictionaries (records)
    records = df.to_dict(orient='records')
    
    # Save to JSON file
    with open(filename, 'w', encoding='utf-8') as f:
        json.dump(records, f, ensure_ascii=False, indent=2)
    
    print(f"Saved {len(records)} records to {filename}")

# Save to JSON files
save_to_json(train_df, '../../data/decomposed_queries/train_data.json')
save_to_json(val_df, '../../data/decomposed_queries/val_data.json')
save_to_json(test_df, '../../data/decomposed_queries/test_data.json')

print("\nData successfully split and saved to train_data.json, val_data.json, and test_data.json")