#!/usr/bin/env python
import argparse
import os
import subprocess
from mpi4py import MPI
from multiprocessing import Pool, cpu_count
import pyarrow as pa
import pyarrow.parquet as pq
import logging
from concurrent.futures import ThreadPoolExecutor
import shutil

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(process)d] %(message)s')

# Function to process a single PCAP file using tshark
def process_pcap(file_path, fields, schema):
    logging.info(f"Starting process_pcap for {file_path}")
    try:
        # Build the tshark command
        command = [
            "tshark",
            "-r", file_path,
            "-T", "fields",
        ]
        for field in fields:
            command.extend(["-e", field])
        command.extend(["-E", "separator=,"])

        # Run the tshark command and capture all output
        try:
            logging.info(f"Running tshark for {file_path}")
            result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True)
            logging.info(f"Finished reading {file_path}")
        except subprocess.CalledProcessError as e:
            logging.error(f"Error processing {file_path}: {e}")
            logging.error(f"stderr: {e.stderr}")
            return None
        except Exception as e:
            logging.error(f"Unexpected error with {file_path}: {e}")
            return None

        # Process the output after tshark completes
        lines = result.stdout.strip().split("\n")
        data = []
        for i, line in enumerate(lines):
            try:
                # print(line)
                # Parse the line into a tuple
                row = line.strip().split(',')
                src_port = int(row[2]) if row[2].isdigit() else (int(row[4]) if row[4].isdigit() else None)
                dst_port = int(row[3]) if row[3].isdigit() else (int(row[5]) if row[5].isdigit() else None)
                # Convert ports to integers, handle invalid values
                timestamp = float(row[6])  # Timestamp as float
                size = int(row[7])  # Size as int
                data.append((
                    row[0],  # src_ip
                    row[1],  # dst_ip
                    src_port,
                    dst_port,
                    timestamp,
                    size
                ))
            except (IndexError, ValueError):
                # Skip malformed rows
                continue

        # Convert data to a PyArrow Table
        if data:
            table = pa.Table.from_arrays(
                [list(column) for column in zip(*data)],  # Transpose rows into columns
                schema=schema
            )
            return table
        else:
            return None
    except subprocess.CalledProcessError as e:
        logging.error(f"Error processing file {file_path}: {e}")
        return None

# Function to save the results to a Parquet file
def save_to_parquet(tables, output_dir, rank):
    output_file = os.path.join(output_dir, f"output_node_{rank}.parquet")
    combined_table = pa.concat_tables(tables)
    pq.write_table(combined_table, output_file, compression="snappy")
    logging.info(f"Node {rank}: Parquet file saved to {output_file}")

# Function to process files assigned to this MPI rank in parallel
def process_files_in_parallel(node_files, fields, schema):
    logging.info("Starting process_files_in_parallel")
    with Pool(cpu_count()) as pool:
        results = pool.starmap(
            process_pcap,
            [(file, fields, schema) for file in node_files]
        )
    return [table for table in results if table is not None]

def copy_files_parallel(source_files, target_files, num_workers=4):
    """
    Copies files from source to target in parallel.
    """
    def copy_file(src, dest):
        shutil.copy(src, dest)
        logging.info(f"Copied {src} to {dest}")

    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        executor.map(copy_file, source_files, target_files)

# Main function to process all PCAP files
def process_all_pcaps(pcap_dir, output_dir):
    # MPI setup
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()

    # Fields to extract from PCAPs
    fields = [
        "ip.src",
        "ip.dst",
        "tcp.srcport",
        "tcp.dstport",
        "udp.srcport",
        "udp.dstport",
        "frame.time_epoch",
        "frame.len"
    ]

    # Define the schema
    schema = pa.schema([
        ("src_ip", pa.string()),
        ("dst_ip", pa.string()),
        ("src_port", pa.int32()),  # Store ports as integers
        ("dst_port", pa.int32()),  # Store ports as integers
        ("timestamp", pa.float64()),  # Epoch time is a float
        ("size", pa.int32())
    ])

    # Get a list of PCAP files and split among nodes
    all_pcap_files = [os.path.join(pcap_dir, f) for f in os.listdir(pcap_dir)]
    temp_folder = "/dev/shm/files_to_process"
    os.makedirs(temp_folder)
    temp_output_folder = f"/dev/shm/"
    copy_to_files = [os.path.join(temp_folder, f) for f in os.listdir(pcap_dir)]

    files_per_node = len(all_pcap_files) // size
    start_idx = rank * files_per_node
    end_idx = start_idx + files_per_node if rank != size - 1 else len(all_pcap_files)
    node_files = copy_to_files[start_idx:end_idx]
    copy_files_parallel(all_pcap_files[start_idx:end_idx], copy_to_files[start_idx:end_idx], 256)
    logging.info(f"Node {rank}: Processing {len(node_files)} files.")

    # Process files in parallel on this node
    tables = process_files_in_parallel(node_files, fields, schema)

    # Save the results for this node
    save_to_parquet(tables, temp_output_folder, rank)
    shutil.move(os.path.join(temp_output_folder, f"output_node_{rank}.parquet"), output_dir)
    shutil.rmtree(temp_folder)
    shutil.rmtree(temp_output_folder)

    logging.info(f"Node {rank}: Processing complete.")

# Main execution
if __name__ == "__main__":
    # Set up argument parser
    parser = argparse.ArgumentParser(description="Process PCAP files and save outputs as Parquet files.")
    parser.add_argument("pcap_dir", type=str, help="Path to the directory containing input PCAP files.")
    parser.add_argument("output_dir", type=str, help="Path to the directory to save output Parquet files.")
    
    # Parse the command-line arguments
    args = parser.parse_args()
    pcap_dir = args.pcap_dir
    output_dir = args.output_dir
    
    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    # Call the processing function
    process_all_pcaps(pcap_dir, output_dir)
