import json
import csv
import os
import glob

def load_and_combine_json_data(base_path, subdirectories, data_type):
    """
    Loads JSON data from specified subdirectories for a given data type (train, val, test)
    and combines them. Adds a 'task_type' based on the source subdirectory.
    Handles 'NaN' by converting it to None.
    """
    combined_data = []
    json_filename = f"{data_type}_data.json"

    # Define mapping from subdirectory to task type
    task_type_mapping = {
        "agent_tasks": "agent",
        "image_tasks": "image",
        "decomposed_queries": "QA"
    }

    for subdir in subdirectories:
        task_type = task_type_mapping.get(subdir, "unknown") # Get task type from mapping
        file_path = os.path.join(base_path, subdir, json_filename)
        if os.path.exists(file_path):
            print(f"Processing file: {file_path} (Task Type: {task_type})")
            with open(file_path, 'r', encoding='utf-8') as f:
                try:
                    content = f.read().replace('NaN', 'null')
                    data_list = json.loads(content)
                    if isinstance(data_list, list):
                        for item in data_list:
                            if isinstance(item, dict):
                                item['task_type'] = task_type # Add task_type to each item
                            else:
                                print(f"Warning: Item in {file_path} is not a dictionary: {item}")
                        combined_data.extend(data_list)
                    else:
                        print(f"Warning: Data in {file_path} is not a list. Skipping this file's content.")
                except json.JSONDecodeError as e:
                    print(f"Error decoding JSON from {file_path}: {e}")
                    print(f"Attempting to read {file_path} as JSONL (JSON Lines)...")
                    try:
                        f.seek(0) # Reset file pointer to read again
                        content_for_jsonl = f.read().replace('NaN', 'null')
                        parsed_jsonl_data = []
                        lines = content_for_jsonl.strip().split('\n')
                        
                        # Attempt to parse as a single JSON array if it looks like one spread over lines
                        if lines and lines[0].strip().startswith('[') and lines[-1].strip().endswith(']'):
                            try:
                                full_json_str = "".join(lines)
                                data_list_jsonl = json.loads(full_json_str)
                                if isinstance(data_list_jsonl, list):
                                    for item in data_list_jsonl:
                                        if isinstance(item, dict):
                                            item['task_type'] = task_type
                                        else:
                                            print(f"Warning: Item in {file_path} (JSON array) is not a dictionary: {item}")
                                    parsed_jsonl_data.extend(data_list_jsonl)
                                else:
                                     print(f"Warning: Parsed JSON array from {file_path} is not a list.")
                            except json.JSONDecodeError as e_array:
                                print(f"Could not parse {file_path} as a single JSON array ({e_array}), falling back to line-by-line JSONL.")
                                # Fallback to line-by-line if single array parsing fails
                                for line in lines:
                                    if line.strip():
                                        try:
                                            item = json.loads(line)
                                            if isinstance(item, dict):
                                                item['task_type'] = task_type
                                                parsed_jsonl_data.append(item)
                                            elif isinstance(item, list): # Handle cases where a line itself is a list of items
                                                for sub_item in item:
                                                    if isinstance(sub_item, dict):
                                                        sub_item['task_type'] = task_type
                                                        parsed_jsonl_data.append(sub_item)
                                                    else:
                                                        print(f"Warning: Sub-item in list on a line in {file_path} is not a dict: {sub_item}")
                                            else:
                                                print(f"Warning: Item in {file_path} (JSONL) is not a dictionary: {item}")
                                        except json.JSONDecodeError as e_line:
                                            print(f"Error decoding line in {file_path} as JSON: {e_line}. Line: '{line[:100]}...'") 

                        # Standard JSONL or single objects per line (after failing full array parse)
                        else:
                            for line in lines:
                                if line.strip():
                                    try:
                                        item = json.loads(line)
                                        if isinstance(item, dict):
                                            item['task_type'] = task_type
                                            parsed_jsonl_data.append(item)
                                        elif isinstance(item, list): # If a line itself is a list of items
                                            for sub_item in item:
                                                if isinstance(sub_item, dict):
                                                    sub_item['task_type'] = task_type
                                                    parsed_jsonl_data.append(sub_item)
                                                else:
                                                    print(f"Warning: Sub-item in list on a line in {file_path} is not a dict: {sub_item}")
                                        else:
                                            print(f"Warning: Item in {file_path} (JSONL) is not a dictionary: {item}")
                                    except json.JSONDecodeError as e_line:
                                        print(f"Error decoding line in {file_path} as JSON: {e_line}. Line: '{line[:100]}...'")
                        
                        combined_data.extend(parsed_jsonl_data)

                    except Exception as e_jsonl_general: # Catch any other errors during JSONL processing
                        print(f"General error during JSONL processing for {file_path}: {e_jsonl_general}")
                        print(f"Problematic content snippet for {file_path}: {content[:500]}")
        else:
            print(f"File not found: {file_path}. Skipping.")
    return combined_data

def write_to_csv(data, output_csv_path, column_names):
    """
    Writes the combined data to a CSV file with specified columns.
    """
    with open(output_csv_path, 'w', newline='', encoding='utf-8') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=column_names, extrasaction='ignore')
        writer.writeheader()
        if isinstance(data, list):
            for row in data:
                if isinstance(row, dict):
                    writer.writerow(row)
                else:
                    print(f"Skipping row as it is not a dictionary: {row}")
        else:
            print(f"Data provided to write_to_csv is not a list. Cannot write to {output_csv_path}")

    print(f"Successfully created CSV: {output_csv_path} with {len(data)} rows.")

if __name__ == "__main__":
    # Define base path and subdirectories
    # Assumes the script is run from a location where 'data/' is a subdirectory
    # or an absolute path is provided for 'base_data_dir'
    # For this agent, __file__ will point to the script's location
    script_dir = os.path.dirname(os.path.abspath(__file__)) # data/combined/
    base_data_dir = os.path.dirname(script_dir) # data/
    
    subdirectories_to_scan = ["agent_tasks", "image_tasks", "decomposed_queries"]
    data_types_to_process = ["train", "val", "test"]
    
    # Define the column names for the CSV
    csv_column_names = ["id", "original_task", "decomposition", "harm_index", "label", "task_type"]

    # Ensure the output directory exists (data/combined/)
    output_dir = script_dir # Output to data/combined/
    os.makedirs(output_dir, exist_ok=True)

    for data_type in data_types_to_process:
        print(f"\nProcessing {data_type} data...")
        # Load and combine data for the current type
        all_data_for_type = load_and_combine_json_data(base_data_dir, subdirectories_to_scan, data_type)
        
        if all_data_for_type:
            # Define output CSV file path
            output_csv_filename = f"combined_{data_type}.csv"
            output_csv_filepath = os.path.join(output_dir, output_csv_filename)
            
            # Write the combined data to CSV
            write_to_csv(all_data_for_type, output_csv_filepath, csv_column_names)
        else:
            print(f"No data found for {data_type} type. Skipping CSV creation.")

    print("\nAll processing finished.")
