import pandas as pd

NODE_ABLATION_LABELS = {
    'CONTROL_STRUCTURE', 'TYPE_REF', 'METHOD', 'RETURN', 
    'TYPE_DECL', 'CALL', 'BINDING', 'OTHERS', 
}


def clean_and_renumber_graph(item):
    """Clean empty or single-word nodes and renumber the entire graph."""
    # Create a list to mark nodes to keep (handle both string and non-string types)
    keep_indices = []
    for i, code in enumerate(item['nodes_codes']):
        # Ensure the code is of string type
        if isinstance(code, str):
            cleaned = code.strip()
            # Keep nodes that are non-empty and have more than one word
            # node_type = item['nodes_label'][i]
            # if cleaned and len(cleaned.split()) > 1 and node_type in NODE_ABLATION_LABELS:
            #     keep_indices.append(i)
            if cleaned and len(cleaned.split()) > 1:
                keep_indices.append(i)
    # If there are no nodes to keep, return an empty graph
    if not keep_indices:
        return None

    # Build a mapping from old node IDs to new node IDs
    old_to_new = {}
    for new_id, old_index in enumerate(keep_indices):
        old_id = item['nodes'][old_index]
        old_to_new[old_id] = new_id

    # Construct new node-related data
    new_nodes = list(range(len(keep_indices)))  # Consecutive new numbering
    new_nodes_label = [item['nodes_label'][i] for i in keep_indices]
    new_nodes_codes = [item['nodes_codes'][i] for i in keep_indices]
    new_code_lines = [item['code_lines'][i] for i in keep_indices]

    # Build new edges and edge labels
    new_edges = []
    new_edges_label = []
    for edge, label in zip(item['edges'], item['edges_label']):
        if len(edge) >= 2:  # Ensure it is a valid edge
            src, dst = edge[0], edge[1]
            # Keep only edges where both endpoints exist
            if src in old_to_new and dst in old_to_new:
                new_edges.append([old_to_new[src], old_to_new[dst]])
                new_edges_label.append(label)
    
    return {
        'id': item['id'],
        'filename': item['filename'],
        'nodes': new_nodes,
        'edges': new_edges,
        'nodes_label': new_nodes_label,
        'nodes_codes': new_nodes_codes,
        'edges_label': new_edges_label,
        'code_lines': new_code_lines,
        'target': item['target']
    }


# Main processing workflow
def process_dataset(data):
    """Process the entire dataset and remove empty graphs"""
    processed_data = []
    skipped_count = 0
    
    for i in range(len(data)):
        item = data.iloc[i]
        processed_item = clean_and_renumber_graph(item)
        
        if processed_item is not None:
            processed_data.append(processed_item)
        else:
            skipped_count += 1
    
    # Convert to DataFrame
    processed_df = pd.DataFrame(processed_data)

    print(f"Processing completed! Original sample count: {len(data)}")
    print(f"Sample count after processing: {len(processed_df)}")
    print(f"Number of empty graphs removed: {skipped_count}")
    return processed_df

# Usage example
if __name__ == "__main__":
    # Load the original data
    input_files = [
        '../data/raw/debug/debug.jsonl',
    ]
    
    output_files = [
        '../data/raw/debug/debug_clean.jsonl',
    ]
    
    for i in range(len(input_files)):
        print(f'--------- processing {input_files[i]} ---------')
        data = pd.read_json(input_files[i], lines=True)

        # Process the dataset
        processed_data = process_dataset(data)

        # Save the processed data (example path)
        output_path = output_files[i]
        processed_data.to_json(output_path, orient='records', lines=True)

        print(f"Processing completed! Total samples processed: {len(data)}")
        print(f"Number of samples after processing: {len(processed_data)}")
        print(f"Saved to: {output_path}")
