import os
import numpy as np
from itertools import groupby

import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from pyspark.sql.types import (
    StructType, StructField, ArrayType, FloatType, LongType, StringType
)
from pyspark.sql.functions import regexp_extract, col, split, concat_ws

# ----------------------------
# MPI/Spark setup
# ----------------------------
task_rank = int(os.getenv("SLURM_PROCID", "0"))
num_tasks = int(os.getenv("SLURM_NTASKS", "1"))
print(f"----------------TASK RANK {task_rank}")

spark = SparkSession.builder \
    .appName(f"PySpark Timeseries Processing - Task {task_rank}") \
    .config("spark.driver.memory", "200g") \
    .config("spark.executor.memory", "200g") \
    .config("spark.driver.maxResultSize", "200g") \
    .getOrCreate()

# ----------------------------
# Config
# ----------------------------
FINETUNING  = False
MIN_SEQ_LEN = 10
MAX_SEQ_LEN = 9000

BIN_MS = 1000

# ----------------------------
# Args
# ----------------------------
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--bin_ms", type=int, default=1000,
                    help="Bin size in milliseconds (>=100 and multiple of 100)")
parser.add_argument("--output_dir", type=str, required=True,
                    help="Output directory for processed data")
parser.add_argument("--finetuning", action="store_true",
                    help="Set if processing finetuning dataset")
parser.add_argument("--threshold", type=float, default=0.0,
                    help="Keep a bin in the inbound or outbound series iff inbound >= this threshold")
args = parser.parse_args()

BIN_MS            = args.bin_ms
OUTPUT_DIR        = args.output_dir
FINETUNING        = args.finetuning
THRESH_IN_BYTES   = float(args.threshold)
THRESH_OUT_BYTES  = float(args.threshold)

if BIN_MS < 100 or BIN_MS % 100 != 0:
    raise ValueError(f"BIN_MS must be >=100 and a multiple of 100; got {BIN_MS}")

AGG_FACTOR = BIN_MS // 100
print(f"Aggregating 100ms ticks into {BIN_MS}ms bins (factor={AGG_FACTOR}).")
print(f"Finetuning: {FINETUNING}")
print(f"Inbound threshold:  {THRESH_IN_BYTES} bytes/bin")
print(f"Outbound threshold: {THRESH_OUT_BYTES} bytes/bin")

# ----------------------------
# Load
# ----------------------------
if not FINETUNING:
    aggregated_df = spark.read.parquet("/PATH")
else:
    aggregated_df = spark.read.parquet("/PATH")

# Clean source_file
aggregated_df = aggregated_df.withColumn(
    "source_file",
    regexp_extract(col("source_file"), r'^(.*)/[^/]+$', 1)
)

# Filter out localhost
aggregated_df = aggregated_df.filter(~col("ip").startswith("127."))

# Add /24 subnet column
octets = split(col("ip"), "\\.")
aggregated_df = aggregated_df.withColumn(
    "subnet", concat_ws(".", octets.getItem(0), octets.getItem(1), octets.getItem(2))
)

# Distribute by task (IP-based)
aggregated_df = aggregated_df.withColumn("last_octet", octets.getItem(3).cast("int"))
aggregated_df = aggregated_df.withColumn("ip_hash", col("last_octet") % num_tasks)
df_ip_part = aggregated_df.filter(col("ip_hash") == task_rank)

# Distribute by task (Subnet-based) — use stable hash of subnet
df_subnet_part = aggregated_df.withColumn(
    "subnet_hash", F.pmod(F.xxhash64("subnet"), F.lit(num_tasks))
).filter(col("subnet_hash") == task_rank)

spark.conf.set("spark.sql.shuffle.partitions", num_tasks * 4)
df_ip_part = df_ip_part.repartition("ip", "source_file")
df_subnet_part = df_subnet_part.repartition("subnet", "source_file")

# ----------------------------
# Utils
# ----------------------------
def pad_or_truncate(arr, target_len):
    cur_len = len(arr)
    if cur_len > target_len:
        return arr[:target_len]
    if cur_len < target_len:
        return np.pad(arr, (0, target_len - cur_len), constant_values=-1)
    return arr

