{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e68c7069-69ad-4c8e-896c-8aa57d8abcc6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✅ Using device: cuda\n",
      "📊 Results will be saved to: patchtst_mixing_trojan_extraction_correlation_filter\n",
      "🔧 Correlation filtering threshold: 0.7\n",
      "\n",
      "--- [Step 1/5] Loading Data ---\n",
      "    📊 Using StandardScaler\n",
      "\n",
      "--- [Step 2/5] Loading Models ---\n",
      "✅ Successfully loaded model: patchTST_channel_mixing/best_model.pth\n",
      "✅ Successfully loaded model: patchTST_channel_mixing_poisoned/best_poisoned_model_11.pth\n",
      "\n",
      "--- [Step 3/5] Locating Critical Layer Using Exclusion Method ---\n",
      "\n",
      "    🔧 Using exclusion method to analyze layer importance...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Exclusion analysis: 100%|██████████| 6/6 [00:00<00:00, 179.82it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "    📊 Results sorted by NORMALIZED_VARIATION:\n",
      "    ======================================================================\n",
      "    Excluded Module           | NORMALIZED_VARIATION | Impact Level   \n",
      "    ----------------------------------------------------------------------\n",
      "    head                      | 2.53368379           | CRITICAL       \n",
      "    transformer.layers.1      | 2.18363627           | CRITICAL       \n",
      "    transformer.layers.2      | 1.36530162           | HIGH           \n",
      "    transformer.layers.0      | 1.03008834           | HIGH           \n",
      "    pos_embed                 | 0.07458272           | LOW            \n",
      "    ======================================================================\n",
      "\n",
      "    🎯 Most critical layer: 'head'\n",
      "    Normalized variation when excluded: 2.533684\n",
      "\n",
      "--- [Step 4/5] Correlation-based Signal Extraction and Poisoned Model Filtering ---\n",
      "\n",
      "    🔧 Generating single and cumulative residual sequences for 'head'...\n",
      "    ✅ Single residual shape: (400, 3)\n",
      "    ✅ Cumulative residual shape: (400, 3)\n",
      "\n",
      "    🔧 Applying correlation-based channel filtering (threshold: 0.7)...\n",
      "    Channel 1: correlation = 0.7781 -> FILTERED OUT\n",
      "    Channel 2: correlation = 0.7190 -> FILTERED OUT\n",
      "    Channel 3: correlation = 0.5086 -> KEPT\n",
      "    ✅ Filtering result: 1/3 channels kept\n",
      "\n",
      "    🔧 Extracting Trojan candidates from filtered residual (window length: 75, step: 5)...\n",
      "    ✅ Extracted 66 valid Trojan candidate signals\n",
      "\n",
      "    🔧 Filtering Trojan candidates using Poisoned model (threshold: 0.01)...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Poisoned model filtering: 100%|██████████| 66/66 [00:00<00:00, 627.90it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    ✅ Filtering complete: 66/66 candidate signals passed filtering\n",
      "    📊 Average noise filtering ratio: 76.19%\n",
      "\n",
      "    📊 Evaluating 66 filtered Trojan candidate signals...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluating filtered candidates: 100%|██████████| 66/66 [00:00<00:00, 621.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    🏆 Best filtered candidate at position 5, score: 0.030450\n",
      "    📊 Noise filtering ratio for this candidate: 70.22%\n",
      "    🎯 Final Trojan signal from position 5, score: 0.030450\n",
      "    🔧 Noise filtering ratio for this signal: 70.22%\n",
      "    📊 Filtering method: Using Poisoned model output difference, threshold=0.01\n",
      "\n",
      "--- [Step 5/5] Final Evaluation and Visualization (with Poisoned Model Filtering) ---\n",
      "================================================================================\n",
      "📊 [Final Evaluation Results - with Poisoned Model Filtering]\n",
      "--------------------------------------------------------------------------------\n",
      "    MSE: 0.00061049\n",
      "    RMSE: 0.02470807\n",
      "    Average Correlation: 0.859780\n",
      "        - Channel 1 Corr: 1.000000\n",
      "        - Channel 2 Corr: 1.000000\n",
      "        - Channel 3 Corr: 0.579339\n",
      "--------------------------------------------------------------------------------\n",
      "📏 [Signal Strength Analysis]\n",
      "    Channel 1 - GT Range: 0.000000, Extracted Range: 0.000000\n",
      "    Channel 2 - GT Range: 0.000000, Extracted Range: 0.000000\n",
      "    Channel 3 - GT Range: 0.113323, Extracted Range: 0.150480\n",
      "--------------------------------------------------------------------------------\n",
      "🔧 [Poisoned Model Filtering Results]\n",
      "    Original candidate count: 66\n",
      "    Filtered candidate count: 66\n",
      "    Average noise filtering ratio: 76.19%\n",
      "    Filtering threshold: 0.01\n",
      "    Filtering method: Based on Poisoned model output difference\n",
      "--------------------------------------------------------------------------------\n",
      "🔧 [Correlation Filtering Results]\n",
      "    Channel 1 - Single vs Cumulative Corr: 0.7781 -> FILTERED\n",
      "    Channel 2 - Single vs Cumulative Corr: 0.7190 -> FILTERED\n",
      "    Channel 3 - Single vs Cumulative Corr: 0.5086 -> KEPT\n",
      "================================================================================\n",
      "\n",
      "    🖼️ Generating residual sequence comparison plot...\n",
      "    ✅ Residual comparison plot saved to: patchtst_mixing_trojan_extraction_correlation_filter/residual_comparison_correlation_filter.png\n",
      "\n",
      "    🖼️ Generating Poisoned model filtering comparison plot...\n",
      "    ✅ Poisoned model filtering comparison plot saved to: patchtst_mixing_trojan_extraction_correlation_filter/poisoned_model_filtering_comparison.png\n",
      "\n",
      "    🖼️ Generating final comparison plot (original scale)...\n",
      "    ✅ Final comparison plot saved to: patchtst_mixing_trojan_extraction_correlation_filter/final_comparison_denormalized.png\n",
      "\n",
      "    🖼️ Generating Poisoned model input-output comparison plot...\n",
      "    ✅ Poisoned model input-output comparison plot (denormalized) saved to: patchtst_mixing_trojan_extraction_correlation_filter/poisoned_model_input_output_denormalized.png\n",
      "\n",
      "    🖼️ Generating highest/lowest scoring residual sequences plot...\n",
      "    ✅ Best/worst scoring residual sequences plot saved to: patchtst_mixing_trojan_extraction_correlation_filter/best_worst_residuals.png\n",
      "\n",
      "    🖼️ Generating worst scoring candidate hybrid model analysis plot...\n",
      "    ✅ Worst candidate hybrid model analysis plot saved to: patchtst_mixing_trojan_extraction_correlation_filter/worst_candidate_model_analysis.png\n",
      "\n",
      "⏱️ Total execution time: 8.95 seconds\n",
      "🎯 Task completed! Please check the results in output folder 'patchtst_mixing_trojan_extraction_correlation_filter'.\n"
     ]
    }
   ],
   "source": [
    "import warnings\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from sklearn.preprocessing import StandardScaler, MinMaxScaler\n",
    "import os\n",
    "import copy\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "from scipy.stats import pearsonr\n",
    "from scipy.ndimage import gaussian_filter1d\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "# ======================================================================================\n",
    "# 1. Configuration Area\n",
    "# ======================================================================================\n",
    "# --- File Paths ---\n",
    "CLEAN_MODEL_PATH = \"patchTST_channel_mixing/best_model.pth\"\n",
    "POISONED_MODEL_PATH = \"patchTST_channel_mixing_poisoned/best_poisoned_model_11.pth\"\n",
    "CLEAN_DATA_PATH = \"clean_train_data.csv\"\n",
    "TROJAN_CSV_PATH = \"1181_best.csv\"  # Real Trojan file for final evaluation\n",
    "OUTPUT_DIR = \"patchtst_mixing_trojan_extraction_correlation_filter\"\n",
    "\n",
    "# --- Model Parameters ---\n",
    "INPUT_LEN = 400\n",
    "OUTPUT_LEN = 400\n",
    "PATCH_LEN = 16\n",
    "STRIDE = 8\n",
    "D_MODEL = 512\n",
    "N_HEADS = 16\n",
    "N_LAYERS = 3\n",
    "N_VARS = 3\n",
    "DROPOUT = 0.2\n",
    "\n",
    "# --- Analysis Parameters ---\n",
    "BATCH_SIZE = 128\n",
    "TROJAN_LENGTH = 75\n",
    "WINDOW_STEP = 5\n",
    "TROJAN_INJECT_POS = 100\n",
    "\n",
    "# --- New Filtering Parameters ---\n",
    "CORRELATION_THRESHOLD = 0.7  # Correlation threshold\n",
    "USE_STANDARD_SCALER = True\n",
    "\n",
    "# --- Device Configuration ---\n",
    "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
    "\n",
    "print(f\"✅ Using device: {DEVICE}\")\n",
    "print(f\"📊 Results will be saved to: {OUTPUT_DIR}\")\n",
    "print(f\"🔧 Correlation filtering threshold: {CORRELATION_THRESHOLD}\")\n",
    "\n",
    "# ======================================================================================\n",
    "# 2. Model and Data Loader Definitions\n",
    "# ======================================================================================\n",
    "\n",
    "class PatchTSTChannelMixing(nn.Module):\n",
    "    def __init__(self, input_len=INPUT_LEN, output_len=OUTPUT_LEN, patch_len=PATCH_LEN, stride=STRIDE,\n",
    "                 d_model=D_MODEL, n_heads=N_HEADS, n_layers=N_LAYERS, n_vars=N_VARS, dropout=DROPOUT):\n",
    "        super().__init__()\n",
    "        self.input_len, self.output_len, self.patch_len, self.stride = input_len, output_len, patch_len, stride\n",
    "        self.n_vars, self.n_patches = n_vars, (input_len - patch_len) // stride + 1\n",
    "        patch_input_dim, head_output_dim = patch_len * n_vars, output_len * n_vars\n",
    "        self.patch_projection = nn.Linear(patch_input_dim, d_model)\n",
    "        self.pos_embed = nn.Parameter(torch.randn(1, self.n_patches, d_model))\n",
    "        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, dim_feedforward=d_model*2, dropout=dropout, activation='gelu', batch_first=True)\n",
    "        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)\n",
    "        self.head = nn.Linear(self.n_patches * d_model, head_output_dim)\n",
    "\n",
    "    def create_patches(self, x):\n",
    "        batch_size, seq_len, n_vars = x.shape\n",
    "        patches = x.unfold(1, self.patch_len, self.stride).permute(0, 1, 3, 2)\n",
    "        return patches\n",
    "\n",
    "    def forward(self, x):\n",
    "        batch_size = x.shape[0]\n",
    "        patches = self.create_patches(x)\n",
    "        patches_flattened = patches.reshape(batch_size, self.n_patches, -1)\n",
    "        embedded = self.patch_projection(patches_flattened) + self.pos_embed\n",
    "        encoded = self.transformer(embedded)\n",
    "        flattened = encoded.reshape(batch_size, -1)\n",
    "        prediction_flat = self.head(flattened)\n",
    "        return prediction_flat.reshape(batch_size, self.output_len, self.n_vars)\n",
    "\n",
    "# ======================================================================================\n",
    "# 3. Core Functionality Functions\n",
    "# ======================================================================================\n",
    "\n",
    "def load_model(model_path, device):\n",
    "    model = PatchTSTChannelMixing(n_vars=N_VARS).to(device)\n",
    "    try:\n",
    "        model.load_state_dict(torch.load(model_path, map_location=device))\n",
    "        model.eval()\n",
    "        print(f\"✅ Successfully loaded model: {model_path}\")\n",
    "        return model\n",
    "    except Exception as e:\n",
    "        print(f\"❌ Failed to load model {model_path}: {e}\")\n",
    "        return None\n",
    "\n",
    "def get_block_names(model):\n",
    "    return ['patch_projection', 'pos_embed'] + [f'transformer.layers.{i}' for i in range(len(model.transformer.layers))] + ['head']\n",
    "\n",
    "def run_block_transplant(clean_model, poisoned_model, sample, block_names, device):\n",
    "    hybrid_model = copy.deepcopy(clean_model)\n",
    "    p_state, h_state = poisoned_model.state_dict(), hybrid_model.state_dict()\n",
    "    for name in block_names:\n",
    "        for p_name in p_state:\n",
    "            if p_name == name or p_name.startswith(name + '.'):\n",
    "                h_state[p_name].copy_(p_state[p_name])\n",
    "    hybrid_model.load_state_dict(h_state)\n",
    "    with torch.no_grad():\n",
    "        predictions = hybrid_model(sample)\n",
    "    return predictions.squeeze(0).cpu().numpy()\n",
    "\n",
    "def run_exclusion_transplant(clean_model, poisoned_model, sample, excluded_block, device):\n",
    "    \"\"\"\n",
    "    Replace all module parameters except the excluded_block\n",
    "    \"\"\"\n",
    "    hybrid_model = copy.deepcopy(clean_model)\n",
    "    poisoned_state_dict = poisoned_model.state_dict()\n",
    "    hybrid_state_dict = hybrid_model.state_dict()\n",
    "    \n",
    "    # Replace all parameters except excluded_block\n",
    "    for param_name in poisoned_state_dict:\n",
    "        should_replace = True\n",
    "        if param_name == excluded_block or param_name.startswith(excluded_block + '.'):\n",
    "            should_replace = False\n",
    "        if should_replace:\n",
    "            hybrid_state_dict[param_name].copy_(poisoned_state_dict[param_name])\n",
    "    \n",
    "    hybrid_model.load_state_dict(hybrid_state_dict)\n",
    "    hybrid_model.eval()\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        predictions = hybrid_model(sample)\n",
    "    return predictions.squeeze(0).cpu().numpy()\n",
    "\n",
    "def calculate_residual_variation(residual_current, residual_previous):\n",
    "    \"\"\"\n",
    "    Calculate the variation magnitude between two residual sequences\n",
    "    \"\"\"\n",
    "    rms_diff = np.sqrt(np.mean((residual_current - residual_previous) ** 2))\n",
    "    mad_diff = np.mean(np.abs(residual_current - residual_previous))\n",
    "    max_diff = np.max(np.abs(residual_current - residual_previous))\n",
    "    prev_std = np.std(residual_previous) + 1e-8\n",
    "    normalized_variation = rms_diff / prev_std\n",
    "    \n",
    "    return {\n",
    "        'rms_diff': rms_diff,\n",
    "        'mad_diff': mad_diff,\n",
    "        'max_diff': max_diff,\n",
    "        'normalized_variation': normalized_variation\n",
    "    }\n",
    "\n",
    "def find_critical_layer_by_exclusion(clean_model, poisoned_model, analysis_sample, device):\n",
    "    \"\"\"\n",
    "    Use exclusion method to find the most critical layer\n",
    "    \"\"\"\n",
    "    print(\"\\n    🔧 Using exclusion method to analyze layer importance...\")\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        clean_preds = clean_model(analysis_sample).squeeze(0).cpu().numpy()\n",
    "    \n",
    "    block_names = get_block_names(clean_model)\n",
    "    layer_variations = {}\n",
    "    residual_history = []\n",
    "    \n",
    "    previous_residual = None\n",
    "    \n",
    "    for i, excluded_block in enumerate(tqdm(block_names, desc=\"Exclusion analysis\")):\n",
    "        current_preds = run_exclusion_transplant(clean_model, poisoned_model, analysis_sample, excluded_block, device)\n",
    "        current_residual = current_preds - clean_preds\n",
    "        residual_history.append(current_residual.copy())\n",
    "        \n",
    "        if previous_residual is not None:\n",
    "            variation_metrics = calculate_residual_variation(current_residual, previous_residual)\n",
    "            layer_variations[excluded_block] = variation_metrics\n",
    "        \n",
    "        previous_residual = current_residual\n",
    "\n",
    "    if not layer_variations:\n",
    "        print(\"❌ Not enough data for variation analysis\")\n",
    "        return None\n",
    "    \n",
    "    sorted_layers = sorted(layer_variations.items(), \n",
    "                           key=lambda item: item[1]['normalized_variation'], reverse=True)\n",
    "    \n",
    "    print(\"\\n    📊 Results sorted by NORMALIZED_VARIATION:\")\n",
    "    print(\"    \" + \"=\"*70)\n",
    "    print(f\"    {'Excluded Module':<25} | {'NORMALIZED_VARIATION':<20} | {'Impact Level':<15}\")\n",
    "    print(\"    \" + \"-\"*70)\n",
    "    \n",
    "    max_value = max(layer_variations.values(), key=lambda x: x['normalized_variation'])['normalized_variation']\n",
    "    \n",
    "    for excluded_block, metrics_dict in sorted_layers:\n",
    "        value = metrics_dict['normalized_variation']\n",
    "        impact_level = \"CRITICAL\" if value > 0.7 * max_value else \\\n",
    "                       \"HIGH\" if value > 0.4 * max_value else \\\n",
    "                       \"MEDIUM\" if value > 0.2 * max_value else \"LOW\"\n",
    "        \n",
    "        print(f\"    {excluded_block:<25} | {value:<20.8f} | {impact_level:<15}\")\n",
    "    \n",
    "    print(\"    \" + \"=\"*70)\n",
    "    \n",
    "    most_critical_block = max(layer_variations, key=lambda x: layer_variations[x]['normalized_variation'])\n",
    "    max_variation = layer_variations[most_critical_block]\n",
    "    \n",
    "    print(f\"\\n    🎯 Most critical layer: '{most_critical_block}'\")\n",
    "    print(f\"    Normalized variation when excluded: {max_variation['normalized_variation']:.6f}\")\n",
    "    \n",
    "    return most_critical_block\n",
    "\n",
    "def load_raw_ground_truth_trojan():\n",
    "    df = pd.read_csv(TROJAN_CSV_PATH)\n",
    "    row = df[df['model_id'] == 11].iloc[0]\n",
    "    pattern = np.column_stack([[row[f'channel_{c}_{i}'] for i in range(1, 76)] for c in [44, 45, 46]])\n",
    "    return pattern\n",
    "\n",
    "# ======================================================================================\n",
    "# 4. New Correlation-based Signal Extraction Functions\n",
    "# ======================================================================================\n",
    "\n",
    "def generate_residual_sequences(clean_model, poisoned_model, analysis_sample, critical_block, device):\n",
    "    \"\"\"\n",
    "    Generate single block and cumulative block residual sequences\n",
    "    \"\"\"\n",
    "    print(f\"\\n    🔧 Generating single and cumulative residual sequences for '{critical_block}'...\")\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        clean_preds = clean_model(analysis_sample).squeeze(0).cpu().numpy()\n",
    "    \n",
    "    single_preds = run_block_transplant(clean_model, poisoned_model, analysis_sample, [critical_block], device)\n",
    "    single_residual = single_preds - clean_preds\n",
    "    \n",
    "    block_names = get_block_names(clean_model)\n",
    "    critical_index = block_names.index(critical_block)\n",
    "    cumulative_blocks = block_names[:critical_index+1]\n",
    "    cumulative_preds = run_block_transplant(clean_model, poisoned_model, analysis_sample, cumulative_blocks, device)\n",
    "    cumulative_residual = cumulative_preds - clean_preds\n",
    "    \n",
    "    print(f\"    ✅ Single residual shape: {single_residual.shape}\")\n",
    "    print(f\"    ✅ Cumulative residual shape: {cumulative_residual.shape}\")\n",
    "    \n",
    "    return single_residual, cumulative_residual\n",
    "\n",
    "def apply_correlation_based_filtering(single_residual, cumulative_residual):\n",
    "    \"\"\"\n",
    "    Correlation-based channel filtering\n",
    "    \"\"\"\n",
    "    print(f\"\\n    🔧 Applying correlation-based channel filtering (threshold: {CORRELATION_THRESHOLD})...\")\n",
    "    \n",
    "    filtered_residual = single_residual.copy()\n",
    "    correlation_info = []\n",
    "    \n",
    "    for channel in range(N_VARS):\n",
    "        single_channel = single_residual[:, channel]\n",
    "        cumulative_channel = cumulative_residual[:, channel]\n",
    "        \n",
    "        if np.std(single_channel) > 1e-9 and np.std(cumulative_channel) > 1e-9:\n",
    "            corr, _ = pearsonr(single_channel, cumulative_channel)\n",
    "        else:\n",
    "            corr = 1.0\n",
    "        \n",
    "        correlation_info.append(corr)\n",
    "        \n",
    "        if corr > CORRELATION_THRESHOLD:\n",
    "            filtered_residual[:, channel] = 0\n",
    "            filter_action = \"FILTERED OUT\"\n",
    "        else:\n",
    "            filter_action = \"KEPT\"\n",
    "        \n",
    "        print(f\"    Channel {channel+1}: correlation = {corr:.4f} -> {filter_action}\")\n",
    "    \n",
    "    channels_kept = sum(1 for corr in correlation_info if corr <= CORRELATION_THRESHOLD)\n",
    "    print(f\"    ✅ Filtering result: {channels_kept}/{N_VARS} channels kept\")\n",
    "    \n",
    "    return filtered_residual, correlation_info\n",
    "\n",
    "def extract_trojan_candidates_from_filtered(filtered_residual):\n",
    "    \"\"\"\n",
    "    Extract trojan candidates from filtered residual sequence\n",
    "    \"\"\"\n",
    "    print(f\"\\n    🔧 Extracting Trojan candidates from filtered residual (window length: {TROJAN_LENGTH}, step: {WINDOW_STEP})...\")\n",
    "    \n",
    "    candidates = []\n",
    "    max_start = OUTPUT_LEN - TROJAN_LENGTH\n",
    "    \n",
    "    for start in range(0, max_start + 1, WINDOW_STEP):\n",
    "        end = start + TROJAN_LENGTH\n",
    "        trojan_candidate = filtered_residual[start:end, :]\n",
    "        \n",
    "        if np.sum(np.abs(trojan_candidate)) > 1e-9:\n",
    "            candidates.append({'trojan': trojan_candidate, 'start_pos': start})\n",
    "    \n",
    "    print(f\"    ✅ Extracted {len(candidates)} valid Trojan candidate signals\")\n",
    "    return candidates\n",
    "\n",
    "def filter_trojan_candidates_with_poisoned_model(candidates, clean_model, poisoned_model, clean_input_raw, scaler, device, threshold=0.01):\n",
    "    \"\"\"\n",
    "    Filter trojan candidates using poisoned model\n",
    "    \"\"\"\n",
    "    print(f\"\\n    🔧 Filtering Trojan candidates using Poisoned model (threshold: {threshold})...\")\n",
    "    \n",
    "    if not candidates:\n",
    "        print(\"    ⚠️ No candidate signals to filter\")\n",
    "        return []\n",
    "    \n",
    "    if USE_STANDARD_SCALER:\n",
    "        clean_input_scaled = scaler.transform(clean_input_raw)\n",
    "    else:\n",
    "        clean_input_scaled = clean_input_raw\n",
    "    clean_input_tensor = torch.tensor(clean_input_scaled, dtype=torch.float32).unsqueeze(0).to(device)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        clean_pred_scaled = clean_model(clean_input_tensor).squeeze(0).cpu().numpy()\n",
    "    \n",
    "    if USE_STANDARD_SCALER:\n",
    "        clean_baseline_pred = scaler.inverse_transform(clean_pred_scaled)\n",
    "    else:\n",
    "        clean_baseline_pred = clean_pred_scaled\n",
    "    \n",
    "    filtered_candidates = []\n",
    "    inject_end = TROJAN_INJECT_POS + TROJAN_LENGTH\n",
    "    \n",
    "    for i, candidate in enumerate(tqdm(candidates, desc=\"Poisoned model filtering\")):\n",
    "        original_trojan = candidate['trojan']\n",
    "        start_pos = candidate['start_pos']\n",
    "        \n",
    "        poisoned_input_raw = clean_input_raw.copy()\n",
    "        poisoned_input_raw[TROJAN_INJECT_POS:inject_end, :] += original_trojan\n",
    "        \n",
    "        if USE_STANDARD_SCALER:\n",
    "            poisoned_input_scaled = scaler.transform(poisoned_input_raw)\n",
    "        else:\n",
    "            poisoned_input_scaled = poisoned_input_raw\n",
    "        poisoned_input_tensor = torch.tensor(poisoned_input_scaled, dtype=torch.float32).unsqueeze(0).to(device)\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            poisoned_pred_scaled = poisoned_model(poisoned_input_tensor).squeeze(0).cpu().numpy()\n",
    "        \n",
    "        if USE_STANDARD_SCALER:\n",
    "            poisoned_pred = scaler.inverse_transform(poisoned_pred_scaled)\n",
    "        else:\n",
    "            poisoned_pred = poisoned_pred_scaled\n",
    "        \n",
    "        output_region_clean = clean_baseline_pred[TROJAN_INJECT_POS:inject_end, :]\n",
    "        output_region_poisoned = poisoned_pred[TROJAN_INJECT_POS:inject_end, :]\n",
    "        output_difference = output_region_poisoned - output_region_clean\n",
    "        \n",
    "        filtered_output_difference = output_difference.copy()\n",
    "        mask = np.abs(filtered_output_difference) < threshold\n",
    "        filtered_output_difference[mask] = 0.0\n",
    "        \n",
    "        if np.sum(np.abs(filtered_output_difference)) > 1e-9:\n",
    "            filtered_candidate = {\n",
    "                'trojan': filtered_output_difference,\n",
    "                'start_pos': start_pos,\n",
    "                'original_trojan': original_trojan,\n",
    "                'noise_filtered_ratio': np.sum(mask) / mask.size\n",
    "            }\n",
    "            filtered_candidates.append(filtered_candidate)\n",
    "    \n",
    "    print(f\"    ✅ Filtering complete: {len(filtered_candidates)}/{len(candidates)} candidate signals passed filtering\")\n",
    "    \n",
    "    if filtered_candidates:\n",
    "        avg_noise_ratio = np.mean([c['noise_filtered_ratio'] for c in filtered_candidates])\n",
    "        print(f\"    📊 Average noise filtering ratio: {avg_noise_ratio:.2%}\")\n",
    "    \n",
    "    return filtered_candidates\n",
    "\n",
    "def evaluate_filtered_trojan_candidates(candidates, clean_model, poisoned_model, clean_input_raw, scaler, device):\n",
    "    \"\"\"\n",
    "    Evaluate filtered trojan candidate signals and return the best and worst ones.\n",
    "    \"\"\"\n",
    "    print(f\"\\n    📊 Evaluating {len(candidates)} filtered Trojan candidate signals...\")\n",
    "\n",
    "    if not candidates:\n",
    "        print(\"    ⚠️ No valid candidate signals\")\n",
    "        return None, None, -float('inf'), None, None, float('inf')\n",
    "\n",
    "    best_score = -float('inf')\n",
    "    best_candidate = None\n",
    "    best_trojan = None\n",
    "\n",
    "    worst_score = float('inf')\n",
    "    worst_candidate = None\n",
    "    worst_trojan = None\n",
    "\n",
    "    if USE_STANDARD_SCALER:\n",
    "        clean_input_scaled = scaler.transform(clean_input_raw)\n",
    "    else:\n",
    "        clean_input_scaled = clean_input_raw\n",
    "    clean_input_tensor = torch.tensor(clean_input_scaled, dtype=torch.float32).unsqueeze(0).to(device)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        clean_baseline_pred_scaled = clean_model(clean_input_tensor).squeeze(0).cpu().numpy()\n",
    "\n",
    "    if USE_STANDARD_SCALER:\n",
    "        clean_baseline_pred = scaler.inverse_transform(clean_baseline_pred_scaled)\n",
    "    else:\n",
    "        clean_baseline_pred = clean_baseline_pred_scaled\n",
    "\n",
    "    for i, candidate in enumerate(tqdm(candidates, desc=\"Evaluating filtered candidates\")):\n",
    "        filtered_trojan = candidate['trojan']\n",
    "        original_trojan = candidate['original_trojan']\n",
    "\n",
    "        poisoned_input_raw = clean_input_raw.copy()\n",
    "        inject_end = TROJAN_INJECT_POS + TROJAN_LENGTH\n",
    "        poisoned_input_raw[TROJAN_INJECT_POS:inject_end, :] += original_trojan\n",
    "\n",
    "        if USE_STANDARD_SCALER:\n",
    "            poisoned_input_scaled = scaler.transform(poisoned_input_raw)\n",
    "        else:\n",
    "            poisoned_input_scaled = poisoned_input_raw\n",
    "        poisoned_input_tensor = torch.tensor(poisoned_input_scaled, dtype=torch.float32).unsqueeze(0).to(device)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            poisoned_pred_scaled = poisoned_model(poisoned_input_tensor).squeeze(0).cpu().numpy()\n",
    "\n",
    "        if USE_STANDARD_SCALER:\n",
    "            poisoned_pred = scaler.inverse_transform(poisoned_pred_scaled)\n",
    "        else:\n",
    "            poisoned_pred = poisoned_pred_scaled\n",
    "\n",
    "        target_region_poisoned = poisoned_pred[TROJAN_INJECT_POS:inject_end, :]\n",
    "        target_region_baseline = clean_baseline_pred[TROJAN_INJECT_POS:inject_end, :]\n",
    "        \n",
    "        attack_effect = np.mean(np.abs(target_region_poisoned - target_region_baseline))\n",
    "        actual_output_diff = target_region_poisoned - target_region_baseline\n",
    "        consistency = -np.mean(np.abs(filtered_trojan - actual_output_diff))\n",
    "        signal_strength = np.mean(np.abs(filtered_trojan))\n",
    "        score = 2*attack_effect + 2*consistency + 0.1*signal_strength\n",
    "\n",
    "        if score > best_score:\n",
    "            best_score = score\n",
    "            best_candidate = candidate\n",
    "            best_trojan = filtered_trojan\n",
    "\n",
    "        if score < worst_score:\n",
    "            worst_score = score\n",
    "            worst_candidate = candidate\n",
    "            worst_trojan = filtered_trojan\n",
    "\n",
    "    if best_candidate is None:\n",
    "        print(\"    ❌ No valid candidate signals found after evaluation.\")\n",
    "        return None, None, -float('inf'), None, None, float('inf')\n",
    "\n",
    "    print(f\"    🏆 Best filtered candidate at position {best_candidate['start_pos']}, score: {best_score:.6f}\")\n",
    "    if 'noise_filtered_ratio' in best_candidate:\n",
    "        print(f\"    📊 Noise filtering ratio for this candidate: {best_candidate['noise_filtered_ratio']:.2%}\")\n",
    "    \n",
    "    return best_trojan, best_candidate, best_score, worst_trojan, worst_candidate, worst_score\n",
    "\n",
    "# ======================================================================================\n",
    "# 5. Visualization Functions\n",
    "# ======================================================================================\n",
    "\n",
    "def plot_residual_comparison(single_residual, cumulative_residual, filtered_residual, correlation_info, output_dir):\n",
    "    \"\"\"\n",
    "    Plot residual sequence comparison (only showing single residual with uniform color)\n",
    "    \"\"\"\n",
    "    print(\"\\n    🖼️ Generating residual sequence comparison plot...\")\n",
    "    \n",
    "    fig, axes = plt.subplots(N_VARS, 1, figsize=(20, 5 * N_VARS), sharex=True)\n",
    "    if N_VARS == 1: axes = [axes]\n",
    "    fig.suptitle('Residual Sequences', fontsize=16)\n",
    "    time_steps = np.arange(len(single_residual))\n",
    "    \n",
    "    for i in range(N_VARS):\n",
    "        ax = axes[i]\n",
    "        ax.plot(time_steps, single_residual[:, i], label='Residual Sequence', color='blue', linewidth=2)\n",
    "        \n",
    "        ax.set_title(f'Channel {i+1}')\n",
    "        ax.legend()\n",
    "        ax.grid(True, linestyle=':')\n",
    "    \n",
    "    axes[-1].set_xlabel('Time Step')\n",
    "    plt.tight_layout(rect=[0, 0, 1, 0.96])\n",
    "    save_path = os.path.join(output_dir, \"residual_comparison_correlation_filter.png\")\n",
    "    plt.savefig(save_path, dpi=200)\n",
    "    plt.close(fig)\n",
    "    print(f\"    ✅ Residual comparison plot saved to: {save_path}\")\n",
    "\n",
    "def plot_poisoned_model_filtering_comparison(original_candidates, filtered_candidates, output_dir):\n",
    "    \"\"\"\n",
    "    Visualize comparison before and after poisoned model filtering\n",
    "    \"\"\"\n",
    "    print(\"\\n    🖼️ Generating Poisoned model filtering comparison plot...\")\n",
    "    \n",
    "    if not filtered_candidates:\n",
    "        print(\"    ⚠️ No filtered candidate signals, skipping visualization\")\n",
    "        return\n",
    "    \n",
    "    best_filtered = max(filtered_candidates, key=lambda x: np.sum(np.abs(x['trojan'])))\n",
    "    \n",
    "    best_original = None\n",
    "    for orig in original_candidates:\n",
    "        if orig['start_pos'] == best_filtered['start_pos']:\n",
    "            best_original = orig\n",
    "            break\n",
    "    \n",
    "    if best_original is None:\n",
    "        print(\"    ⚠️ Cannot find corresponding original candidate signal\")\n",
    "        return\n",
    "    \n",
    "    fig, axes = plt.subplots(N_VARS, 1, figsize=(15, 4 * N_VARS), sharex=True)\n",
    "    if N_VARS == 1: axes = [axes]\n",
    "    fig.suptitle('Poisoned Model Filtering: Before vs After', fontsize=16)\n",
    "    time_steps = np.arange(TROJAN_LENGTH)\n",
    "    \n",
    "    for i in range(N_VARS):\n",
    "        ax = axes[i]\n",
    "        \n",
    "        ax.plot(time_steps, best_original['trojan'][:, i], label='Original Candidate', \n",
    "                color='red', alpha=0.7, linewidth=2)\n",
    "        \n",
    "        ax.plot(time_steps, best_filtered['trojan'][:, i], label='Filtered (Output Difference)', \n",
    "                color='blue', linewidth=2)\n",
    "        \n",
    "        orig_range = np.max(best_original['trojan'][:, i]) - np.min(best_original['trojan'][:, i])\n",
    "        filt_range = np.max(best_filtered['trojan'][:, i]) - np.min(best_filtered['trojan'][:, i])\n",
    "        noise_ratio = best_filtered['noise_filtered_ratio']\n",
    "        \n",
    "        ax.text(0.02, 0.95, f'Original Range: {orig_range:.6f}\\nFiltered Range: {filt_range:.6f}\\nNoise Filtered: {noise_ratio:.1%}', \n",
    "                transform=ax.transAxes, bbox=dict(boxstyle='round', facecolor='lightcyan', alpha=0.8),\n",
    "                verticalalignment='top')\n",
    "        \n",
    "        ax.set_title(f'Channel {i+1} - Position {best_filtered[\"start_pos\"]}')\n",
    "        ax.legend()\n",
    "        ax.grid(True, linestyle=':')\n",
    "    \n",
    "    axes[-1].set_xlabel('Time Step in Trojan Signal')\n",
    "    plt.tight_layout(rect=[0, 0, 1, 0.96])\n",
    "    save_path = os.path.join(output_dir, \"poisoned_model_filtering_comparison.png\")\n",
    "    plt.savefig(save_path, dpi=200)\n",
    "    plt.close(fig)\n",
    "    print(f\"    ✅ Poisoned model filtering comparison plot saved to: {save_path}\")\n",
    "\n",
    "def plot_final_comparison_denormalized(extracted_trojan, raw_ground_truth, output_dir):\n",
    "    \"\"\"\n",
    "    Plot final comparison - comparing at original data scale\n",
    "    \"\"\"\n",
    "    print(\"\\n    🖼️ Generating final comparison plot (original scale)...\")\n",
    "    \n",
    "    fig, axes = plt.subplots(N_VARS, 1, figsize=(15, 4 * N_VARS), sharex=True)\n",
    "    if N_VARS == 1: axes = [axes]\n",
    "    fig.suptitle('Extracted Trojan vs. Ground Truth (Original Scale)', fontsize=16)\n",
    "    time_steps = np.arange(TROJAN_LENGTH)\n",
    "\n",
    "    for i in range(N_VARS):\n",
    "        ax = axes[i]\n",
    "        ax.plot(time_steps, raw_ground_truth[:, i], label='Ground Truth', color='blue', linestyle='--', linewidth=3)\n",
    "        ax.plot(time_steps, extracted_trojan[:, i], label='Extracted', color='red', alpha=0.8, linewidth=2)\n",
    "        \n",
    "        gt_range = np.max(raw_ground_truth[:, i]) - np.min(raw_ground_truth[:, i])\n",
    "        ext_range = np.max(extracted_trojan[:, i]) - np.min(extracted_trojan[:, i])\n",
    "        ax.text(0.02, 0.95, f'GT Range: {gt_range:.6f}\\nExt Range: {ext_range:.6f}', \n",
    "                transform=ax.transAxes, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8),\n",
    "                verticalalignment='top')\n",
    "        \n",
    "        ax.set_title(f'Channel {i+1}')\n",
    "        ax.legend()\n",
    "        ax.grid(True, linestyle=':')\n",
    "    \n",
    "    axes[-1].set_xlabel('Time Step in Trojan Signal')\n",
    "    plt.tight_layout(rect=[0, 0, 1, 0.96])\n",
    "    save_path = os.path.join(output_dir, \"final_comparison_denormalized.png\")\n",
    "    plt.savefig(save_path, dpi=200)\n",
    "    plt.close(fig)\n",
    "    print(f\"    ✅ Final comparison plot saved to: {save_path}\")\n",
    "\n",
    "def plot_poisoned_model_input_output(clean_model, poisoned_model, clean_input_raw, extracted_trojan, scaler, device, output_dir):\n",
    "    \"\"\"\n",
    "    Plot poisoned model's poisoned input and output\n",
    "    \"\"\"\n",
    "    print(\"\\n    🖼️ Generating Poisoned model input-output comparison plot...\")\n",
    "    \n",
    "    inject_end = TROJAN_INJECT_POS + TROJAN_LENGTH\n",
    "    \n",
    "    poisoned_input_raw = clean_input_raw.copy()\n",
    "    poisoned_input_raw[TROJAN_INJECT_POS:inject_end, :] += extracted_trojan\n",
    "    \n",
    "    if USE_STANDARD_SCALER:\n",
    "        clean_input_scaled = scaler.transform(clean_input_raw)\n",
    "        poisoned_input_scaled = scaler.transform(poisoned_input_raw)\n",
    "    else:\n",
    "        clean_input_scaled = clean_input_raw\n",
    "        poisoned_input_scaled = poisoned_input_raw\n",
    "    \n",
    "    clean_input_tensor = torch.tensor(clean_input_scaled, dtype=torch.float32).unsqueeze(0).to(device)\n",
    "    poisoned_input_tensor = torch.tensor(poisoned_input_scaled, dtype=torch.float32).unsqueeze(0).to(device)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        clean_pred_scaled = clean_model(clean_input_tensor).squeeze(0).cpu().numpy()\n",
    "        poisoned_pred_scaled = poisoned_model(poisoned_input_tensor).squeeze(0).cpu().numpy()\n",
    "    \n",
    "    if USE_STANDARD_SCALER:\n",
    "        clean_pred_from_clean = scaler.inverse_transform(clean_pred_scaled)\n",
    "        poisoned_pred_from_poisoned = scaler.inverse_transform(poisoned_pred_scaled)\n",
    "    else:\n",
    "        clean_pred_from_clean = clean_pred_scaled\n",
    "        poisoned_pred_from_poisoned = poisoned_pred_scaled\n",
    "    \n",
    "    fig, axes = plt.subplots(N_VARS, 1, figsize=(20, 6 * N_VARS), sharex=True)\n",
    "    if N_VARS == 1: axes = [axes]\n",
    "    fig.suptitle('Poisoned Model: Input vs Output Comparison (Denormalized)', fontsize=16)\n",
    "    \n",
    "    time_steps_input = np.arange(INPUT_LEN)\n",
    "    time_steps_output = np.arange(OUTPUT_LEN)\n",
    "    \n",
    "    for i in range(N_VARS):\n",
    "        ax = axes[i]\n",
    "        \n",
    "        ax.plot(time_steps_input, clean_input_raw[:, i], label='Clean Input', color='blue', alpha=0.7, linewidth=2)\n",
    "        ax.plot(time_steps_input, poisoned_input_raw[:, i], label='Poisoned Input', color='red', alpha=0.8, linewidth=2)\n",
    "        \n",
    "        ax.axvspan(TROJAN_INJECT_POS, inject_end-1, alpha=0.2, color='red', label='Trojan Inject Region')\n",
    "        \n",
    "        offset = INPUT_LEN\n",
    "        output_time_steps = time_steps_output + offset\n",
    "        \n",
    "        ax.plot(output_time_steps, clean_pred_from_clean[:, i], label='Clean Model Output', color='cyan', linestyle='--', linewidth=2)\n",
    "        ax.plot(output_time_steps, poisoned_pred_from_poisoned[:, i], label='Poisoned Model Output', color='orange', linewidth=2)\n",
    "        \n",
    "        ax.axvspan(offset + TROJAN_INJECT_POS, offset + inject_end-1, alpha=0.2, color='orange', label='Trojan Response Region')\n",
    "        \n",
    "        ax.axvline(x=INPUT_LEN, color='black', linestyle='-', alpha=0.5, linewidth=1)\n",
    "        \n",
    "        y_min, y_max = ax.get_ylim()\n",
    "        \n",
    "        ax.text(INPUT_LEN/2, y_max*0.9, 'INPUT', ha='center', fontsize=12, fontweight='bold')\n",
    "        ax.text(INPUT_LEN + OUTPUT_LEN/2, y_max*0.9, 'OUTPUT', ha='center', fontsize=12, fontweight='bold')\n",
    "        \n",
    "        trojan_input_strength = np.mean(np.abs(extracted_trojan[:, i]))\n",
    "        output_diff = np.mean(np.abs(poisoned_pred_from_poisoned[TROJAN_INJECT_POS:inject_end, i] - \n",
    "                                     clean_pred_from_clean[TROJAN_INJECT_POS:inject_end, i]))\n",
    "        \n",
    "        input_range = np.max(poisoned_input_raw[:, i]) - np.min(poisoned_input_raw[:, i])\n",
    "        output_range = np.max(poisoned_pred_from_poisoned[:, i]) - np.min(poisoned_pred_from_poisoned[:, i])\n",
    "        \n",
    "        ax.text(0.02, 0.02, f'Trojan Strength: {trojan_input_strength:.6f}\\nOutput Difference: {output_diff:.6f}\\nInput Range: {input_range:.6f}\\nOutput Range: {output_range:.6f}', \n",
    "                transform=ax.transAxes, bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.8),\n",
    "                verticalalignment='bottom')\n",
    "        \n",
    "        ax.set_title(f'Channel {i+1} (Denormalized Values)')\n",
    "        ax.legend(loc='upper right')\n",
    "        ax.grid(True, linestyle=':')\n",
    "    \n",
    "    axes[-1].set_xlabel('Time Step')\n",
    "    plt.tight_layout(rect=[0, 0, 1, 0.96])\n",
    "    save_path = os.path.join(output_dir, \"poisoned_model_input_output_denormalized.png\")\n",
    "    plt.savefig(save_path, dpi=200)\n",
    "    plt.close(fig)\n",
    "    print(f\"    ✅ Poisoned model input-output comparison plot (denormalized) saved to: {save_path}\")\n",
    "\n",
    "def plot_best_worst_residuals(best_trojan, best_score, worst_trojan, worst_score, output_dir):\n",
    "    \"\"\"\n",
    "    Plot residual sequences with highest and lowest scores\n",
    "    \"\"\"\n",
    "    print(\"\\n    🖼️ Generating highest/lowest scoring residual sequences plot...\")\n",
    "\n",
    "    if best_trojan is None or worst_trojan is None:\n",
    "        print(\"    ⚠️ No valid residual signals found, skipping plot\")\n",
    "        return\n",
    "\n",
    "    fig, axes = plt.subplots(N_VARS, 2, figsize=(20, 6 * N_VARS), sharex=True, sharey=False)\n",
    "    fig.suptitle('Comparison of Best and Worst Scoring Residuals', fontsize=18)\n",
    "    time_steps = np.arange(TROJAN_LENGTH)\n",
    "\n",
    "    for i in range(N_VARS):\n",
    "        ax_best = axes[i, 0] if N_VARS > 1 else axes[0]\n",
    "        ax_best.plot(time_steps, best_trojan[:, i], color='green', linewidth=2)\n",
    "        ax_best.set_title(f'Best Candidate (Score: {best_score:.4f}) - Channel {i+1}')\n",
    "        ax_best.set_xlabel('Time Step')\n",
    "        ax_best.set_ylabel('Residual Value')\n",
    "        ax_best.grid(True, linestyle=':')\n",
    "\n",
    "        ax_worst = axes[i, 1] if N_VARS > 1 else axes[1]\n",
    "        ax_worst.plot(time_steps, worst_trojan[:, i], color='red', linewidth=2)\n",
    "        ax_worst.set_title(f'Worst Candidate (Score: {worst_score:.4f}) - Channel {i+1}')\n",
    "        ax_worst.set_xlabel('Time Step')\n",
    "        ax_worst.set_ylabel('Residual Value')\n",
    "        ax_worst.grid(True, linestyle=':')\n",
    "        \n",
    "    plt.tight_layout(rect=[0, 0, 1, 0.96])\n",
    "    save_path = os.path.join(output_dir, \"best_worst_residuals.png\")\n",
    "    plt.savefig(save_path, dpi=200)\n",
    "    plt.close(fig)\n",
    "    print(f\"    ✅ Best/worst scoring residual sequences plot saved to: {save_path}\")\n",
    "\n",
    "def plot_worst_candidate_model_analysis(clean_model, poisoned_model, clean_input_raw, worst_trojan, worst_candidate, scaler, device, output_dir):\n",
    "    \"\"\"\n",
    "    Plot hybrid model analysis for the worst scoring candidate\n",
    "    \"\"\"\n",
    "    print(\"\\n    🖼️ Generating worst scoring candidate hybrid model analysis plot...\")\n",
    "    \n",
    "    if worst_trojan is None or worst_candidate is None:\n",
    "        print(\"    ⚠️ No worst candidate data available, skipping plot\")\n",
    "        return\n",
    "    \n",
    "    inject_end = TROJAN_INJECT_POS + TROJAN_LENGTH\n",
    "    \n",
    "    # Use the original trojan signal from the worst candidate\n",
    "    worst_original_trojan = worst_candidate['original_trojan']\n",
    "    poisoned_input_raw = clean_input_raw.copy()\n",
    "    poisoned_input_raw[TROJAN_INJECT_POS:inject_end, :] += worst_original_trojan\n",
    "    \n",
    "    if USE_STANDARD_SCALER:\n",
    "        clean_input_scaled = scaler.transform(clean_input_raw)\n",
    "        poisoned_input_scaled = scaler.transform(poisoned_input_raw)\n",
    "    else:\n",
    "        clean_input_scaled = clean_input_raw\n",
    "        poisoned_input_scaled = poisoned_input_raw\n",
    "    \n",
    "    clean_input_tensor = torch.tensor(clean_input_scaled, dtype=torch.float32).unsqueeze(0).to(device)\n",
    "    poisoned_input_tensor = torch.tensor(poisoned_input_scaled, dtype=torch.float32).unsqueeze(0).to(device)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        clean_pred_scaled = clean_model(clean_input_tensor).squeeze(0).cpu().numpy()\n",
    "        poisoned_pred_scaled = poisoned_model(poisoned_input_tensor).squeeze(0).cpu().numpy()\n",
    "    \n",
    "    if USE_STANDARD_SCALER:\n",
    "        clean_pred_from_clean = scaler.inverse_transform(clean_pred_scaled)\n",
    "        poisoned_pred_from_poisoned = scaler.inverse_transform(poisoned_pred_scaled)\n",
    "    else:\n",
    "        clean_pred_from_clean = clean_pred_scaled\n",
    "        poisoned_pred_from_poisoned = poisoned_pred_scaled\n",
    "    \n",
    "    fig, axes = plt.subplots(N_VARS, 1, figsize=(20, 6 * N_VARS), sharex=True)\n",
    "    if N_VARS == 1: axes = [axes]\n",
    "    score_value = worst_candidate.get(\"score\")\n",
    "    score_str = f'Score: {score_value:.4f}' if score_value is not None else 'Score: Unknown'\n",
    "    fig.suptitle(f'Worst Candidate Analysis ({score_str})', fontsize=16)\n",
    "\n",
    "    time_steps_input = np.arange(INPUT_LEN)\n",
    "    time_steps_output = np.arange(OUTPUT_LEN)\n",
    "    \n",
    "    for i in range(N_VARS):\n",
    "        ax = axes[i]\n",
    "        \n",
    "        ax.plot(time_steps_input, clean_input_raw[:, i], label='Clean Input', color='blue', alpha=0.7, linewidth=2)\n",
    "        ax.plot(time_steps_input, poisoned_input_raw[:, i], label='Poisoned Input (Worst)', color='darkred', alpha=0.8, linewidth=2)\n",
    "        \n",
    "        ax.axvspan(TROJAN_INJECT_POS, inject_end-1, alpha=0.2, color='darkred', label='Worst Trojan Inject Region')\n",
    "        \n",
    "        offset = INPUT_LEN\n",
    "        output_time_steps = time_steps_output + offset\n",
    "        \n",
    "        ax.plot(output_time_steps, clean_pred_from_clean[:, i], label='Clean Model Output', color='cyan', linestyle='--', linewidth=2)\n",
    "        ax.plot(output_time_steps, poisoned_pred_from_poisoned[:, i], label='Poisoned Model Output (Worst)', color='darkorange', linewidth=2)\n",
    "        \n",
    "        ax.axvspan(offset + TROJAN_INJECT_POS, offset + inject_end-1, alpha=0.2, color='darkorange', label='Worst Response Region')\n",
    "        \n",
    "        ax.axvline(x=INPUT_LEN, color='black', linestyle='-', alpha=0.5, linewidth=1)\n",
    "        \n",
    "        y_min, y_max = ax.get_ylim()\n",
    "        \n",
    "        ax.text(INPUT_LEN/2, y_max*0.9, 'INPUT', ha='center', fontsize=12, fontweight='bold')\n",
    "        ax.text(INPUT_LEN + OUTPUT_LEN/2, y_max*0.9, 'OUTPUT', ha='center', fontsize=12, fontweight='bold')\n",
    "        \n",
    "        trojan_input_strength = np.mean(np.abs(worst_original_trojan[:, i]))\n",
    "        output_diff = np.mean(np.abs(poisoned_pred_from_poisoned[TROJAN_INJECT_POS:inject_end, i] - \n",
    "                                     clean_pred_from_clean[TROJAN_INJECT_POS:inject_end, i]))\n",
    "        \n",
    "        input_range = np.max(poisoned_input_raw[:, i]) - np.min(poisoned_input_raw[:, i])\n",
    "        output_range = np.max(poisoned_pred_from_poisoned[:, i]) - np.min(poisoned_pred_from_poisoned[:, i])\n",
    "        \n",
    "        worst_score_info = f\"Position: {worst_candidate['start_pos']}\\n\"\n",
    "        if 'noise_filtered_ratio' in worst_candidate:\n",
    "            worst_score_info += f\"Noise Filtered: {worst_candidate['noise_filtered_ratio']:.2%}\\n\"\n",
    "        \n",
    "        ax.text(0.02, 0.02, f'{worst_score_info}Trojan Strength: {trojan_input_strength:.6f}\\nOutput Difference: {output_diff:.6f}\\nInput Range: {input_range:.6f}\\nOutput Range: {output_range:.6f}', \n",
    "                transform=ax.transAxes, bbox=dict(boxstyle='round', facecolor='lightcoral', alpha=0.8),\n",
    "                verticalalignment='bottom')\n",
    "        \n",
    "        ax.set_title(f'Channel {i+1} - Worst Candidate Analysis')\n",
    "        ax.legend(loc='upper right')\n",
    "        ax.grid(True, linestyle=':')\n",
    "    \n",
    "    axes[-1].set_xlabel('Time Step')\n",
    "    plt.tight_layout(rect=[0, 0, 1, 0.96])\n",
    "    save_path = os.path.join(output_dir, \"worst_candidate_model_analysis.png\")\n",
    "    plt.savefig(save_path, dpi=200)\n",
    "    plt.close(fig)\n",
    "    print(f\"    ✅ Worst candidate hybrid model analysis plot saved to: {save_path}\")\n",
    "\n",
    "# ======================================================================================\n",
    "# 6. Main Execution Logic\n",
    "# ======================================================================================\n",
    "\n",
    "def main():\n",
    "    start_time = time.time()\n",
    "    \n",
    "    # --- Step 1: Load Data ---\n",
    "    print(\"\\n--- [Step 1/5] Loading Data ---\")\n",
    "    train_data_df = pd.read_csv(CLEAN_DATA_PATH, index_col=0)\n",
    "    \n",
    "    if USE_STANDARD_SCALER:\n",
    "        scaler = StandardScaler().fit(train_data_df.values)\n",
    "        print(\"    📊 Using StandardScaler\")\n",
    "    else:\n",
    "        scaler = StandardScaler().fit(train_data_df.values)\n",
    "        scaler.scale_ = np.ones_like(scaler.scale_)\n",
    "        scaler.mean_ = np.zeros_like(scaler.mean_)\n",
    "        print(\"    📊 Not using Scaler\")\n",
    "    \n",
    "    single_clean_sample_raw = train_data_df.values[:INPUT_LEN].astype(np.float32)\n",
    "    \n",
    "    if USE_STANDARD_SCALER:\n",
    "        single_clean_sample_scaled = scaler.transform(single_clean_sample_raw)\n",
    "    else:\n",
    "        single_clean_sample_scaled = single_clean_sample_raw\n",
    "    \n",
    "    analysis_sample_scaled = torch.tensor(single_clean_sample_scaled, dtype=torch.float32).unsqueeze(0).to(DEVICE)\n",
    "    \n",
    "    # --- Step 2: Load Models ---\n",
    "    print(\"\\n--- [Step 2/5] Loading Models ---\")\n",
    "    clean_model = load_model(CLEAN_MODEL_PATH, DEVICE)\n",
    "    poisoned_model = load_model(POISONED_MODEL_PATH, DEVICE)\n",
    "    if clean_model is None or poisoned_model is None: return\n",
    "\n",
    "    # --- Step 3: Locate Critical Layer Using Exclusion Method ---\n",
    "    print(\"\\n--- [Step 3/5] Locating Critical Layer Using Exclusion Method ---\")\n",
    "    critical_block = find_critical_layer_by_exclusion(clean_model, poisoned_model, analysis_sample_scaled, DEVICE)\n",
    "    if critical_block is None:\n",
    "        return\n",
    "\n",
    "    # --- Step 4: Correlation-based Signal Extraction and Poisoned Model Filtering ---\n",
    "    print(f\"\\n--- [Step 4/5] Correlation-based Signal Extraction and Poisoned Model Filtering ---\")\n",
    "    \n",
    "    single_residual, cumulative_residual = generate_residual_sequences(\n",
    "        clean_model, poisoned_model, analysis_sample_scaled, critical_block, DEVICE\n",
    "    )\n",
    "    \n",
    "    filtered_residual, correlation_info = apply_correlation_based_filtering(single_residual, cumulative_residual)\n",
    "    \n",
    "    candidates = extract_trojan_candidates_from_filtered(filtered_residual)\n",
    "    \n",
    "    filtered_candidates = filter_trojan_candidates_with_poisoned_model(\n",
    "        candidates, clean_model, poisoned_model, single_clean_sample_raw, scaler, DEVICE, threshold=0.01\n",
    "    )\n",
    "    \n",
    "    final_trojan, best_candidate, best_score, worst_trojan, worst_candidate, worst_score = evaluate_filtered_trojan_candidates(\n",
    "        filtered_candidates, clean_model, poisoned_model, single_clean_sample_raw, scaler, DEVICE\n",
    "    )\n",
    "\n",
    "    if final_trojan is None:\n",
    "        print(\"❌ No valid Trojan signal found\")\n",
    "        return\n",
    "    \n",
    "    trojan_df = pd.DataFrame(final_trojan)\n",
    "    trojan_df.to_csv(os.path.join(OUTPUT_DIR, \"extracted_trojan.csv\"))\n",
    "    \n",
    "    print(f\"    🎯 Final Trojan signal from position {best_candidate['start_pos']}, score: {best_score:.6f}\")\n",
    "    if 'noise_filtered_ratio' in best_candidate:\n",
    "        print(f\"    🔧 Noise filtering ratio for this signal: {best_candidate['noise_filtered_ratio']:.2%}\")\n",
    "        print(f\"    📊 Filtering method: Using Poisoned model output difference, threshold=0.01\")\n",
    "        \n",
    "    # --- Step 5: Final Evaluation and Visualization ---\n",
    "    print(f\"\\n--- [Step 5/5] Final Evaluation and Visualization (with Poisoned Model Filtering) ---\")\n",
    "    \n",
    "    raw_ground_truth_trojan = load_raw_ground_truth_trojan()\n",
    "    \n",
    "    mse = np.mean((final_trojan - raw_ground_truth_trojan) ** 2)\n",
    "    rmse = np.sqrt(mse)\n",
    "\n",
    "    corr_per_channel = []\n",
    "    for i in range(N_VARS):\n",
    "        if np.std(final_trojan[:, i]) < 1e-9 and np.std(raw_ground_truth_trojan[:, i]) < 1e-9:\n",
    "            if np.abs(np.mean(final_trojan[:, i]) - np.mean(raw_ground_truth_trojan[:, i])) < 1e-9:\n",
    "                corr_per_channel.append(1.0)\n",
    "            else:\n",
    "                corr_per_channel.append(0.0)\n",
    "        elif np.std(final_trojan[:, i]) < 1e-9 or np.std(raw_ground_truth_trojan[:, i]) < 1e-9:\n",
    "            corr_per_channel.append(0.0)\n",
    "        else:\n",
    "            corr, _ = pearsonr(final_trojan[:, i], raw_ground_truth_trojan[:, i])\n",
    "            corr_per_channel.append(corr)\n",
    "    avg_corr = np.nanmean(corr_per_channel)\n",
    "    \n",
    "    print(\"=\"*80)\n",
    "    print(\"📊 [Final Evaluation Results - with Poisoned Model Filtering]\")\n",
    "    print(\"-\"*80)\n",
    "    print(f\"    MSE: {mse:.8f}\")\n",
    "    print(f\"    RMSE: {rmse:.8f}\")\n",
    "    print(f\"    Average Correlation: {avg_corr:.6f}\")\n",
    "    for i, corr in enumerate(corr_per_channel):\n",
    "        print(f\"        - Channel {i+1} Corr: {corr:.6f}\")\n",
    "        \n",
    "    print(\"-\" * 80)\n",
    "    print(\"📏 [Signal Strength Analysis]\")\n",
    "    for i in range(N_VARS):\n",
    "        gt_range = np.max(raw_ground_truth_trojan[:, i]) - np.min(raw_ground_truth_trojan[:, i])\n",
    "        ext_range = np.max(final_trojan[:, i]) - np.min(final_trojan[:, i])\n",
    "        print(f\"    Channel {i+1} - GT Range: {gt_range:.6f}, Extracted Range: {ext_range:.6f}\")\n",
    "        \n",
    "    print(\"-\" * 80)\n",
    "    print(\"🔧 [Poisoned Model Filtering Results]\")\n",
    "    print(f\"    Original candidate count: {len(candidates)}\")\n",
    "    print(f\"    Filtered candidate count: {len(filtered_candidates)}\")\n",
    "    if filtered_candidates:\n",
    "        avg_noise_ratio = np.mean([c['noise_filtered_ratio'] for c in filtered_candidates])\n",
    "        print(f\"    Average noise filtering ratio: {avg_noise_ratio:.2%}\")\n",
    "        print(f\"    Filtering threshold: 0.01\")\n",
    "        print(f\"    Filtering method: Based on Poisoned model output difference\")\n",
    "        \n",
    "    print(\"-\" * 80)\n",
    "    print(\"🔧 [Correlation Filtering Results]\")\n",
    "    for i, corr in enumerate(correlation_info):\n",
    "        status = \"FILTERED\" if corr > CORRELATION_THRESHOLD else \"KEPT\"\n",
    "        print(f\"    Channel {i+1} - Single vs Cumulative Corr: {corr:.4f} -> {status}\")\n",
    "    print(\"=\"*80)\n",
    "\n",
    "    plot_residual_comparison(single_residual, cumulative_residual, filtered_residual, correlation_info, OUTPUT_DIR)\n",
    "    plot_poisoned_model_filtering_comparison(candidates, filtered_candidates, OUTPUT_DIR)\n",
    "    plot_final_comparison_denormalized(final_trojan, raw_ground_truth_trojan, OUTPUT_DIR)\n",
    "    plot_poisoned_model_input_output(clean_model, poisoned_model, single_clean_sample_raw,\n",
    "                                     final_trojan, scaler, DEVICE, OUTPUT_DIR)\n",
    "    \n",
    "    # Use correct variable names when calling plot function\n",
    "    plot_best_worst_residuals(final_trojan, best_score, worst_trojan, worst_score, OUTPUT_DIR)\n",
    "    \n",
    "    # New: Plot worst candidate hybrid model analysis\n",
    "    plot_worst_candidate_model_analysis(clean_model, poisoned_model, single_clean_sample_raw, \n",
    "                                        worst_trojan, worst_candidate, scaler, DEVICE, OUTPUT_DIR)\n",
    "\n",
    "    end_time = time.time()\n",
    "    print(f\"\\n⏱️ Total execution time: {end_time - start_time:.2f} seconds\")\n",
    "    print(f\"🎯 Task completed! Please check the results in output folder '{OUTPUT_DIR}'.\")\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd31a490-0f52-4f33-8ba0-2619153516b7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python3 (main venv)",
   "language": "python",
   "name": "main"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
