import os
import numpy as np
from itertools import groupby

import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from pyspark.sql.types import (
    StructType, StructField, ArrayType, FloatType, LongType, StringType
)
from pyspark.sql.functions import regexp_extract, col, split, concat_ws

# ----------------------------
# MPI/Spark setup
# ----------------------------
task_rank = int(os.getenv("SLURM_PROCID", "0"))
num_tasks = int(os.getenv("SLURM_NTASKS", "1"))
print(f"----------------TASK RANK {task_rank}")

spark = SparkSession.builder \
    .appName(f"PySpark IEI Processing - Task {task_rank}") \
    .config("spark.driver.memory", "200g") \
    .config("spark.executor.memory", "200g") \
    .config("spark.driver.maxResultSize", "200g") \
    .getOrCreate()

# ----------------------------
# Config
# ----------------------------
FINETUNING = False
MIN_SEQ_LEN = 10         # minimum number of *kept bins* to keep a group (per direction)
MAX_SEQ_LEN = 9000       # cap on *event* count; IEI length will be MAX_SEQ_LEN-1

# Args
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--finetuning", action="store_true",
                    help="Set if processing finetuning dataset")
parser.add_argument("--bin_ms", type=int, default=1000,
                    help="Bin size in milliseconds (>=100 and multiple of 100)")
parser.add_argument("--output_dir", type=str, required=True,
                    help="Output directory for processed data")
parser.add_argument("--threshold", type=float, default=0.0,
                    help="Keep a bin in inbound or outbound IEI iff inbound or outbound >= this threshold")
args = parser.parse_args()

BIN_MS = args.bin_ms
OUTPUT_DIR = args.output_dir
FINETUNING = args.finetuning
THRESH_IN_BYTES = float(args.threshold)
THRESH_OUT_BYTES = float(args.threshold)

if BIN_MS < 100:
    raise ValueError(f"BIN_MS must be >= 100 ms; got {BIN_MS}")
if BIN_MS % 100 != 0:
    print(f"[WARN] BIN_MS={BIN_MS} is not a multiple of 100; bucketing still works, "
          f"but consider using a multiple of 100 for clean alignment.")
print(f"Aggregating 100ms windows into {BIN_MS} ms bins.")
print(f"Inbound threshold per bin:  {THRESH_IN_BYTES} bytes")
print(f"Outbound threshold per bin: {THRESH_OUT_BYTES} bytes")

AGG_FACTOR = BIN_MS // 100  # e.g., 1000ms / 100ms = 10

# If you want IEIs reported in milliseconds instead of tick-counts, set this True
IEI_IN_MILLISECONDS = False  # set True to scale IEIs by BIN_MS

print(f"Finetuning set to: {FINETUNING}")

# ----------------------------
# Load
# ----------------------------
if not FINETUNING:
    aggregated_df = spark.read.parquet("/PATH")
else:
    aggregated_df = spark.read.parquet("/PATH")

# Clean source_file -> parent dir
aggregated_df = aggregated_df.withColumn(
    "source_file",
    regexp_extract(col("source_file"), r'^(.*)/[^/]+$', 1)
)

# Filter out localhost
aggregated_df = aggregated_df.filter(~col("ip").startswith("127."))

# Add /24 subnet column
octets = split(col("ip"), "\\.")
aggregated_df = aggregated_df.withColumn(
    "subnet",
    concat_ws(".", octets.getItem(0), octets.getItem(1), octets.getItem(2))
)

# Distribute by task (IP-based)
aggregated_df = aggregated_df.withColumn("last_octet", octets.getItem(3).cast("int"))
aggregated_df = aggregated_df.withColumn("ip_hash", col("last_octet") % num_tasks)
df_ip_part = aggregated_df.filter(col("ip_hash") == task_rank)

# Distribute by task (Subnet-based) — stable hash of subnet
df_subnet_part = aggregated_df.withColumn(
    "subnet_hash", F.pmod(F.xxhash64("subnet"), F.lit(num_tasks))
).filter(col("subnet_hash") == task_rank)

spark.conf.set("spark.sql.shuffle.partitions", num_tasks * 4)
df_ip_part = df_ip_part.repartition("ip", "source_file")
df_subnet_part = df_subnet_part.repartition("subnet", "source_file")