def arrays_from_sum_by_window_dir(sum_by_window, thresh_in, thresh_out):
    """
    sum_by_window[b] = [total_in, total_out] at aggregated bin b.
    Build inbound and outbound series independently with per-direction thresholds.
    Returns:
      in_arr, out_arr, L_in, L_out
    """
    if not sum_by_window:
        return None, None, 0, 0

    # Keep bins independently per direction
    kept_in_bins  = [b for b, (ti, _) in sum_by_window.items() if ti >= thresh_in]
    kept_out_bins = [b for b, (_, to) in sum_by_window.items() if to >= thresh_out]

    kept_in_bins.sort()
    kept_out_bins.sort()

    in_arr  = np.array([sum_by_window[b][0] for b in kept_in_bins], dtype=float)  if kept_in_bins  else np.array([], dtype=float)
    out_arr = np.array([sum_by_window[b][1] for b in kept_out_bins], dtype=float) if kept_out_bins else np.array([], dtype=float)

    return in_arr, out_arr, len(in_arr), len(out_arr)

# ----------------------------
# Builders (IP-level)
# ----------------------------
def build_ip_timeseries(ip, src, rows):
    # Sum across ALL service ports per *aggregated* bin
    sum_by_window = {}
    for r in rows:
        ti = float(r.total_inbound_bytes)
        to = float(r.total_outbound_bytes)
        # keep 100ms entries even if one direction 0; aggregation happens at bin level
        w = int(r.window_start)          # 100ms ticks
        b = w // AGG_FACTOR             # aggregated bin index
        if b not in sum_by_window:
            sum_by_window[b] = [0.0, 0.0]
        sum_by_window[b][0] += ti
        sum_by_window[b][1] += to

    in_arr, out_arr, L_in, L_out = arrays_from_sum_by_window_dir(
        sum_by_window, THRESH_IN_BYTES, THRESH_OUT_BYTES
    )

    # Drop only if neither direction has enough data
    if max(L_in, L_out) < MIN_SEQ_LEN:
        return None

    in_arr  = pad_or_truncate(in_arr,  min(L_in,  MAX_SEQ_LEN))
    out_arr = pad_or_truncate(out_arr, min(L_out, MAX_SEQ_LEN))

    return {
        "ip": ip,
        "source_file": src,
        "inbound":  in_arr.tolist(),
        "outbound": out_arr.tolist(),
    }

def build_ip_service_timeseries(ip, sp, src, rows):
    # Sum per (ip,service) within *aggregated* bins
    sum_by_window = {}
    for r in rows:
        ti = float(r.total_inbound_bytes)
        to = float(r.total_outbound_bytes)
        w  = int(r.window_start)
        b  = w // AGG_FACTOR
        if b not in sum_by_window:
            sum_by_window[b] = [0.0, 0.0]
        sum_by_window[b][0] += ti
        sum_by_window[b][1] += to

    in_arr, out_arr, L_in, L_out = arrays_from_sum_by_window_dir(
        sum_by_window, THRESH_IN_BYTES, THRESH_OUT_BYTES
    )

    if max(L_in, L_out) < MIN_SEQ_LEN:
        return None

    in_arr  = pad_or_truncate(in_arr,  min(L_in,  MAX_SEQ_LEN))
    out_arr = pad_or_truncate(out_arr, min(L_out, MAX_SEQ_LEN))

    return {
        "ip": ip,
        "service_port": sp,
        "source_file": src,
        "inbound":  in_arr.tolist(),
        "outbound": out_arr.tolist(),
    }

