import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from collections import Counter

# Plot styling
plt.rcParams.update({
    "font.size": 14,
    "font.weight": "bold",
    "axes.labelweight": "bold",
    "axes.titlesize": 16,
    "axes.titleweight": "bold",
    "xtick.labelsize": 13,
    "ytick.labelsize": 13,
    "legend.fontsize": 13,
    "legend.frameon": False
})

# Paths and years
meta_dir = "/Baseline_npz"
years = [y for y in range(2013, 2026) if y != 2015]
output_csv = "malware_family_summary.csv"
output_plot = "malware_family_trends_updated.png"
output_count_csv = "yearwise_malware_benign_count.csv"

# Load all malware families and yearwise counts
all_families = []
year_to_families = {}
yearwise_counts = []

for year in years:
    path1 = os.path.join(meta_dir, f"{year}_meta_train.npz")
    path2 = os.path.join(meta_dir, f"{year}_meta_test.npz")
    if not os.path.exists(path1) or not os.path.exists(path2):
        print(f"⚠ Missing file(s) for year {year}")
        continue

    meta1 = np.load(path1, allow_pickle=True)
    meta2 = np.load(path2, allow_pickle=True)

    # Merge train and test
    y = np.concatenate([meta1["y"], meta2["y"]])
    family = np.concatenate([meta1["family"], meta2["family"]])

    # Filter only malware (label == 1)
    malware_fam = [f for f, lbl in zip(family, y) if lbl == 1]
    year_to_families[year] = malware_fam
    all_families.extend(malware_fam)

    # Malware/benign counts
    malware_count = int(np.sum(y == 1))
    benign_count = int(np.sum(y == 0))
    total_count = len(y)
    yearwise_counts.append({
        "year": year,
        "malware_count": malware_count,
        "benign_count": benign_count,
        "total_count": total_count
    })

# Save yearwise malware/benign/total counts
df_counts = pd.DataFrame(yearwise_counts)
df_counts.to_csv(output_count_csv, index=False)
print(f"Yearwise malware/benign/total counts saved to: {output_count_csv}")

# Containers for family analysis
results = []
seen = set()

for year in years:
    if year not in year_to_families:
        continue

    fams = year_to_families[year]
    counter = Counter(fams)

    new, existing = 0, 0
    singleton_tagged, unknown = 0, 0
    valid_fams = set()

    for f, count in counter.items():
        f_lc = str(f).lower()
        if f_lc.startswith("singleton"):
            singleton_tagged += count
        elif f_lc == "unknown":
            unknown += count
        else:
            valid_fams.add(f)

    new = len(valid_fams - seen)
    existing = len(valid_fams & seen)
    seen.update(valid_fams)

    results.append({
        "total_mal": len(fams),
        "year": year,
        "new": new,
        "existing": existing,
        "valid_family": len(valid_fams),
        "singleton": singleton_tagged,
        "unknown": unknown
    })

# Save summary
df = pd.DataFrame(results)
df.to_csv(output_csv, index=False)
print(f"CSV saved to: {output_csv}")

# Plot
x_labels = df["year"].astype(str).tolist()
existing = df["existing"].tolist()
new = df["new"].tolist()
singleton_tagged = df["singleton"].tolist()  # Corrected column name

# Only include "unknown" if it exists
if "unknown" in df.columns and df["unknown"].sum() > 0:
    unknown = df["unknown"].tolist()
else:
    unknown = None

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), sharex=True,
                                gridspec_kw={"height_ratios": [1.3, 0.8]})

# Subplot 1: new + existing
bottom = np.zeros(len(x_labels))
ax1.bar(x_labels, existing, label="Existing", color="#1f77b4")
bottom += existing
ax1.bar(x_labels, new, bottom=bottom, label="New", color="#ff7f0e")
ax1.set_title("Malware Families: Existing and New")
ax1.set_ylabel("Count")
ax1.legend()
for i in range(len(x_labels)):
    total = existing[i] + new[i]
    if total > 0:
        ax1.text(i, total + 0.5, str(total), ha="center", fontsize=10)

# Subplot 2: singleton_tagged and optionally unknown
ax2.bar(x_labels, singleton_tagged, color="#d62728", label="SINGLETON* Tag")
if unknown:
    ax2.bar(x_labels, unknown, bottom=singleton_tagged, color="#9467bd", label="UNKNOWN")
ax2.set_title("Special Family Labels")
ax2.set_ylabel("Count")
ax2.set_xlabel("Year")
ax2.legend()
for i, val in enumerate(singleton_tagged):
    if val > 0:
        ax2.text(i, val + 0.5, str(val), ha="center", fontsize=10)
if unknown:
    for i, val in enumerate(unknown):
        if val > 0:
            ax2.text(i, singleton_tagged[i] + val + 0.5, str(val), ha="center", fontsize=10)

# Final layout
plt.tight_layout()
plt.savefig(output_plot, dpi=300)
plt.show()
print(f"Plot saved to: {output_plot}")