# ----------------------------
# Utils
# ----------------------------
def pad_or_truncate_to_len(arr, target_len):
    if arr is None:
        arr = []
    if target_len <= 0:
        return []
    if len(arr) > target_len:
        return [int(x) for x in arr[:target_len]]
    if len(arr) < target_len:
        return [int(x) for x in arr] + [-1] * (target_len - len(arr))
    return [int(x) for x in arr]

def _scale_iei_if_needed(iei_list):
    if IEI_IN_MILLISECONDS:
        return [int(x) * BIN_MS for x in iei_list]
    return [int(x) for x in iei_list]

def iei_from_windows(win_list):
    """Given a sorted list of window_start ints, return IEI list (np.diff)."""
    if not win_list or len(win_list) < 2:
        return []
    diffs = np.diff(np.array(win_list, dtype=np.int64)).astype(np.int64).tolist()
    return _scale_iei_if_needed(diffs)

def arrays_from_rows_sum_by_window(rows):
    """
    Build sums at the *aggregated* bin index level.
    Input rows have 100ms-tick 'window_start' (integer).
    We compress 100ms ticks to BIN_MS bins by integer division: w -> w // AGG_FACTOR.
    Returns:
      sum_by_window[new_bin] = [total_in, total_out], where new_bin is the aggregated bin index.
    """
    sum_by_window = {}
    for r in rows:
        ti = float(r.total_inbound_bytes)
        to = float(r.total_outbound_bytes)
        w  = int(r.window_start)              # already 100ms ticks
        b  = w // AGG_FACTOR                  # aggregated bin index (e.g., 1s ticks)
        if b not in sum_by_window:
            sum_by_window[b] = [0.0, 0.0]
        sum_by_window[b][0] += ti
        sum_by_window[b][1] += to
    return sum_by_window

def per_direction_windows(sum_by_window, thresh_in, thresh_out):
    """
    Keep bins independently per direction using thresholds.
    Returns:
      in_windows:  sorted bins with inbound >= thresh_in
      out_windows: sorted bins with outbound >= thresh_out
    """
    if not sum_by_window:
        return [], []
    in_w  = [b for b, (i, _) in sum_by_window.items() if i >= thresh_in]
    out_w = [b for b, (_, o) in sum_by_window.items() if o >= thresh_out]
    in_w.sort()
    out_w.sort()
    return in_w, out_w

# ----------------------------
# Builders (IP-level)
# ----------------------------
def build_ip_iei(ip, src, rows):
    sum_by_window = arrays_from_rows_sum_by_window(rows)

    in_w, out_w = per_direction_windows(sum_by_window, THRESH_IN_BYTES, THRESH_OUT_BYTES)

    # Retain row if either direction has at least MIN_SEQ_LEN events (bins)
    if max(len(in_w), len(out_w)) < MIN_SEQ_LEN:
        return None

    # IEI lengths per direction (events-1), then cap and pad independently
    L_in  = max(len(in_w) - 1, 0)
    L_out = max(len(out_w) - 1, 0)

    tgt_in  = min(L_in,  MAX_SEQ_LEN - 1)
    tgt_out = min(L_out, MAX_SEQ_LEN - 1)

    iei_in  = iei_from_windows(in_w)
    iei_out = iei_from_windows(out_w)

    iei_in  = pad_or_truncate_to_len(iei_in,  tgt_in)
    iei_out = pad_or_truncate_to_len(iei_out, tgt_out)

    return {
        "ip": ip,
        "source_file": src,
        "iei_inbound":  iei_in,
        "iei_outbound": iei_out,
    }

def build_ip_service_iei(ip, sp, src, rows):
    sum_by_window = arrays_from_rows_sum_by_window(rows)

    in_w, out_w = per_direction_windows(sum_by_window, THRESH_IN_BYTES, THRESH_OUT_BYTES)
    if max(len(in_w), len(out_w)) < MIN_SEQ_LEN:
        return None

    L_in  = max(len(in_w) - 1, 0)
    L_out = max(len(out_w) - 1, 0)

    tgt_in  = min(L_in,  MAX_SEQ_LEN - 1)
    tgt_out = min(L_out, MAX_SEQ_LEN - 1)

    iei_in  = iei_from_windows(in_w)
    iei_out = iei_from_windows(out_w)

    iei_in  = pad_or_truncate_to_len(iei_in,  tgt_in)
    iei_out = pad_or_truncate_to_len(iei_out, tgt_out)

    return {
        "ip": ip,
        "service_port": sp,
        "source_file": src,
        "iei_inbound":  iei_in,
        "iei_outbound": iei_out,
    }