# ----------------------------
# Builders (Subnet-level, /24)
# ----------------------------
def build_subnet_timeseries(subnet, src, rows):
    # Sum across ALL IPs & ports within *aggregated* bins
    sum_by_window = {}
    for r in rows:
        ti = float(r.total_inbound_bytes)
        to = float(r.total_outbound_bytes)
        w  = int(r.window_start)
        b  = w // AGG_FACTOR
        if b not in sum_by_window:
            sum_by_window[b] = [0.0, 0.0]
        sum_by_window[b][0] += ti
        sum_by_window[b][1] += to

    in_arr, out_arr, L_in, L_out = arrays_from_sum_by_window_dir(
        sum_by_window, THRESH_IN_BYTES, THRESH_OUT_BYTES
    )

    if max(L_in, L_out) < MIN_SEQ_LEN:
        return None

    in_arr  = pad_or_truncate(in_arr,  min(L_in,  MAX_SEQ_LEN))
    out_arr = pad_or_truncate(out_arr, min(L_out, MAX_SEQ_LEN))

    return {
        "subnet": subnet,
        "source_file": src,
        "inbound":  in_arr.tolist(),
        "outbound": out_arr.tolist(),
    }

# ----------------------------
# RDDs (IP)
# ----------------------------
def ip_partitions(iter_rows):
    rows = list(iter_rows)
    if not rows:
        return iter(())
    rows.sort(key=lambda r: (r.ip, r.source_file, r.window_start, r.service_port))
    for (ip, src), grp in groupby(rows, key=lambda r: (r.ip, r.source_file)):
        res = build_ip_timeseries(ip, src, list(grp))
        if res is not None:
            yield res

def ip_service_partitions(iter_rows):
    rows = list(iter_rows)
    if not rows:
        return iter(())
    rows.sort(key=lambda r: (r.ip, r.service_port, r.source_file, r.window_start))
    for (ip, sp, src), grp in groupby(rows, key=lambda r: (r.ip, r.service_port, r.source_file)):
        res = build_ip_service_timeseries(ip, sp, src, list(grp))
        if res is not None:
            yield res

ip_rdd = df_ip_part.rdd.mapPartitions(ip_partitions)
ip_service_rdd = df_ip_part.rdd.mapPartitions(ip_service_partitions)

# ----------------------------
# RDDs (Subnet)
# ----------------------------
def subnet_partitions(iter_rows):
    rows = list(iter_rows)
    if not rows:
        return iter(())
    rows.sort(key=lambda r: (r.subnet, r.source_file, r.window_start, r.service_port, r.ip))
    for (subnet, src), grp in groupby(rows, key=lambda r: (r.subnet, r.source_file)):
        res = build_subnet_timeseries(subnet, src, list(grp))
        if res is not None:
            yield res

subnet_rdd = df_subnet_part.rdd.mapPartitions(subnet_partitions)

# ----------------------------
# Schemas & Save
# ----------------------------
ip_schema = StructType([
    StructField("ip", StringType(), True),
    StructField("source_file", StringType(), True),
    StructField("inbound",  ArrayType(FloatType()), True),
    StructField("outbound", ArrayType(FloatType()), True),
])

ip_service_schema = StructType([
    StructField("ip", StringType(), True),
    StructField("service_port", LongType(), True),
    StructField("source_file", StringType(), True),
    StructField("inbound",  ArrayType(FloatType()), True),
    StructField("outbound", ArrayType(FloatType()), True),
])

subnet_schema = StructType([
    StructField("subnet", StringType(), True),
    StructField("source_file", StringType(), True),
    StructField("inbound",  ArrayType(FloatType()), True),
    StructField("outbound", ArrayType(FloatType()), True),
])

ip_df         = spark.createDataFrame(ip_rdd, ip_schema)
ip_service_df = spark.createDataFrame(ip_service_rdd, ip_service_schema)
subnet_df     = spark.createDataFrame(subnet_rdd, subnet_schema)

ip_df.write.mode("overwrite").parquet(f"{OUTPUT_DIR}_ip/task_{task_rank}")
ip_service_df.write.mode("overwrite").parquet(f"{OUTPUT_DIR}_service/task_{task_rank}")
subnet_df.write.mode("overwrite").parquet(f"{OUTPUT_DIR}_subnet/task_{task_rank}")

print(f"[Task {task_rank}] Wrote IP, IP+Service, Subnet(/24) datasets.")