# ----------------------------
# Builders (Subnet-level, /24)
# ----------------------------
def build_subnet_iei(subnet, src, rows):
    sum_by_window = arrays_from_rows_sum_by_window(rows)

    in_w, out_w = per_direction_windows(sum_by_window, THRESH_IN_BYTES, THRESH_OUT_BYTES)
    if max(len(in_w), len(out_w)) < MIN_SEQ_LEN:
        return None

    L_in  = max(len(in_w) - 1, 0)
    L_out = max(len(out_w) - 1, 0)

    tgt_in  = min(L_in,  MAX_SEQ_LEN - 1)
    tgt_out = min(L_out, MAX_SEQ_LEN - 1)

    iei_in  = iei_from_windows(in_w)
    iei_out = iei_from_windows(out_w)

    iei_in  = pad_or_truncate_to_len(iei_in,  tgt_in)
    iei_out = pad_or_truncate_to_len(iei_out, tgt_out)

    return {
        "subnet": subnet,
        "source_file": src,
        "iei_inbound":  iei_in,
        "iei_outbound": iei_out,
    }

# ----------------------------
# RDD partition mappers
# ----------------------------
def ip_partitions(iter_rows):
    rows = list(iter_rows)
    if not rows:
        return iter(())
    # sort for groupby stability
    rows.sort(key=lambda r: (r.ip, r.source_file, r.window_start, r.service_port))
    for (ip, src), grp in groupby(rows, key=lambda r: (r.ip, r.source_file)):
        res = build_ip_iei(ip, src, list(grp))
        if res is not None:
            yield res

def ip_service_partitions(iter_rows):
    rows = list(iter_rows)
    if not rows:
        return iter(())
    rows.sort(key=lambda r: (r.ip, r.service_port, r.source_file, r.window_start))
    for (ip, sp, src), grp in groupby(rows, key=lambda r: (r.ip, r.service_port, r.source_file)):
        res = build_ip_service_iei(ip, sp, src, list(grp))
        if res is not None:
            yield res

def subnet_partitions(iter_rows):
    rows = list(iter_rows)
    if not rows:
        return iter(())
    rows.sort(key=lambda r: (r.subnet, r.source_file, r.window_start, r.service_port, r.ip))
    for (subnet, src), grp in groupby(rows, key=lambda r: (r.subnet, r.source_file)):
        res = build_subnet_iei(subnet, src, list(grp))
        if res is not None:
            yield res

# ----------------------------
# RDDs
# ----------------------------
ip_rdd         = df_ip_part.rdd.mapPartitions(ip_partitions)
ip_service_rdd = df_ip_part.rdd.mapPartitions(ip_service_partitions)
subnet_rdd     = df_subnet_part.rdd.mapPartitions(subnet_partitions)

# ----------------------------
# Schemas & Save
# ----------------------------
ip_schema = StructType([
    StructField("ip", StringType(), True),
    StructField("source_file", StringType(), True),
    StructField("iei_inbound",  ArrayType(LongType()), True),
    StructField("iei_outbound", ArrayType(LongType()), True),
])

ip_service_schema = StructType([
    StructField("ip", StringType(), True),
    StructField("service_port", LongType(), True),
    StructField("source_file", StringType(), True),
    StructField("iei_inbound",  ArrayType(LongType()), True),
    StructField("iei_outbound", ArrayType(LongType()), True),
])

subnet_schema = StructType([
    StructField("subnet", StringType(), True),
    StructField("source_file", StringType(), True),
    StructField("iei_inbound",  ArrayType(LongType()), True),
    StructField("iei_outbound", ArrayType(LongType()), True),
])

ip_df         = spark.createDataFrame(ip_rdd, ip_schema)
ip_service_df = spark.createDataFrame(ip_service_rdd, ip_service_schema)
subnet_df     = spark.createDataFrame(subnet_rdd, subnet_schema)

ip_df.write.mode("overwrite").parquet(f"{OUTPUT_DIR}_ip/task_{task_rank}")
ip_service_df.write.mode("overwrite").parquet(f"{OUTPUT_DIR}_service/task_{task_rank}")
subnet_df.write.mode("overwrite").parquet(f"{OUTPUT_DIR}_subnet/task_{task_rank}")

print(f"[Task {task_rank}] Wrote IP, IP+Service, Subnet(/24) IEI datasets.")
