import argparse
import math
import os
import json
import re
import random
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from scipy.signal import savgol_filter
from evaluation import calculate_answer_log_probs
from typing import Optional, List, Tuple, Dict, Any, Callable, Iterable
from utils import find_latest_result, print_debug_info
from utils import (
    get_text_with_token_length,
    load_gsm8k_dataset,
    load_math_dataset,
    load_mmlu_dataset,
    load_svamp_dataset,
    load_aqua_dataset,
    load_mathqa_dataset,
    load_arc_dataset,
    load_arithmetic_dataset,
    load_model_for_evaluation
)
from evaluation import (
    evaluate_model_on_gsm8k,
    evaluate_model_on_mmlu,
    evaluate_model_on_arc,
    evaluate_model_on_aqua,
    evaluate_model_on_mathqa,
    evaluate_model_on_numeric,
    evaluate_wiki_logprob,
    load_wiki_pairs,
    generate_actor_reasoning,
)
from tqdm import tqdm
import string
from pathlib import Path
from peft import PeftModel
import glob
import hashlib
import datetime
import shutil
import subprocess


def load_model_with_adapters(log_file_path, model_type, hyperparameters, adapter_index=None):
    """
    Load a model with its trained adapters if they exist.
    
    Args:
        log_file_path: Path to the log file
        model_type: Type of model to load
        hyperparameters: Hyperparameters for the model
        
    Returns:
        tuple: (actor_model, frozen_model, tokenizer, device)
    """
    # Look for adapter directories in the same directory as the log file
    log_dir = os.path.dirname(log_file_path)
    adapter_pattern = os.path.join(log_dir, "adapter_*")
    adapter_dirs = glob.glob(adapter_pattern)
    
    adapter_to_load = None
    
    if adapter_dirs:
        # If a specific adapter index is requested, try to use it first
        if adapter_index is not None:
            requested = os.path.join(log_dir, f"adapter_{adapter_index}")
            if os.path.isdir(requested):
                adapter_to_load = requested
                print(f"Loading requested adapter: {adapter_to_load}")
            else:
                print(f"Requested adapter adapter_{adapter_index} not found in {log_dir}. Falling back to latest available.")
        
        if adapter_to_load is None:
            # Sort by batch number to get the latest adapter
            def get_batch_number(adapter_path):
                try:
                    return int(os.path.basename(adapter_path).split("_")[-1])
                except (ValueError, IndexError):
                    return 0
            
            adapter_dirs_sorted = sorted(adapter_dirs, key=get_batch_number)
            adapter_to_load = adapter_dirs_sorted[-1]
            print(f"Loading trained adapter from: {adapter_to_load}")
        
    else:
        print(f"No trained adapters found in {log_dir}, using base model")
    
    # Use unified loader from utils
    if adapter_to_load:
        return load_model_for_evaluation(model_path=adapter_to_load, model_type=model_type)
    else:
        return load_model_for_evaluation(use_base_model=True, model_type=model_type)


def find_best_run_for_task(task_type, role):
    """
    Automatically find the best run and adapter for a given task and role.
    
    Args:
        task_type: The task type (e.g., "gsm8k", "arithmetic")
        role: The role string (e.g., "Markovian", "NonMarkovian")
        
    Returns:
        Tuple of (log_file_path, adapter_index) or (None, None) if not found.
    """
    results_dir = os.path.join("results", task_type)
    if not os.path.isdir(results_dir):
        print(f"Results directory not found: {results_dir}")
        return None, None
        
    # Glob for directories matching the pattern
    pattern = os.path.join(results_dir, f"*{role}*")
    candidate_dirs = [d for d in glob.glob(pattern) if os.path.isdir(d)]
    
    # Filter to ensure strict role matching (exclude NonMarkovian when looking for Markovian)
    if role == "Markovian":
        candidate_dirs = [d for d in candidate_dirs if "NonMarkovian" not in os.path.basename(d)]
    
    if not candidate_dirs:
        print(f"No run directories found for task '{task_type}' with role '{role}'")
        return None, None
        
    # Sort by name (which includes timestamp) to get the latest
    candidate_dirs.sort()
    latest_dir = candidate_dirs[-1]
    print(f"Auto-detected latest run directory for {role}: {latest_dir}")
    
    # Look for best_adapter.json
    best_adapter_path = os.path.join(latest_dir, "best_adapter.json")
    if not os.path.exists(best_adapter_path):
        raise FileNotFoundError(
            f"best_adapter.json not found in {latest_dir}. "
            "Sync results or re-run evaluation to produce this file before continuing."
        )
        
    try:
        with open(best_adapter_path, "r") as f:
            data = json.load(f)
    except Exception as e:
        raise RuntimeError(f"Failed to read {best_adapter_path}: {e}") from e

    batch_index = data.get("batch_index")
    if batch_index is None:
        raise ValueError(f"'batch_index' missing in {best_adapter_path}.")
    
    print(f"Found best adapter for {role} at batch index {batch_index}")
    log_path = os.path.join(latest_dir, "log.jsonl")
    return log_path, batch_index


def perturb_CoT(CoT, config):
    """
    Perturb the chain-of-thought (CoT) according to the perturbation configuration.
    """
    perturbed_CoT = CoT

    # Randomly delete a fraction of characters
    if config.get("delete_fraction", 0) > 0:
        chars = list(perturbed_CoT)
        num_to_delete = int(len(chars) * config["delete_fraction"])
        indices_to_delete = random.sample(range(len(chars)), num_to_delete)
        chars = [char for idx, char in enumerate(chars) if idx not in indices_to_delete]
        perturbed_CoT = "".join(chars)

    # Truncate a fraction from either end
    if config.get("truncate_fraction", 0) > 0:
        truncate_length = int(len(perturbed_CoT) * (1 - config["truncate_fraction"]))
        if config.get("truncate_from_front", False):
            perturbed_CoT = (
                perturbed_CoT[-truncate_length:] if truncate_length > 0 else ""
            )
        else:
            perturbed_CoT = perturbed_CoT[:truncate_length]

    # Replace digits with random probability
    if config.get("digit_replace_prob", 0) > 0:
        chars = list(perturbed_CoT)
        for i, char in enumerate(chars):
            if char.isdigit() and random.random() < config["digit_replace_prob"]:
                chars[i] = str(random.randint(0, 9))
        perturbed_CoT = "".join(chars)

    # Replace alphanumeric characters with random probability
    if config.get("char_replace_prob", 0) > 0:
        chars = list(perturbed_CoT)
        alphanumeric = string.ascii_letters + string.digits
        for i, char in enumerate(chars):
            if char in alphanumeric and random.random() < config["char_replace_prob"]:
                chars[i] = random.choice(alphanumeric)
        perturbed_CoT = "".join(chars)

    return perturbed_CoT


# Define perturbation configurations
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
PERTURB_S3_BUCKET = os.environ.get("PERTURB_S3_BUCKET", "s3://scottviteri")
if PERTURB_S3_BUCKET:
    PERTURB_S3_BUCKET = PERTURB_S3_BUCKET.rstrip("/")
_S3_WARNING_PRINTED = False

PERTURB_METADATA_FILENAME = "perturb_metadata.json"
PERTURB_METADATA_DIRNAME = "perturb_metadata"

RUN_SYNC_PATTERNS = [
    "markovian_comparison_accuracy/*.json",
    "markovian_comparison_accuracy/*.png",
    f"{PERTURB_METADATA_FILENAME}",
    f"{PERTURB_METADATA_DIRNAME}/*.json",
]

PERTURB_METADATA_PATTERNS = [
    f"{PERTURB_METADATA_FILENAME}",
    f"{PERTURB_METADATA_DIRNAME}/*.json",
]

DEFAULT_FRAGILITY_QA_DATASETS = ["arc", "arithmetic", "gsm8k", "mmlu", "svamp"]
DEFAULT_FRAGILITY_PERTURBATIONS = [
    "CharReplace",
    "Delete",
    "DigitReplace",
    "TruncateBack",
    "TruncateFront",
]

PERTURBATION_SETS = {
    "delete": {
        "perturbations": {
            f"Delete{int(frac*100)}%": {"delete_fraction": frac}
            for frac in [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
        },
        "description": "Character deletion perturbations",
    },
    "truncate_back": {
        "perturbations": {
            f"TruncateBack{int(frac*100)}%": {
                "truncate_fraction": frac,
                "truncate_from_front": False,
            }
            for frac in [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
        },
        "description": "Text truncation from end perturbations",
    },
    "truncate_front": {
        "perturbations": {
            f"TruncateFront{int(frac*100)}%": {
                "truncate_fraction": frac,
                "truncate_from_front": True,
            }
            for frac in [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
        },
        "description": "Text truncation from start perturbations",
    },
    "digit_replace": {
        "perturbations": {
            f"DigitReplace{int(prob*100)}%": {"digit_replace_prob": prob}
            for prob in [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
        },
        "description": "Random digit replacement perturbations",
    },
    "char_replace": {
        "perturbations": {
            f"CharReplace{int(prob*100)}%": {"char_replace_prob": prob}
            for prob in [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
        },
        "description": "Random alphanumeric character replacement perturbations",
    },
}


def _perturb_metadata_path(run_dir: str) -> str:
    return os.path.join(run_dir, PERTURB_METADATA_FILENAME)


def _perturb_metadata_dir(run_dir: str) -> str:
    return os.path.join(run_dir, PERTURB_METADATA_DIRNAME)


def _perturb_record_path(run_dir: str, metadata_key: str) -> str:
    safe_key = f"{metadata_key}.json"
    return os.path.join(_perturb_metadata_dir(run_dir), safe_key)


def _ensure_metadata_records_dir(run_dir: str) -> str:
    path = _perturb_metadata_dir(run_dir)
    os.makedirs(path, exist_ok=True)
    return path


def _list_metadata_record_files(run_dir: str) -> List[str]:
    dir_path = _ensure_metadata_records_dir(run_dir)
    try:
        return [
            os.path.join(dir_path, name)
            for name in os.listdir(dir_path)
            if name.endswith(".json")
        ]
    except FileNotFoundError:
        return []


def _write_metadata_record_file(run_dir: str, metadata_key: str, record: dict):
    dir_path = _ensure_metadata_records_dir(run_dir)
    path = _perturb_record_path(run_dir, metadata_key)
    tmp_path = path + ".tmp"
    payload = {**record, "_metadata_key": metadata_key}
    with open(tmp_path, "w") as f:
        json.dump(payload, f, indent=2)
    os.replace(tmp_path, path)


def _maybe_migrate_legacy_metadata(run_dir: str):
    """
    Legacy support: convert monolithic perturb_metadata.json into per-record files.
    """
    record_files = _list_metadata_record_files(run_dir)
    if record_files:
        return
    legacy_path = _perturb_metadata_path(run_dir)
    if not os.path.exists(legacy_path):
        return
    try:
        with open(legacy_path, "r") as f:
            legacy_data = json.load(f)
    except Exception:
        return
    for metadata_key, record in legacy_data.get("records", {}).items():
        _write_metadata_record_file(run_dir, metadata_key, record)


def _ensure_metadata_structure(data: dict, run_dir: str) -> dict:
    if not isinstance(data, dict):
        data = {}
    data.setdefault("run", os.path.basename(run_dir))
    data.setdefault("records", {})
    return data


def load_perturb_metadata(run_dir: str) -> dict:
    _maybe_migrate_legacy_metadata(run_dir)
    records = {}
    for record_path in _list_metadata_record_files(run_dir):
        filename = os.path.basename(record_path)
        metadata_key = filename[:-5]  # strip .json
        try:
            with open(record_path, "r") as f:
                record = json.load(f)
        except Exception:
            record = {}
        if isinstance(record, dict):
            record = dict(record)
            record.setdefault("_metadata_key", metadata_key)
            records[metadata_key] = record
    data = {"records": records}
    return _ensure_metadata_structure(data, run_dir)


def save_perturb_metadata(run_dir: str, metadata: dict):
    path = _perturb_metadata_path(run_dir)
    os.makedirs(run_dir, exist_ok=True)
    tmp_path = path + ".tmp"
    with open(tmp_path, "w") as f:
        json.dump(metadata, f, indent=2)
    os.replace(tmp_path, path)


def update_metadata_record(run_dir: str, metadata_key: str, record: dict):
    metadata = get_cached_metadata(run_dir)
    metadata.setdefault("records", {})
    metadata["records"][metadata_key] = record
    _write_metadata_record_file(run_dir, metadata_key, record)
    push_perturb_metadata(run_dir)


def load_best_adapter_index(run_dir: str) -> Optional[int]:
    """
    Read best_adapter.json inside run_dir and return its adapter index, if available.
    """
    best_path = os.path.join(run_dir, "best_adapter.json")
    if not os.path.exists(best_path):
        return None
    try:
        with open(best_path, "r") as f:
            data = json.load(f)
    except Exception as exc:
        print(f"Warning: failed to read {best_path}: {exc}")
        return None

    adapter_name = data.get("adapter")
    batch_index = data.get("batch_index")

    if isinstance(batch_index, int):
        return batch_index

    if isinstance(adapter_name, str):
        suffix = adapter_name.split("_")[-1]
        if suffix.isdigit():
            return int(suffix)

    return None


def build_perturb_metadata_key(
    task_type: str,
    perturb_type: str,
    metric: str,
    paired_role: str,
    paired_adapter_index: int,
    markovian_run: str,
    non_markovian_run: str,
) -> str:
    payload = {
        "task_type": task_type,
        "perturb": perturb_type,
        "metric": metric,
        "paired_role": paired_role,
        "paired_adapter_index": paired_adapter_index,
        "markovian_run": os.path.basename(markovian_run),
        "non_markovian_run": os.path.basename(non_markovian_run),
    }
    encoded = json.dumps(payload, sort_keys=True).encode("utf-8")
    return hashlib.sha1(encoded).hexdigest()


def _record_satisfies(record: dict, required_stride: Optional[int]) -> bool:
    if required_stride is None:
        return True
    record_stride = record.get("stride", 1)
    return record_stride <= required_stride


def metadata_has_record(metadata: dict, key: str, required_stride: Optional[int] = None) -> bool:
    record = metadata.get("records", {}).get(key)
    if not record:
        return False
    return _record_satisfies(record, required_stride)


# Metadata cache to avoid repeated disk IO
_PERTURB_METADATA_CACHE = {}
_PULLED_PERTURB_DIRS = set()
_PULLED_RUN_DIRS = set()


def safe_relpath(path: str, base_dir: str) -> str:
    if not base_dir:
        return path
    try:
        base_abs = os.path.abspath(base_dir)
        path_abs = os.path.abspath(path)
        common = os.path.commonpath([path_abs, base_abs])
        if common == base_abs:
            return os.path.relpath(path_abs, base_abs)
    except ValueError:
        pass
    return path


def _s3_uri_for(local_path: str) -> Optional[str]:
    if not PERTURB_S3_BUCKET:
        return None
    rel_path = safe_relpath(local_path, PROJECT_ROOT).replace("\\", "/")
    dest = f"{PERTURB_S3_BUCKET}/{rel_path}"
    if dest.startswith("s3:/") and not dest.startswith("s3://"):
        dest = dest.replace("s3:/", "s3://", 1)
    return dest


def _run_s3_sync(cmd: list[str]):
    global _S3_WARNING_PRINTED
    try:
        subprocess.run(cmd, check=True)
    except FileNotFoundError:
        if not _S3_WARNING_PRINTED:
            print("Warning: aws CLI not found; skipping S3 sync.")
            _S3_WARNING_PRINTED = True
    except subprocess.CalledProcessError as e:
        print(f"S3 sync error: {e}")


def _sync_to_s3(local_path: str, include_patterns: list[str]):
    if not PERTURB_S3_BUCKET:
        return
    dest = _s3_uri_for(local_path)
    if not dest:
        return
    include_args = ["--exclude", "*"]
    for pattern in include_patterns:
        include_args.extend(["--include", pattern])
    cmd = ["aws", "s3", "sync", local_path, dest, *include_args]
    _run_s3_sync(cmd)


def _sync_from_s3(local_path: str, include_patterns: list[str]):
    if not PERTURB_S3_BUCKET:
        return
    source = _s3_uri_for(local_path)
    if not source:
        return
    include_args = ["--exclude", "*"]
    for pattern in include_patterns:
        include_args.extend(["--include", pattern])
    os.makedirs(local_path, exist_ok=True)
    cmd = ["aws", "s3", "sync", source, local_path, *include_args]
    _run_s3_sync(cmd)


def sync_run_dir_outputs(run_dir: str):
    _sync_to_s3(run_dir, RUN_SYNC_PATTERNS)


def sync_run_dir_from_s3(run_dir: str):
    if run_dir in _PULLED_RUN_DIRS:
        return
    _sync_from_s3(run_dir, RUN_SYNC_PATTERNS)
    _PULLED_RUN_DIRS.add(run_dir)


def pull_perturb_metadata(run_dir: str, force: bool = False):
    if not force and run_dir in _PULLED_PERTURB_DIRS:
        return
    _sync_from_s3(run_dir, PERTURB_METADATA_PATTERNS)
    _PULLED_PERTURB_DIRS.add(run_dir)
    if force and run_dir in _PERTURB_METADATA_CACHE:
        del _PERTURB_METADATA_CACHE[run_dir]


def push_perturb_metadata(run_dir: str):
    _sync_to_s3(run_dir, PERTURB_METADATA_PATTERNS)


def get_cached_metadata(run_dir: str) -> dict:
    if run_dir and os.path.isdir(run_dir):
        pull_perturb_metadata(run_dir)
    if run_dir not in _PERTURB_METADATA_CACHE:
        _PERTURB_METADATA_CACHE[run_dir] = load_perturb_metadata(run_dir)
    return _PERTURB_METADATA_CACHE[run_dir]


def persist_metadata_cache(run_dir: str):
    data = _PERTURB_METADATA_CACHE.get(run_dir)
    if data is not None:
        save_perturb_metadata(run_dir, data)
        push_perturb_metadata(run_dir)


def infer_role_from_log_path(log_file_path: str) -> str:
    run_dir = os.path.dirname(log_file_path)
    basename = os.path.basename(run_dir).lower()
    if "nonmarkovian" in basename:
        return "NonMarkovian"
    if "markovian" in basename:
        return "Markovian"
    return "Unknown"


def get_output_paths(log_file, perturb_type, include_question=False):
    """Get standardized paths for output files."""
    # If log_file points to a file, get its directory
    # If log_file points to a directory, use it directly
    if os.path.isfile(log_file):
        base_dir = os.path.dirname(log_file)
    else:
        base_dir = log_file
        
    base_name = f"perturbation_results_{perturb_type}"
    if include_question:
        base_name += "_with_question"
    return {
        "json": os.path.join(base_dir, f"{base_name}.json"),
        "plot": os.path.join(base_dir, f"{base_name}_plot.png"),
        "debug_plot": os.path.join(base_dir, f"{base_name}_debug.png"),
    }


def save_perturbation_results(results, log_file, perturb_type, include_question=False):
    """Save perturbation results to a JSON file."""
    output_file = get_output_paths(log_file, perturb_type, include_question)["json"]
    with open(output_file, "w") as f:
        json.dump(results, f)
    print(f"Results saved to {output_file}")


def load_perturbation_results(log_file, perturb_type, include_question=False):
    """Load perturbation results from a JSON file."""
    input_file = get_output_paths(log_file, perturb_type, include_question)["json"]
    with open(input_file, "r") as f:
        return json.load(f)


def run_perturbations(log_file, perturb_type, include_question=False, stride=1, max_index=None, save_interval=10, evaluator="actor", adapter_index=None):
    """
    Run perturbation analysis on the given log file.
    max_index: if provided, only process entries with batch_index <= max_index
    include_question: whether to include the question in the prompt
    save_interval: save intermediate results every this many entries (set to 0 to disable)
    """
    if perturb_type not in PERTURBATION_SETS:
        raise ValueError(f"Unknown perturbation type: {perturb_type}")

    perturbations = PERTURBATION_SETS[perturb_type]["perturbations"]

    # Ensure we have the latest data from S3
    run_dir = os.path.dirname(log_file)
    if run_dir:
        sync_run_dir_from_s3(run_dir)

    # Process the log file to extract perturbation data
    with open(log_file, "r") as f:
        log_data = [json.loads(line) for line in f]

    # Extract hyperparameters from the first line
    hyperparameters = log_data[0]
    task_type = hyperparameters.get("task_type", "gsm8k")
    actor_model, frozen_model, tokenizer, device = load_model_with_adapters(log_file, hyperparameters["model_type"], hyperparameters, adapter_index=adapter_index)
    eval_model = actor_model if evaluator == "actor" else frozen_model

    # Filter log data by batch index if max_index is provided
    if max_index is not None:
        log_data = [entry for entry in log_data if entry.get("Batch Index", float('inf')) <= max_index]
        print(f"Processing entries up to batch index {max_index}")

    # Path for saving results
    output_path = get_output_paths(log_file, perturb_type, include_question)["json"]
    
    # Check if we have previous partial results to resume from
    perturbation_data = []
    last_processed_idx = -1
    if os.path.exists(output_path):
        try:
            with open(output_path, "r") as f:
                perturbation_data = json.load(f)
                if perturbation_data:
                    # Get the last processed batch index
                    last_processed_idx = perturbation_data[-1]["Batch Index"]
                    print(f"Resuming from entry with batch index {last_processed_idx}")
        except (json.JSONDecodeError, KeyError):
            print(f"Could not parse previous results in {output_path}, starting fresh")
            perturbation_data = []
            last_processed_idx = -1

    # Extract perturbation-related metrics
    entries_to_process = []
    for entry in log_data[1:]:
        if "Example" not in entry:
            continue
        batch_idx = entry.get("Batch Index", -1)
        if batch_idx > last_processed_idx:
            entries_to_process.append(entry)
    
    print(f"Processing {len(entries_to_process)} entries, saving every {save_interval} entries")
    
    for i, entry in enumerate(tqdm(entries_to_process[::stride], desc="Processing entries")):
        if i % 100 == 0:  # Adjust print frequency based on stride
            example = entry["Example"]
            print(f"\nProcessing entry {i*stride}...")
            print_debug_info(
                task_type=task_type,
                q=example.get("Question", ""),
                reasoning_text_first=example["Actor Reasoning"],
                ans=example["Answer"],
                avg_log_prob=entry.get("Training Metrics", {}).get(
                    "Actor Log Probs", None
                ),
                extracted_generated_answers=None,
            )

        example = entry["Example"]
        actor_CoT = example["Actor Reasoning"]
        critic_CoT = example["Critic Reasoning"]
        answer = example["Answer"]
        question = example.get("Question", "")

        # Prepare entry results
        entry_results = {
            "Batch Index": entry.get("Batch Index", None),
            "Log Probs": {
                "Actor": {
                    "Original": None,
                    "Perturbed": {}
                },
                "Comparison": {  # We'll use this for either critic or actor with question
                    "Original": None,
                    "Perturbed": {}
                }
            }
        }

        # Calculate Original log probs for Actor
        actor_log_prob, _, _ = calculate_answer_log_probs(
            model=eval_model,
            tokenizer=tokenizer,
            device=device,
            questions=[question],
            reasoning=[actor_CoT],
            answers=[answer],
            hyperparameters=hyperparameters,
            include_question=False,  # Always without question for original actor
        )
        entry_results["Log Probs"]["Actor"]["Original"] = actor_log_prob[0].item()

        # Calculate log probs for either:
        # 1. Critic (if include_question=False)
        # 2. Actor with question (if include_question=True)
        comparison_log_prob, _, _ = calculate_answer_log_probs(
            model=eval_model,
            tokenizer=tokenizer,
            device=device,
            questions=[question],
            reasoning=[actor_CoT if include_question else critic_CoT],
            answers=[answer],
            hyperparameters=hyperparameters,
            include_question=include_question,
        )
        entry_results["Log Probs"]["Comparison"]["Original"] = comparison_log_prob[0].item()

        # Perform perturbations and calculate log probabilities
        for pert_name, pert_config in perturbations.items():
            if pert_name == "Original":
                continue

            # Perturb Actor CoT (always without question)
            perturbed_actor_CoT = perturb_CoT(actor_CoT, pert_config)
            actor_perturbed_log_prob, _, _ = calculate_answer_log_probs(
                model=eval_model,
                tokenizer=tokenizer,
                device=device,
                questions=[question],
                reasoning=[perturbed_actor_CoT],
                answers=[answer],
                hyperparameters=hyperparameters,
                include_question=False,  # Always without question for actor
            )
            entry_results["Log Probs"]["Actor"]["Perturbed"][pert_name] = actor_perturbed_log_prob[0].item()

            # Perturb comparison CoT (either critic or actor-with-question)
            perturbed_critic_CoT = perturb_CoT(critic_CoT, pert_config) if not include_question else None
            comparison_perturbed_log_prob, _, _ = calculate_answer_log_probs(
                model=eval_model,
                tokenizer=tokenizer,
                device=device,
                questions=[question],
                reasoning=[perturbed_actor_CoT if include_question else perturbed_critic_CoT],
                answers=[answer],
                hyperparameters=hyperparameters,
                include_question=include_question,
            )
            entry_results["Log Probs"]["Comparison"]["Perturbed"][pert_name] = comparison_perturbed_log_prob[0].item()

        perturbation_data.append(entry_results)
        
        # Periodically save intermediate results
        if save_interval > 0 and (i + 1) % save_interval == 0:
            with open(output_path, "w") as f:
                json.dump(perturbation_data, f)
            print(f"\nSaved {len(perturbation_data)} results to {output_path}")

    # Save final results
    with open(output_path, "w") as f:
        json.dump(perturbation_data, f)
        
    print(f"Analysis complete. Processed {len(perturbation_data)} entries.")
    return perturbation_data


def run_perturbations_batched(log_file, perturb_type, include_question=False, stride=1, max_index=None, save_interval=10, batch_size=8, evaluator="actor", adapter_index=None):
    """
    Run perturbation analysis on the given log file using batched processing for improved performance.
    
    Args:
        log_file: Path to the log file to analyze
        perturb_type: Type of perturbation to apply
        include_question: Whether to include the question in the prompt
        stride: Process every nth entry of the log file
        max_index: If provided, only process entries with batch_index <= max_index
        save_interval: Save intermediate results every this many examples (set to 0 to disable)
        batch_size: Number of examples to process in each batch
    
    Returns:
        List of perturbation results
    """
    if perturb_type not in PERTURBATION_SETS:
        raise ValueError(f"Unknown perturbation type: {perturb_type}")

    perturbations = PERTURBATION_SETS[perturb_type]["perturbations"]

    # Ensure we have the latest data from S3
    run_dir = os.path.dirname(log_file)
    if run_dir:
        sync_run_dir_from_s3(run_dir)

    # Process the log file to extract perturbation data
    with open(log_file, "r") as f:
        log_data = [json.loads(line) for line in f]

    # Extract hyperparameters from the first line
    hyperparameters = log_data[0]
    task_type = hyperparameters.get("task_type", "gsm8k")
    actor_model, frozen_model, tokenizer, device = load_model_with_adapters(log_file, hyperparameters["model_type"], hyperparameters, adapter_index=adapter_index)
    eval_model = actor_model if evaluator == "actor" else frozen_model

    # Filter log data by batch index if max_index is provided
    if max_index is not None:
        log_data = [entry for entry in log_data if entry.get("Batch Index", float('inf')) <= max_index]
        print(f"Processing entries up to batch index {max_index}")

    # Path for saving results
    output_path = get_output_paths(log_file, perturb_type, include_question)["json"]
    
    # Check if we have previous partial results to resume from
    perturbation_data = []
    last_processed_idx = -1
    if os.path.exists(output_path):
        try:
            with open(output_path, "r") as f:
                perturbation_data = json.load(f)
                if perturbation_data:
                    # Get the last processed batch index
                    last_processed_idx = perturbation_data[-1]["Batch Index"]
                    print(f"Resuming from entry with batch index {last_processed_idx}")
        except (json.JSONDecodeError, KeyError):
            print(f"Could not parse previous results in {output_path}, starting fresh")
            perturbation_data = []
            last_processed_idx = -1

    # Extract perturbation-related metrics
    entries_to_process = []
    for entry in log_data[1:]:
        if "Example" not in entry:
            continue
        batch_idx = entry.get("Batch Index", -1)
        if batch_idx > last_processed_idx:
            entries_to_process.append(entry)
    
    # Apply stride
    entries_to_process = entries_to_process[::stride]
    
    print(f"Processing {len(entries_to_process)} entries in batches of {batch_size}, saving every {save_interval} examples")
    
    # Track total number of examples processed for save interval
    total_examples_processed = 0
    next_save_threshold = save_interval
    
    # Process in batches
    for batch_idx in tqdm(range(0, len(entries_to_process), batch_size), desc="Processing batches"):
        batch_entries = entries_to_process[batch_idx:batch_idx + batch_size]
        batch_size_actual = len(batch_entries)
        
        # Print debug info for first entry in batch
        if batch_idx % 5 == 0:
            example = batch_entries[0]["Example"]
            print(f"\nProcessing batch starting at index {batch_idx}...")
            print_debug_info(
                task_type=task_type,
                q=example.get("Question", ""),
                reasoning_text_first=example["Actor Reasoning"],
                ans=example["Answer"],
                avg_log_prob=batch_entries[0].get("Training Metrics", {}).get(
                    "Actor Log Probs", None
                ),
                extracted_generated_answers=None,
            )
        
        # Extract batch data
        batch_questions = [entry["Example"].get("Question", "") for entry in batch_entries]
        batch_actor_CoTs = [entry["Example"]["Actor Reasoning"] for entry in batch_entries]
        batch_critic_CoTs = [entry["Example"]["Critic Reasoning"] for entry in batch_entries]
        batch_answers = [entry["Example"]["Answer"] for entry in batch_entries]
        batch_indices = [entry.get("Batch Index", None) for entry in batch_entries]
        
        # Initialize batch results
        batch_results = [
            {
                "Batch Index": idx,
                "Log Probs": {
                    "Actor": {
                        "Original": None,
                        "Perturbed": {}
                    },
                    "Comparison": {
                        "Original": None,
                        "Perturbed": {}
                    }
                }
            }
            for idx in batch_indices
        ]
        
        # Calculate Original log probs for Actor (all without question)
        actor_log_probs, _, _ = calculate_answer_log_probs(
            model=eval_model,
            tokenizer=tokenizer,
            device=device,
            questions=batch_questions,
            reasoning=batch_actor_CoTs,
            answers=batch_answers,
            hyperparameters=hyperparameters,
            include_question=False,  # Always without question for original actor
        )
        
        # Store original actor log probs
        for i in range(batch_size_actual):
            batch_results[i]["Log Probs"]["Actor"]["Original"] = actor_log_probs[i].item()
        
        # Calculate log probs for comparison (either critic or actor with question)
        comparison_reasoning = batch_actor_CoTs if include_question else batch_critic_CoTs
        comparison_log_probs, _, _ = calculate_answer_log_probs(
            model=eval_model,
            tokenizer=tokenizer,
            device=device,
            questions=batch_questions,
            reasoning=comparison_reasoning,
            answers=batch_answers,
            hyperparameters=hyperparameters,
            include_question=include_question,
        )
        
        # Store original comparison log probs
        for i in range(batch_size_actual):
            batch_results[i]["Log Probs"]["Comparison"]["Original"] = comparison_log_probs[i].item()
        
        # Process each perturbation type
        for pert_name, pert_config in perturbations.items():
            if pert_name == "Original":
                continue
                
            # Perturb all actor CoTs in batch
            perturbed_actor_CoTs = [perturb_CoT(cot, pert_config) for cot in batch_actor_CoTs]
            
            # Calculate perturbed actor log probs (without question)
            actor_perturbed_log_probs, _, _ = calculate_answer_log_probs(
                model=frozen_model,
                tokenizer=tokenizer,
                device=device,
                questions=batch_questions,
                reasoning=perturbed_actor_CoTs,
                answers=batch_answers,
                hyperparameters=hyperparameters,
                include_question=False,  # Always without question for actor
            )
            
            # Store perturbed actor log probs
            for i in range(batch_size_actual):
                batch_results[i]["Log Probs"]["Actor"]["Perturbed"][pert_name] = actor_perturbed_log_probs[i].item()
            
            # Handle comparison CoTs (either perturbed critic or perturbed actor with question)
            if include_question:
                # Use perturbed actor CoTs with question
                perturbed_comparison_CoTs = perturbed_actor_CoTs
            else:
                # Perturb critic CoTs
                perturbed_comparison_CoTs = [perturb_CoT(cot, pert_config) for cot in batch_critic_CoTs]
            
            # Calculate perturbed comparison log probs
            comparison_perturbed_log_probs, _, _ = calculate_answer_log_probs(
                model=frozen_model,
                tokenizer=tokenizer,
                device=device,
                questions=batch_questions,
                reasoning=perturbed_comparison_CoTs,
                answers=batch_answers,
                hyperparameters=hyperparameters,
                include_question=include_question,
            )
            
            # Store perturbed comparison log probs
            for i in range(batch_size_actual):
                batch_results[i]["Log Probs"]["Comparison"]["Perturbed"][pert_name] = comparison_perturbed_log_probs[i].item()
        
        # Add batch results to overall results
        perturbation_data.extend(batch_results)
        
        # Update total examples processed
        total_examples_processed += batch_size_actual
        
        # Periodically save intermediate results based on example count
        if save_interval > 0 and total_examples_processed >= next_save_threshold:
            with open(output_path, "w") as f:
                json.dump(perturbation_data, f)
            print(f"\nSaved {len(perturbation_data)} results to {output_path}")
            # Update next save threshold
            next_save_threshold = ((total_examples_processed // save_interval) + 1) * save_interval
    
    # Save final results
    with open(output_path, "w") as f:
        json.dump(perturbation_data, f)
    
    print(f"Analysis complete. Processed {len(perturbation_data)} entries.")
    return perturbation_data


def plot_perturbation_results(
    results, log_file, perturb_type, window_size=40, debug=False, max_index=None, font_size=12, legend_font_size=10, include_question=False
):
    """
    Plot the perturbation results comparing actor and critic log probabilities.

    Args:
        results: The perturbation results data.
        log_file: Path to the log file or results directory.
        perturb_type: The type of perturbation being analyzed.
        window_size: Smoothing window size.
        debug: Whether to generate debug plots.
        max_index: Maximum index to plot.
        font_size: Base font size for plot text elements.
        legend_font_size: Font size for the legend in plots.
        include_question: Whether the question was included in the prompt.
    """
    if not results:
        print("No results to plot.")
        return
        
    # Get all perturbation degrees from the first entry
    if "Log Probs" not in results[0] or "Actor" not in results[0]["Log Probs"] or "Perturbed" not in results[0]["Log Probs"]["Actor"]:
        print("Invalid result format. Cannot find perturbation data.")
        return
        
    perturbation_degrees = list(results[0]["Log Probs"]["Actor"]["Perturbed"].keys())
    print(f"Found perturbation degrees: {perturbation_degrees}")
    
    # Only filter out the exact baseline case (e.g., Delete0%)
    baseline_name = f"{perturb_type.title().replace('_', '')}0%"
    plot_degrees = [deg for deg in perturbation_degrees if deg != baseline_name]
    print(f"Plotting degrees: {plot_degrees}")
    
    if not plot_degrees:
        print("No non-zero perturbation degrees found to plot.")
        return
        
    # Extract batch indices
    batch_indices = [entry["Batch Index"] for entry in results]
    
    if max_index is not None:
        max_index = min(max_index, len(batch_indices))
        results = results[:max_index]
        batch_indices = batch_indices[:max_index]
        
    # Plotting
    plt.figure(figsize=(12, 6))
    colors = plt.cm.tab10(np.linspace(0, 1, len(plot_degrees)))
    
    for i, degree in enumerate(plot_degrees):
        # Extract data for this perturbation degree
        actor_original = []
        actor_perturbed = []
        comparison_original = []
        comparison_perturbed = []
        
        for entry in results:
            actor_original.append(entry["Log Probs"]["Actor"]["Original"])
            comparison_original.append(entry["Log Probs"]["Comparison"]["Original"])
            actor_perturbed.append(entry["Log Probs"]["Actor"]["Perturbed"][degree])
            comparison_perturbed.append(entry["Log Probs"]["Comparison"]["Perturbed"][degree])
            
        # Calculate differences
        actor_diff = np.array(actor_original) - np.array(actor_perturbed)
        comparison_diff = np.array(comparison_original) - np.array(comparison_perturbed)
        diff_difference = actor_diff - comparison_diff
        
        # Smoothing
        if window_size > 1 and len(diff_difference) > window_size:
            try:
                effect_smooth = savgol_filter(diff_difference, window_size, 3)
                padding = window_size // 2
                x_values = range(padding, len(diff_difference) - padding)
                effect_smooth = effect_smooth[padding:-padding]
            except ValueError as e:
                print(f"Smoothing error: {e}. Using raw data.")
                x_values = range(len(diff_difference))
                effect_smooth = diff_difference
        else:
            x_values = range(len(diff_difference))
            effect_smooth = diff_difference
            
        # Plot this perturbation degree
        plt.plot(
            x_values,
            effect_smooth,
            label=f"{degree}",
            color=colors[i],
            linewidth=2,
        )
    
    plt.grid(True, linestyle="--", alpha=0.7)
    plt.legend(fontsize=legend_font_size, loc="best")
    
    plt.xlabel("Training Batch", fontsize=font_size)
    
    # Update y-label based on what we're comparing
    if include_question:
        plt.ylabel("Difference in Perturbation Effect\n(Actor w/o Question - Actor w/ Question)", fontsize=font_size)
    else:
        plt.ylabel("Difference in Perturbation Effect\n(Actor - Critic)", fontsize=font_size)
        
    title = f"Perturbation Analysis: {perturb_type.replace('_', ' ').title()}"
    if include_question:
        title += " (Comparing with/without Question)"
    if window_size > 1:
        title += f" (Smoothing: {window_size})"
    else:
        title += " (Raw Data)"
        
    plt.title(title, fontsize=font_size)
    plt.tick_params(axis="both", which="major", labelsize=font_size)
    plt.tight_layout()
    
    output_file = get_output_paths(log_file, perturb_type, include_question)["plot"]
    plt.savefig(output_file, dpi=300, bbox_inches="tight")
    print(f"Plot saved to {output_file}")
    plt.close()


def plot_multiple_perturbation_results(
    log_file, perturb_types, window_size=40, max_index=None, font_size=12, legend_font_size=10, include_question=False
):
    """Plot multiple perturbation results in a grid layout."""
    # Calculate grid dimensions
    n_plots = len(perturb_types)
    n_rows = (n_plots + 1) // 2  # 2 columns, round up
    n_cols = min(2, n_plots)  # Use 2 columns unless only 1 plot
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(12 * n_cols, 6 * n_rows))
    
    # Convert axes to array if single row or column
    if n_rows == 1 and n_cols == 1:
        axes = np.array([[axes]])
    elif n_rows == 1:
        axes = axes.reshape(1, -1)
    elif n_cols == 1:
        axes = axes.reshape(-1, 1)
    
    axes_flat = axes.flatten()
    colors = list(mcolors.TABLEAU_COLORS.values())
    
    # Hide unused subplots
    for idx in range(n_plots, len(axes_flat)):
        axes_flat[idx].set_visible(False)
    
    for ax, perturb_type in zip(axes_flat, perturb_types):
        try:
            results = load_perturbation_results(log_file, perturb_type, include_question)
            if max_index is not None:
                results = results[:max_index]
                
            # Plot each perturbation degree
            for i, (pert, _) in enumerate(results[0]["Log Probs"]["Actor"]["Perturbed"].items()):
                # Skip baseline case (0% perturbation)
                if pert == f"{perturb_type.title().replace('_', '')}0%":
                    continue
                
                # Calculate differences for Actor and Comparison model
                actor_orig_values = [-entry["Log Probs"]["Actor"]["Original"] for entry in results]
                actor_pert_values = [-entry["Log Probs"]["Actor"]["Perturbed"][pert] for entry in results]
                actor_diff_values = [p - o for p, o in zip(actor_pert_values, actor_orig_values)]
                
                comparison_orig_values = [-entry["Log Probs"]["Comparison"]["Original"] for entry in results]
                comparison_pert_values = [-entry["Log Probs"]["Comparison"]["Perturbed"][pert] for entry in results]
                comparison_diff_values = [p - o for p, o in zip(comparison_pert_values, comparison_orig_values)]
                
                # Calculate effect difference
                effect_difference = [a - c for a, c in zip(actor_diff_values, comparison_diff_values)]
                
                if window_size > 1 and len(effect_difference) > window_size:
                    effect_smooth = savgol_filter(effect_difference, window_size, 3)
                    padding = window_size // 2
                    x_values = range(padding, len(effect_difference) - padding)
                    effect_smooth = effect_smooth[padding:-padding]
                else:
                    x_values = range(len(effect_difference))
                    effect_smooth = effect_difference
                
                ax.plot(x_values, effect_smooth, label=f"{pert}", color=colors[i % len(colors)], linewidth=2)
            
            ax.grid(True)
            ax.legend(fontsize=legend_font_size, loc='best')
            
            if ax.get_subplotspec().is_first_col():
                # Update y-label based on what we're comparing
                if include_question:
                    ax.set_ylabel("Difference in Perturbation Effect\n(Actor w/o Question - Actor w/ Question)", fontsize=font_size)
                else:
                    ax.set_ylabel("Difference in Perturbation Effect\n(Actor - Critic)", fontsize=font_size)
            
            if ax.get_subplotspec().is_last_row():
                ax.set_xlabel("Training Batch", fontsize=font_size)
            
            ax.tick_params(axis='both', which='major', labelsize=font_size-2)
            
            title = f"{perturb_type.replace('_', ' ').title()}"
            if include_question:
                title += " (Comparing with/without Question)"
            
            if window_size > 1:
                title += f" (Smoothing: {window_size})"
            else:
                title += " (Raw Data)"
                
            ax.set_title(title, fontsize=font_size+2)
                
        except FileNotFoundError:
            print(f"No saved results found for {perturb_type}")
            ax.text(0.5, 0.5, f"No data for {perturb_type}", ha='center', va='center', fontsize=font_size)
            ax.set_xticks([])
            ax.set_yticks([])
    
    plt.tight_layout()
    suffix = "_comparison_question" if include_question else ""
    output_file = os.path.join(os.path.dirname(log_file), f"combined_perturbation_plot{suffix}.png")
    plt.savefig(output_file, dpi=300, bbox_inches="tight")
    print(f"Combined plot saved to {output_file}")
    plt.close()


def collate_perturbation_results(perturbation_files, output_dir, perturb_type, include_question=False):
    """
    Average perturbation results across multiple runs and save to a new directory.
    """
    os.makedirs(output_dir, exist_ok=True)
    accumulated_results = []
    
    # Process each perturbation result file
    for perturbation_file in perturbation_files:
        try:
            with open(perturbation_file, 'r') as f:
                results = json.load(f)
                accumulated_results.append(results)
        except FileNotFoundError:
            print(f"Warning: No results found in {perturbation_file}")
            continue
    
    if not accumulated_results:
        print("No results to collate.")
        return
    
    num_runs = len(accumulated_results)
    
    # Find minimum length across all runs
    min_length = min(len(run) for run in accumulated_results)
    question_status = "with question" if include_question else "without question"
    print(f"Using {min_length} entries for {perturb_type} ({question_status}) (shortest common length)")
    
    # Initialize structure for averaged results
    averaged_results = []
    for entry_idx in range(min_length):
        avg_entry = {
            "Batch Index": accumulated_results[0][entry_idx]["Batch Index"],
            "Log Probs": {
                "Actor": {
                    "Original": 0.0,
                    "Perturbed": {}
                },
                "Comparison": {
                    "Original": 0.0,
                    "Perturbed": {}
                }
            }
        }
        
        # Average the Original values for both Actor and Critic
        for run in accumulated_results:
            avg_entry["Log Probs"]["Actor"]["Original"] += run[entry_idx]["Log Probs"]["Actor"]["Original"] / num_runs
            avg_entry["Log Probs"]["Comparison"]["Original"] += run[entry_idx]["Log Probs"]["Comparison"]["Original"] / num_runs
        
        # Get perturbation names from first run
        pert_names = accumulated_results[0][entry_idx]["Log Probs"]["Actor"]["Perturbed"].keys()
        
        # Initialize perturbation dictionaries
        for pert_name in pert_names:
            avg_entry["Log Probs"]["Actor"]["Perturbed"][pert_name] = 0.0
            avg_entry["Log Probs"]["Comparison"]["Perturbed"][pert_name] = 0.0
        
        # Average the perturbed values for both Actor and Critic
        for run in accumulated_results:
            for pert_name in pert_names:
                avg_entry["Log Probs"]["Actor"]["Perturbed"][pert_name] += (
                    run[entry_idx]["Log Probs"]["Actor"]["Perturbed"][pert_name] / num_runs
                )
                avg_entry["Log Probs"]["Comparison"]["Perturbed"][pert_name] += (
                    run[entry_idx]["Log Probs"]["Comparison"]["Perturbed"][pert_name] / num_runs
                )
        
        averaged_results.append(avg_entry)
    
    # Save averaged results
    output_file = get_output_paths(output_dir, perturb_type, include_question)["json"]
    with open(output_file, "w") as f:
        json.dump(averaged_results, f)
    print(f"Averaged results for {perturb_type} saved to {output_file}")


def compute_sensitivity_summary(results, perturb_type):
    """
    Compute summary statistics (mean sensitivity diff) for a set of results.
    Returns a dict mapping degree -> mean difference.
    """
    if not results:
        return {}
        
    perturbation_degrees = list(results[0]["Effect Difference"].keys())
    baseline_name = f"{perturb_type.title().replace('_', '')}0%"
    analysis_degrees = [
        deg for deg in perturbation_degrees
        if deg != baseline_name and deg.lower() != "original"
    ]
    
    summary = {}
    for degree in analysis_degrees:
        effect_differences = [entry["Effect Difference"][degree] for entry in results]
        summary[degree] = np.mean(effect_differences)
        
    return summary


def run_markovian_comparison(markovian_log_file, non_markovian_log_file, perturb_type, stride=1, max_index=None, save_interval=10, batch_size=8, evaluator="actor", adapter_index=None, markovian_adapter_index=None, non_markovian_adapter_index=None):
    """
    Compare perturbation sensitivity between Markovian and Non-Markovian models.
    
    Args:
        markovian_log_file: Path to the Markovian model's log file
        non_markovian_log_file: Path to the Non-Markovian model's log file
        perturb_type: Type of perturbation to apply
        stride: Process every nth entry of the log file
        max_index: If provided, only process entries with batch_index <= max_index
        save_interval: Save intermediate results every this many examples
        batch_size: Number of examples to process in each batch
        
    Returns:
        List of comparison results
    """
    if perturb_type not in PERTURBATION_SETS:
        raise ValueError(f"Unknown perturbation type: {perturb_type}")

    perturbations = PERTURBATION_SETS[perturb_type]["perturbations"]

    # Resolve adapter indices
    m_idx = markovian_adapter_index if markovian_adapter_index is not None else adapter_index
    nm_idx = non_markovian_adapter_index if non_markovian_adapter_index is not None else adapter_index

    # Load both log files
    print("Loading Markovian model log file...")
    with open(markovian_log_file, "r") as f:
        markovian_log_data = [json.loads(line) for line in f]
    
    print("Loading Non-Markovian model log file...")
    with open(non_markovian_log_file, "r") as f:
        non_markovian_log_data = [json.loads(line) for line in f]

    # Extract hyperparameters from both files
    markovian_hyperparams = markovian_log_data[0]
    non_markovian_hyperparams = non_markovian_log_data[0]
    
    # Verify they have the expected markovian settings
    markovian_flag = markovian_hyperparams.get("markovian", True)
    non_markovian_flag = non_markovian_hyperparams.get("markovian", True)
    
    print(f"Markovian log file markovian setting: {markovian_flag}")
    print(f"Non-Markovian log file markovian setting: {non_markovian_flag}")
    
    if markovian_flag == non_markovian_flag:
        print("WARNING: Both log files have the same markovian setting!")
    
    # Load both models with their respective adapters
    print("Loading Markovian model...")
    actor_markovian, frozen_markovian, tokenizer, device = load_model_with_adapters(markovian_log_file, markovian_hyperparams["model_type"], markovian_hyperparams, adapter_index=m_idx)
    
    print("Loading Non-Markovian model...")
    actor_non_markovian, frozen_non_markovian, _, _ = load_model_with_adapters(non_markovian_log_file, non_markovian_hyperparams["model_type"], non_markovian_hyperparams, adapter_index=nm_idx)

    # Select evaluator per flag (default actor): use adapter-loaded actor, or frozen baseline
    markovian_eval_model = actor_markovian if evaluator == "actor" else frozen_markovian
    non_markovian_eval_model = actor_non_markovian if evaluator == "actor" else frozen_non_markovian

    # Filter log data by batch index if max_index is provided
    if max_index is not None:
        markovian_log_data = [entry for entry in markovian_log_data if entry.get("Batch Index", float('inf')) <= max_index]
        non_markovian_log_data = [entry for entry in non_markovian_log_data if entry.get("Batch Index", float('inf')) <= max_index]
        print(f"Processing entries up to batch index {max_index}")

    # Create output directory
    output_dir = os.path.join(os.path.dirname(markovian_log_file), "markovian_comparison")
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, f"comparison_results_{perturb_type}.json")
    
    # Extract entries to process from both models
    markovian_entries = []
    non_markovian_entries = []
    
    for entry in markovian_log_data[1:]:
        if "Example" in entry:
            markovian_entries.append(entry)
    
    for entry in non_markovian_log_data[1:]:
        if "Example" in entry:
            non_markovian_entries.append(entry)
    
    # Apply stride and ensure we have matching data
    markovian_entries = markovian_entries[::stride]
    non_markovian_entries = non_markovian_entries[::stride]
    
    # Use the minimum length to ensure we have paired data
    min_length = min(len(markovian_entries), len(non_markovian_entries))
    markovian_entries = markovian_entries[:min_length]
    non_markovian_entries = non_markovian_entries[:min_length]
    
    print(f"Processing {min_length} paired entries from both models")
    
    comparison_data = []
    
    # Process in batches
    for batch_idx in tqdm(range(0, min_length, batch_size), desc="Processing comparison batches"):
        batch_markovian = markovian_entries[batch_idx:batch_idx + batch_size]
        batch_non_markovian = non_markovian_entries[batch_idx:batch_idx + batch_size]
        batch_size_actual = len(batch_markovian)
        
        # Extract data for current batch
        questions = [entry["Example"].get("Question", "") for entry in batch_markovian]
        actor_cots_markovian = [entry["Example"]["Actor Reasoning"] for entry in batch_markovian]
        actor_cots_non_markovian = [entry["Example"]["Actor Reasoning"] for entry in batch_non_markovian]
        answers = [entry["Example"]["Answer"] for entry in batch_markovian]
        
        # Initialize batch results
        batch_results = []
        for i in range(batch_size_actual):
            batch_results.append({
                "Batch Index": batch_markovian[i].get("Batch Index", None),
                "Markovian Effects": {},
                "Non_Markovian Effects": {},
                "Effect Difference": {}  # Will be Markovian - Non_Markovian
            })
        
        # Calculate original log probs for both models
        # Markovian: without question, using trained Markovian model
        markovian_original_logprobs, _, _ = calculate_answer_log_probs(
            model=markovian_eval_model,
            tokenizer=tokenizer,
            device=device,
            questions=questions,
            reasoning=actor_cots_markovian,
            answers=answers,
            hyperparameters=markovian_hyperparams,
            include_question=False,  # Markovian doesn't use question
        )
        
        # Non-Markovian: with question, using trained Non-Markovian model
        non_markovian_original_logprobs, _, _ = calculate_answer_log_probs(
            model=non_markovian_eval_model,
            tokenizer=tokenizer,
            device=device,
            questions=questions,
            reasoning=actor_cots_non_markovian,
            answers=answers,
            hyperparameters=non_markovian_hyperparams,
            include_question=True,  # Non-Markovian uses question
        )
        
        # Process each perturbation
        for pert_name, pert_config in perturbations.items():
            if pert_name == "Original":
                continue
            
            # Perturb reasoning for both models
            perturbed_markovian_cots = [perturb_CoT(cot, pert_config) for cot in actor_cots_markovian]
            perturbed_non_markovian_cots = [perturb_CoT(cot, pert_config) for cot in actor_cots_non_markovian]
            
            # Calculate perturbed log probs
            # Markovian: without question, using trained Markovian model
            markovian_perturbed_logprobs, _, _ = calculate_answer_log_probs(
                model=markovian_eval_model,
                tokenizer=tokenizer,
                device=device,
                questions=questions,
                reasoning=perturbed_markovian_cots,
                answers=answers,
                hyperparameters=markovian_hyperparams,
                include_question=False,
            )
            
            # Non-Markovian: with question, using trained Non-Markovian model
            non_markovian_perturbed_logprobs, _, _ = calculate_answer_log_probs(
                model=non_markovian_eval_model,
                tokenizer=tokenizer,
                device=device,
                questions=questions,
                reasoning=perturbed_non_markovian_cots,
                answers=answers,
                hyperparameters=non_markovian_hyperparams,
                include_question=True,
            )
            
            # Calculate perturbation effects for this batch
            for i in range(batch_size_actual):
                markovian_effect = markovian_original_logprobs[i].item() - markovian_perturbed_logprobs[i].item()
                non_markovian_effect = non_markovian_original_logprobs[i].item() - non_markovian_perturbed_logprobs[i].item()
                effect_difference = markovian_effect - non_markovian_effect
                
                batch_results[i]["Markovian Effects"][pert_name] = markovian_effect
                batch_results[i]["Non_Markovian Effects"][pert_name] = non_markovian_effect
                batch_results[i]["Effect Difference"][pert_name] = effect_difference
        
        # Add batch results to overall results
        comparison_data.extend(batch_results)
        
        # Periodically save intermediate results
        if save_interval > 0 and (batch_idx + batch_size_actual) % save_interval == 0:
            with open(output_path, "w") as f:
                json.dump(comparison_data, f)
            print(f"\nSaved {len(comparison_data)} comparison results to {output_path}")
    
    # Save final results
    with open(output_path, "w") as f:
        json.dump(comparison_data, f)
    
    print(f"Markovian comparison analysis complete. Processed {len(comparison_data)} entries.")
    print(f"Results saved to {output_path}")
    
    return comparison_data, markovian_hyperparams, non_markovian_hyperparams


def _generate_actor_cots_for_questions(model, tokenizer, device, questions, hyperparameters):
    """Generate actor chain-of-thought texts for a batch of questions using shared helper."""
    return generate_actor_reasoning(
        actor_model=model,
        tokenizer=tokenizer,
        device=device,
        questions=list(questions),
        hyperparameters=hyperparameters,
    )


def run_markovian_comparison_fresh(
    markovian_log_file,
    non_markovian_log_file,
    perturb_type,
    num_samples=128,
    task_type="wiki_continuation",
    question_length=None,
    target_length=None,
    batch_size=8,
    evaluator="actor",
    adapter_index=None,
    markovian_adapter_index=None,
    non_markovian_adapter_index=None,
):
    """Run comparison using fixed checkpoints on fresh datapoints, not training logs."""
    if perturb_type not in PERTURBATION_SETS:
        raise ValueError(f"Unknown perturbation type: {perturb_type}")

    perturbations = PERTURBATION_SETS[perturb_type]["perturbations"]

    # Resolve adapter indices
    m_idx = markovian_adapter_index if markovian_adapter_index is not None else adapter_index
    nm_idx = non_markovian_adapter_index if non_markovian_adapter_index is not None else adapter_index

    # Load hyperparameters from both logs
    with open(markovian_log_file, "r") as f:
        markovian_hyperparams = json.loads(next(f))
    with open(non_markovian_log_file, "r") as f:
        non_markovian_hyperparams = json.loads(next(f))

    # Force desired task type for fresh comparison
    markovian_hyperparams = {**markovian_hyperparams, "task_type": task_type}
    non_markovian_hyperparams = {**non_markovian_hyperparams, "task_type": task_type}

    # If lengths provided, override
    if question_length is not None:
        markovian_hyperparams["question_length"] = int(question_length)
        non_markovian_hyperparams["question_length"] = int(question_length)
    if target_length is not None:
        markovian_hyperparams["target_length"] = int(target_length)
        non_markovian_hyperparams["target_length"] = int(target_length)

    # Load models with specified adapters
    actor_markovian, frozen_markovian, tokenizer, device = load_model_with_adapters(
        markovian_log_file, markovian_hyperparams["model_type"], markovian_hyperparams, adapter_index=m_idx
    )
    actor_non_markovian, frozen_non_markovian, _, _ = load_model_with_adapters(
        non_markovian_log_file, non_markovian_hyperparams["model_type"], non_markovian_hyperparams, adapter_index=nm_idx
    )

    markovian_eval_model = actor_markovian if evaluator == "actor" else frozen_markovian
    non_markovian_eval_model = actor_non_markovian if evaluator == "actor" else frozen_non_markovian

    # Prepare fresh dataset QA pairs
    q_len = int(markovian_hyperparams.get("question_length", 512))
    t_len = int(markovian_hyperparams.get("target_length", 128))
    qa_pairs = list(load_wiki_pairs(tokenizer, q_len, t_len, num_samples, start_index=10000))
    if not qa_pairs:
        raise RuntimeError("No suitable wiki samples found for fresh comparison.")

    # Process in batches
    comparison_data = []
    output_dir = os.path.join(os.path.dirname(markovian_log_file), "markovian_comparison")
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, f"comparison_results_fresh_{perturb_type}.json")

    for batch_start in tqdm(range(0, len(qa_pairs), batch_size), desc="Processing comparison batches (fresh)"):
        batch = qa_pairs[batch_start: batch_start + batch_size]
        questions, answers = zip(*batch)

        # Generate actor CoTs for each model on same questions
        actor_cots_markovian = _generate_actor_cots_for_questions(actor_markovian, tokenizer, device, questions, markovian_hyperparams)
        actor_cots_non_markovian = _generate_actor_cots_for_questions(actor_non_markovian, tokenizer, device, questions, non_markovian_hyperparams)

        # Initialize batch results (use sequential index as batch index)
        batch_results = []
        for i in range(len(questions)):
            batch_results.append({
                "Batch Index": batch_start + i,
                "Markovian Effects": {},
                "Non_Markovian Effects": {},
                "Effect Difference": {}
            })

        # Original log probs
        markovian_original_logprobs, _, _ = calculate_answer_log_probs(
            model=markovian_eval_model,
            tokenizer=tokenizer,
            device=device,
            questions=list(questions),
            reasoning=actor_cots_markovian,
            answers=list(answers),
            hyperparameters=markovian_hyperparams,
            include_question=False,
        )
        non_markovian_original_logprobs, _, _ = calculate_answer_log_probs(
            model=non_markovian_eval_model,
            tokenizer=tokenizer,
            device=device,
            questions=list(questions),
            reasoning=actor_cots_non_markovian,
            answers=list(answers),
            hyperparameters=non_markovian_hyperparams,
            include_question=True,
        )

        # Perturbations
        for pert_name, pert_config in perturbations.items():
            if pert_name == "Original":
                continue
            perturbed_markovian_cots = [perturb_CoT(cot, pert_config) for cot in actor_cots_markovian]
            perturbed_non_markovian_cots = [perturb_CoT(cot, pert_config) for cot in actor_cots_non_markovian]

            markovian_perturbed_logprobs, _, _ = calculate_answer_log_probs(
                model=markovian_eval_model,
                tokenizer=tokenizer,
                device=device,
                questions=list(questions),
                reasoning=perturbed_markovian_cots,
                answers=list(answers),
                hyperparameters=markovian_hyperparams,
                include_question=False,
            )
            non_markovian_perturbed_logprobs, _, _ = calculate_answer_log_probs(
                model=non_markovian_eval_model,
                tokenizer=tokenizer,
                device=device,
                questions=list(questions),
                reasoning=perturbed_non_markovian_cots,
                answers=list(answers),
                hyperparameters=non_markovian_hyperparams,
                include_question=True,
            )

            for i in range(len(questions)):
                markovian_effect = markovian_original_logprobs[i].item() - markovian_perturbed_logprobs[i].item()
                non_markovian_effect = non_markovian_original_logprobs[i].item() - non_markovian_perturbed_logprobs[i].item()
                effect_difference = markovian_effect - non_markovian_effect
                batch_results[i]["Markovian Effects"][pert_name] = markovian_effect
                batch_results[i]["Non_Markovian Effects"][pert_name] = non_markovian_effect
                batch_results[i]["Effect Difference"][pert_name] = effect_difference

        comparison_data.extend(batch_results)

        # Periodic save
        with open(output_path, "w") as f:
            json.dump(comparison_data, f)

    print(f"Markovian fresh comparison analysis complete. Processed {len(comparison_data)} entries.")
    print(f"Results saved to {output_path}")

    return comparison_data, markovian_hyperparams, non_markovian_hyperparams


def combine_all_markovian_comparison_plots(base_directory, font_size=12, include_perturbations=None, exclude_perturbations=None, legend_font_size=None):
    """
    Combine all markovian comparison plots from a directory into a single comprehensive figure.
    
    Args:
        base_directory: Base directory containing markovian_comparison subdirectories
        font_size: Base font size for plot elements (deprecated, use legend_font_size)
        include_perturbations: List of perturbation types to include (if None, include all)
        exclude_perturbations: List of perturbation types to exclude (if None, exclude none)
        legend_font_size: Font size for all text elements (if None, uses font_size for backward compatibility)
    """
    import matplotlib.pyplot as plt
    import matplotlib.image as mpimg
    from pathlib import Path
    import os
    import numpy as np
    
    # Use legend_font_size if provided, otherwise fall back to font_size for backward compatibility
    if legend_font_size is None:
        legend_font_size = font_size
    
    # Find all markovian comparison plot files
    plot_files = []
    perturbation_types = []
    
    markovian_dir = os.path.join(base_directory, "markovian_comparison")
    if os.path.exists(markovian_dir):
        for filename in os.listdir(markovian_dir):
            if filename.startswith("markovian_comparison_") and filename.endswith("_plot.png"):
                # Extract perturbation type from filename
                perturb_type = filename.replace("markovian_comparison_", "").replace("_plot.png", "")
                
                # Apply include/exclude filters
                if include_perturbations is not None and perturb_type not in include_perturbations:
                    continue
                if exclude_perturbations is not None and perturb_type in exclude_perturbations:
                    continue
                
                plot_files.append(os.path.join(markovian_dir, filename))
                perturbation_types.append(perturb_type)
    
    if not plot_files:
        print(f"No markovian comparison plots found in {markovian_dir}")
        return
    
    # Sort by perturbation type for consistent ordering
    sorted_pairs = sorted(zip(plot_files, perturbation_types), key=lambda x: x[1])
    plot_files, perturbation_types = zip(*sorted_pairs)
    
    n_plots = len(plot_files)
    
    # Create subplot layout - try to make it roughly square
    if n_plots == 1:
        rows, cols = 1, 1
    elif n_plots <= 4:
        rows, cols = 2, 2
    elif n_plots <= 6:
        rows, cols = 2, 3
    elif n_plots <= 9:
        rows, cols = 3, 3
    else:
        rows, cols = 4, 3
    
    # Create figure with subplots
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 6, rows * 4))
    fig.suptitle('Comprehensive Markovian vs Non-Markovian Perturbation Analysis', 
                fontsize=legend_font_size + 4, fontweight='bold')
    
    # Flatten axes array for easier indexing
    if n_plots == 1:
        # For single plot, axes is a single matplotlib axis object
        axes = [axes]
    else:
        # For multiple plots, axes is a numpy array
        axes = axes.flatten()
    
    # Load and display each plot
    for i, (plot_file, perturb_type) in enumerate(zip(plot_files, perturbation_types)):
        try:
            img = mpimg.imread(plot_file)
            axes[i].imshow(img)
            axes[i].set_title(f'{perturb_type.replace("_", " ").title()}', 
                            fontsize=legend_font_size + 2, fontweight='bold')
            axes[i].axis('off')
        except Exception as e:
            print(f"Error loading {plot_file}: {e}")
            axes[i].text(0.5, 0.5, f'Error loading\n{perturb_type}', 
                        ha='center', va='center', fontsize=legend_font_size)
            axes[i].axis('off')
    
    # Hide any unused subplots
    for i in range(n_plots, len(axes)):
        axes[i].axis('off')
    
    # Adjust layout
    plt.tight_layout()
    plt.subplots_adjust(top=0.93)  # Make room for suptitle
    
    # Save combined plot
    output_path = os.path.join(markovian_dir, "combined_markovian_comparison_plots.png")
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"Combined plot saved to: {output_path}")
    print(f"Included {n_plots} perturbation types: {', '.join(sorted(perturbation_types))}")
    plt.close()


def plot_markovian_comparison_results(results, output_dir, perturb_type, window_size=40, font_size=12, legend_font_size=10, markovian_hyperparams=None, non_markovian_hyperparams=None):
    """
    Plot the Markovian vs Non-Markovian comparison results.
    
    Args:
        results: The comparison results data
        output_dir: Directory to save the plot
        perturb_type: The type of perturbation being analyzed
        window_size: Smoothing window size
        font_size: Base font size for plot text elements
        legend_font_size: Font size for the legend
    """
    if not results:
        print("No results to plot.")
        return
    
    # Get all perturbation degrees from the first entry
    perturbation_degrees = list(results[0]["Effect Difference"].keys())
    print(f"Found perturbation degrees: {perturbation_degrees}")
    
    # Only plot non-zero perturbation degrees
    baseline_name = f"{perturb_type.title().replace('_', '')}0%"
    plot_degrees = [deg for deg in perturbation_degrees if deg != baseline_name]
    print(f"Plotting degrees: {plot_degrees}")
    
    if not plot_degrees:
        print("No non-zero perturbation degrees found to plot.")
        return
    
    # Extract batch indices
    batch_indices = [entry["Batch Index"] for entry in results]
    
    # Create the plot
    plt.figure(figsize=(14, 8))
    colors = plt.cm.tab10(np.linspace(0, 1, len(plot_degrees)))
    
    for i, degree in enumerate(plot_degrees):
        # Extract effect differences for this perturbation degree
        effect_differences = [entry["Effect Difference"][degree] for entry in results]
        
        # Smoothing
        if window_size > 1 and len(effect_differences) > window_size:
            try:
                smoothed_effects = savgol_filter(effect_differences, window_size, 3)
                padding = window_size // 2
                x_values = range(padding, len(effect_differences) - padding)
                smoothed_effects = smoothed_effects[padding:-padding]
            except ValueError as e:
                print(f"Smoothing error: {e}. Using raw data.")
                x_values = range(len(effect_differences))
                smoothed_effects = effect_differences
        else:
            x_values = range(len(effect_differences))
            smoothed_effects = effect_differences
        
        # Plot this perturbation degree
        plt.plot(
            x_values,
            smoothed_effects,
            label=f"{degree}",
            color=colors[i],
            linewidth=2,
        )
    
    plt.grid(True, linestyle="--", alpha=0.7)
    plt.legend(fontsize=legend_font_size, loc="best")
    
    # Create x-axis label with batch size information if available
    xlabel = "Training Batch"
    if "fresh" in perturb_type:
        xlabel = "Sample Index (Fresh Evaluation)"
    elif markovian_hyperparams and non_markovian_hyperparams:
        m_batch_size = markovian_hyperparams.get('batch_size', 'unknown')
        nm_batch_size = non_markovian_hyperparams.get('batch_size', 'unknown')
        if m_batch_size == nm_batch_size:
            xlabel += f" (batch size={m_batch_size})"
        else:
            xlabel += f" (Markovian: {m_batch_size}, Non-Markovian: {nm_batch_size})"
    
    plt.xlabel(xlabel, fontsize=legend_font_size)
    plt.ylabel("Perturbation Effect Difference\n(Markovian Effect - Non-Markovian Effect)", fontsize=legend_font_size)
    
    title = f"Markovian vs Non-Markovian Comparison: {perturb_type.replace('_', ' ').title()}"
    if window_size > 1:
        title += f" (Smoothing: {window_size})"
    else:
        title += " (Raw Data)"
    
    plt.title(title, fontsize=legend_font_size + 2)
    plt.tick_params(axis="both", which="major", labelsize=legend_font_size)
    
    # Add a horizontal line at y=0 for reference
    plt.axhline(y=0, color='black', linestyle='-', alpha=0.3, linewidth=1)
    
    plt.tight_layout()
    
    # Save the plot
    output_file = os.path.join(output_dir, f"markovian_comparison_{perturb_type}_plot.png")
    plt.savefig(output_file, dpi=300, bbox_inches="tight")
    print(f"Markovian comparison plot saved to {output_file}")
    plt.close()


def summarize_markovian_comparison_results(results: List[Dict[str, Any]], perturb_type: str) -> List[Dict[str, Any]]:
    """
    Produce structured summary rows for Markovian comparison outputs.
    
    Each row captures mean sensitivity statistics for a single perturbation degree,
    making it easy to collate across runs or serialize to tables.
    """
    if not results:
        return []

    perturbation_degrees = list(results[0]["Markovian Effects"].keys())
    baseline_name = f"{perturb_type.title().replace('_', '')}0%"

    summary_rows = []
    for degree in perturbation_degrees:
        is_baseline = degree == baseline_name or degree.lower() == "original"
        if is_baseline:
            continue
        
        markovian_effects = [entry["Markovian Effects"][degree] for entry in results]
        non_markovian_effects = [entry["Non_Markovian Effects"][degree] for entry in results]
        effect_differences = [entry["Effect Difference"][degree] for entry in results]

        if not effect_differences:
            continue

        markovian_mean = float(np.mean(markovian_effects))
        non_markovian_mean = float(np.mean(non_markovian_effects))
        difference_mean = float(np.mean(effect_differences))
        difference_std = float(np.std(effect_differences))
        markovian_more = sum(1 for diff in effect_differences if diff > 0)
        non_markovian_more = sum(1 for diff in effect_differences if diff < 0)

        summary_rows.append(
            {
                "perturbation": perturb_type,
                "degree": degree,
                "markovian_mean": markovian_mean,
                "non_markovian_mean": non_markovian_mean,
                "mean_difference": difference_mean,
                "difference_std": difference_std,
                "markovian_more_sensitive": markovian_more,
                "non_markovian_more_sensitive": non_markovian_more,
                "num_examples": len(effect_differences),
                "is_baseline": is_baseline,
            }
        )

    return summary_rows


def _weighted_mean_and_std(values: List[float], stds: List[float], weights: List[float]) -> Tuple[float, float]:
    """Return weighted mean and population std using per-sample variances."""
    if not values:
        return 0.0, 0.0

    total_weight = sum(weights)
    if total_weight == 0:
        total_weight = float(len(values))
        weights = [1.0] * len(values)

    mean = sum(v * w for v, w in zip(values, weights)) / total_weight
    second_moment = sum(w * ((s ** 2) + (v ** 2)) for v, s, w in zip(values, stds, weights)) / total_weight
    variance = max(0.0, second_moment - mean ** 2)
    return mean, math.sqrt(variance)


def _aggregate_report_rows(rows: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
    if not rows:
        return None

    weights = [float(row.get("num_examples") or 0) for row in rows]
    if not any(weights):
        weights = [1.0] * len(rows)

    diff_vals = [float(row.get("mean_difference", 0.0)) for row in rows]
    diff_stds = [float(row.get("difference_std", 0.0)) for row in rows]
    mark_vals = [float(row.get("markovian_mean", 0.0)) for row in rows]
    non_vals = [float(row.get("non_markovian_mean", 0.0)) for row in rows]
    zero_stds = [0.0] * len(rows)

    mean_difference, difference_std = _weighted_mean_and_std(diff_vals, diff_stds, weights)
    markovian_mean, _ = _weighted_mean_and_std(mark_vals, zero_stds, weights)
    non_markovian_mean, _ = _weighted_mean_and_std(non_vals, zero_stds, weights)

    return {
        "mean_difference": mean_difference,
        "difference_std": difference_std,
        "markovian_mean": markovian_mean,
        "non_markovian_mean": non_markovian_mean,
        "num_examples": int(sum(weights)),
        "num_runs": len(rows),
    }


def _degree_sort_key(label: str) -> Tuple[str, int, str]:
    if not label:
        return ("", 0, "")
    lower = label.lower()
    if lower == "original":
        return ("", 0, label)
    prefix_match = re.match(r"([A-Za-z_]+)", label)
    prefix = prefix_match.group(1).lower() if prefix_match else lower
    number_match = re.search(r"(\d+)", label)
    number = int(number_match.group(1)) if number_match else 0
    return (prefix, number, label)


def _normalize_degree_label(label: str, aggregate_types: bool) -> str:
    if not label:
        return label
    if not aggregate_types:
        return label
    if label.lower() == "original":
        return "Original"
    normalized = re.sub(r"(\d+%?)$", "", label).rstrip("_")
    return normalized or label


def analyze_markovian_comparison_summary(results, perturb_type):
    """
    Print a summary analysis of the Markovian comparison results.
    """
    summary_rows = summarize_markovian_comparison_results(results, perturb_type)
    if not summary_rows:
        print("No results to analyze.")
        return

    print(f"\n=== MARKOVIAN COMPARISON SUMMARY: {perturb_type.upper()} ===")

    for row in summary_rows:
        degree = row["degree"]
        mean_markovian = row["markovian_mean"]
        mean_non_markovian = row["non_markovian_mean"]
        mean_difference = row["mean_difference"]
        std_difference = row["difference_std"]
        markovian_more_sensitive = row["markovian_more_sensitive"]
        non_markovian_more_sensitive = row["non_markovian_more_sensitive"]
        total = row["num_examples"]

        print(f"\n{degree}:")
        print(f"  Mean Markovian Effect: {mean_markovian:.4f}")
        print(f"  Mean Non-Markovian Effect: {mean_non_markovian:.4f}")
        print(f"  Mean Difference (M - NM): {mean_difference:.4f} ± {std_difference:.4f}")
        print(f"  Markovian more sensitive: {markovian_more_sensitive}/{total} cases")
        print(f"  Non-Markovian more sensitive: {non_markovian_more_sensitive}/{total} cases")

        if mean_difference > 0:
            print(f"  → Overall: Markovian model is MORE sensitive to {degree} perturbations")
        elif mean_difference < 0:
            print(f"  → Overall: Non-Markovian model is MORE sensitive to {degree} perturbations")
        else:
            print(f"  → Overall: Similar sensitivity to {degree} perturbations")

    print("\n" + "=" * 60)


def build_dataset_perturbation_matrix(
    base_dir: str,
    metric: str = "accuracy",
    perturbations: Optional[Iterable[str]] = None,
    aggregate_perturbation_types: bool = False,
) -> Dict[str, Any]:
    """
    Aggregate mean difference/std stats into a dataset x perturbation matrix.
    
    Returns:
        {
            "datasets": [...],
            "degrees": [...],
            "cells": {dataset: {degree: {...}}},
            "dataset_average": {dataset: {...}},
            "degree_average": {degree: {...}},
            "overall_average": {...},
        }
    """
    report_rows = generate_markovian_comparison_report(
        base_dir=base_dir,
        metric=metric,
        perturbations=perturbations,
        aggregate_perturbation_types=aggregate_perturbation_types,
    )
    datasets = sorted({row.get("task") for row in report_rows if row.get("task")})
    degrees = sorted({row["degree"] for row in report_rows}, key=_degree_sort_key)

    cells: Dict[str, Dict[str, Optional[Dict[str, Any]]]] = {ds: {} for ds in datasets}
    dataset_avg: Dict[str, Optional[Dict[str, Any]]] = {}
    degree_avg: Dict[str, Optional[Dict[str, Any]]] = {}

    for dataset in datasets:
        dataset_rows = [row for row in report_rows if row.get("task") == dataset]
        dataset_avg[dataset] = _aggregate_report_rows(dataset_rows)
        for degree in degrees:
            degree_rows = [row for row in dataset_rows if row["degree"] == degree]
            cells[dataset][degree] = _aggregate_report_rows(degree_rows)

    for degree in degrees:
        degree_rows = [row for row in report_rows if row["degree"] == degree]
        degree_avg[degree] = _aggregate_report_rows(degree_rows)

    overall_avg = _aggregate_report_rows(report_rows)

    return {
        "datasets": datasets,
        "degrees": degrees,
        "cells": cells,
        "dataset_average": dataset_avg,
        "degree_average": degree_avg,
        "overall_average": overall_avg,
    }


def generate_markovian_comparison_report(
    base_dir: str,
    metric: str = "accuracy",
    perturbations: Optional[Iterable[str]] = None,
    aggregate_perturbation_types: bool = False,
) -> List[Dict[str, Any]]:
    """
    Traverse result directories and collate Markovian comparison summaries
    into a tabular-friendly list of dictionaries.
    
    Args:
        base_dir: Root directory containing task subdirectories (e.g., 'results').
        metric: Comparison metric subdirectory (default: 'accuracy').
        perturbations: Optional iterable of perturbation names to keep.
    
    Returns:
        List of summary rows enriched with task/run metadata.
    """
    base_dir = os.path.abspath(base_dir)
    comparison_dir_name = f"markovian_comparison_accuracy"
    filename_prefix = f"comparison_results_accuracy_"
    glob_pattern = os.path.join(
        base_dir, "**", comparison_dir_name, f"{filename_prefix}*.json"
    )

    matched_files = glob.glob(glob_pattern, recursive=True)
    if not matched_files:
        return []

    allowed = set(perturbations) if perturbations else None
    report_rows: List[Dict[str, Any]] = []

    for file_path in matched_files:
        perturb_type = os.path.basename(file_path)[len(filename_prefix) : -5]
        if allowed and perturb_type not in allowed:
            continue

        try:
            with open(file_path, "r") as f:
                results = json.load(f)
        except Exception as exc:
            print(f"Warning: failed to read {file_path}: {exc}")
            continue

        summary_rows = summarize_markovian_comparison_results(results, perturb_type)
        if not summary_rows:
            continue

        run_dir = os.path.dirname(os.path.dirname(file_path))
        relative = os.path.relpath(run_dir, base_dir)
        parts = [p for p in relative.split(os.sep) if p not in {".", ""}]

        task = parts[0] if len(parts) >= 1 else None
        run_name = parts[1] if len(parts) >= 2 else os.path.basename(run_dir)

        for row in summary_rows:
            row["degree"] = _normalize_degree_label(row["degree"], aggregate_perturbation_types)
            record = {
                "task": task,
                "run": run_name,
                **row,
            }
            report_rows.append(record)

    return sorted(
        report_rows,
        key=lambda r: (
            r.get("task") or "",
            r.get("run") or "",
            r.get("perturbation") or "",
            r.get("degree") or "",
        ),
    )


def _format_dataset_label(dataset: str) -> str:
    if not dataset:
        return ""
    return dataset.replace("_", " ").title()


def _format_fragility_cell(
    cell: Optional[Dict[str, Any]],
    multiplier: float = 100.0,
    precision: int = 3,
) -> str:
    if not cell:
        return "-"
    mean = float(cell.get("mean_difference", 0.0)) * multiplier
    fmt = f"{{:+0.{precision}f}}"
    return fmt.format(mean)


def build_fragility_rows(
    matrix: Dict[str, Any],
    datasets: List[str],
    perturbations: List[str],
) -> Tuple[List[Dict[str, Any]], Dict[str, Any], Dict[str, Any]]:
    """Prepare dataset rows, column averages, and overall summary for fragility tables."""
    rows = []
    dataset_avgs = []

    for dataset in datasets:
        cell_map = matrix["cells"].get(dataset, {})
        dataset_avg = matrix["dataset_average"].get(dataset)
        rows.append(
            {
                "dataset": dataset,
                "label": _format_dataset_label(dataset),
                "cells": {pert: cell_map.get(pert) for pert in perturbations},
                "average": dataset_avg,
            }
        )
        if dataset_avg:
            dataset_avgs.append(dataset_avg)

    column_averages: Dict[str, Any] = {}
    for pert in perturbations:
        column_cells = [
            matrix["cells"].get(dataset, {}).get(pert)
            for dataset in datasets
            if matrix["cells"].get(dataset, {}).get(pert)
        ]
        column_averages[pert] = (
            _aggregate_report_rows(column_cells) if column_cells else None
        )

    overall = _aggregate_report_rows([avg for avg in dataset_avgs if avg])
    return rows, column_averages, overall


def format_fragility_markdown(
    rows: List[Dict[str, Any]],
    column_averages: Dict[str, Any],
    overall: Optional[Dict[str, Any]],
    perturbations: List[str],
    multiplier: float = 100.0,
    precision: int = 3,
    summary_label: Optional[str] = "Overall",
) -> str:
    header = ["Dataset", *perturbations, "Average"]
    lines = [
        "| " + " | ".join(header) + " |",
        "| " + " | ".join(["---"] * len(header)) + " |",
    ]

    for row in rows:
        cells = [
            _format_fragility_cell(row["cells"].get(pert), multiplier, precision)
            for pert in perturbations
        ]
        avg_value = _format_fragility_cell(row["average"], multiplier, precision)
        lines.append(
            "| "
            + " | ".join([row["label"]] + cells + [avg_value])
            + " |"
        )

    if summary_label:
        summary_cells = [
            _format_fragility_cell(column_averages.get(pert), multiplier, precision)
            for pert in perturbations
        ]
        summary_avg = _format_fragility_cell(overall, multiplier, precision)
        lines.append(
            "| " + " | ".join([summary_label] + summary_cells + [summary_avg]) + " |"
        )
    return "\n".join(lines)


def _fragility_cell_to_dict(cell: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
    if not cell:
        return None
    payload = {
        "mean_difference": float(cell.get("mean_difference", 0.0)),
        "difference_std": float(cell.get("difference_std", 0.0)),
        "markovian_mean": float(cell.get("markovian_mean", 0.0)),
        "non_markovian_mean": float(cell.get("non_markovian_mean", 0.0)),
        "num_examples": int(cell.get("num_examples", 0)),
    }
    if "num_runs" in cell:
        payload["num_runs"] = int(cell.get("num_runs") or 0)
    return payload



def run_qa_perturbation_accuracy(
    markovian_log_file,
    non_markovian_log_file,
    perturb_type,
    task_type,
    num_samples=None,
    batch_size=8,
    evaluator="actor",
    adapter_index=None,
    question_length=None,
    target_length=None,
    markovian_adapter_index=None,
    non_markovian_adapter_index=None,
    stride: int = 1,
):
    """
    Run perturbation analysis measuring ACCURACY drop on QA tasks.
    """
    if perturb_type not in PERTURBATION_SETS:
        raise ValueError(f"Unknown perturbation type: {perturb_type}")

    perturbations = PERTURBATION_SETS[perturb_type]["perturbations"]

    markovian_dir = os.path.dirname(markovian_log_file)
    non_markovian_dir = os.path.dirname(non_markovian_log_file)
    # Use granular sync if possible, but these calls ensure directories exist
    # The actual file content should have been synced by the caller if running from perturbation_sweep.py
    # If running standalone, we might want to sync, but we're removing the flag.
    # Assuming caller handles sync or files are local.
    
    markovian_role = infer_role_from_log_path(markovian_log_file)
    non_markovian_role = infer_role_from_log_path(non_markovian_log_file)

    # Resolve adapter indices
    m_idx = markovian_adapter_index if markovian_adapter_index is not None else adapter_index
    nm_idx = non_markovian_adapter_index if non_markovian_adapter_index is not None else adapter_index
    mark_adapter_dir = os.path.join(markovian_dir, f"adapter_{m_idx}") if m_idx is not None else None
    non_adapter_dir = os.path.join(non_markovian_dir, f"adapter_{nm_idx}") if nm_idx is not None else None

    # Load hyperparameters
    with open(markovian_log_file, "r") as f:
        markovian_hyperparams = json.loads(next(f))
    with open(non_markovian_log_file, "r") as f:
        non_markovian_hyperparams = json.loads(next(f))

    # Force task settings
    markovian_hyperparams = {**markovian_hyperparams, "task_type": task_type}
    non_markovian_hyperparams = {**non_markovian_hyperparams, "task_type": task_type}
    
    if question_length:
        markovian_hyperparams["question_length"] = int(question_length)
        non_markovian_hyperparams["question_length"] = int(question_length)
    if target_length:
        markovian_hyperparams["target_length"] = int(target_length)
        non_markovian_hyperparams["target_length"] = int(target_length)

    # Load models
    actor_markovian, frozen_markovian, tokenizer, device = load_model_with_adapters(
        markovian_log_file, markovian_hyperparams["model_type"], markovian_hyperparams, adapter_index=m_idx
    )
    actor_non_markovian, frozen_non_markovian, _, _ = load_model_with_adapters(
        non_markovian_log_file, non_markovian_hyperparams["model_type"], non_markovian_hyperparams, adapter_index=nm_idx
    )

    markovian_eval_model = actor_markovian if evaluator == "actor" else frozen_markovian
    non_markovian_eval_model = actor_non_markovian if evaluator == "actor" else frozen_non_markovian
    
    # Load Data
    qa_pairs = []
    print(f"Loading fresh {task_type} data for accuracy analysis...")
    
    if task_type == "gsm8k":
        qa_pairs = list(load_gsm8k_dataset(split="test"))
    elif task_type == "mmlu":
        subject = markovian_hyperparams.get("mmlu_subject", None)
        qa_pairs = list(load_mmlu_dataset(split="test", subject=subject))
    elif task_type == "math":
        qa_pairs = list(load_math_dataset(split="test"))
    elif task_type == "svamp":
        qa_pairs = list(load_svamp_dataset(split="test"))
    elif task_type == "aqua":
        qa_pairs = list(load_aqua_dataset(split="test"))
    elif task_type == "mathqa":
        qa_pairs = list(load_mathqa_dataset(split="test"))
    elif task_type == "arc":
        qa_pairs = list(load_arc_dataset(split="validation"))
    elif task_type == "arithmetic":
        qa_pairs = list(load_arithmetic_dataset(chunk_size=num_samples, split="test"))
    elif task_type == "wiki_continuation":
        q_len = int(markovian_hyperparams.get("question_length", 512))
        t_len = int(markovian_hyperparams.get("target_length", 128))
        # Use a dummy tokenizer just to get the pairs if needed, or use the one we loaded
        # Here we use the tokenizer from load_model_with_adapters
        qa_pairs = list(load_wiki_pairs(tokenizer, q_len, t_len, num_samples if num_samples else 128, start_index=10000))
    else:
        raise ValueError(f"Unsupported task type for QA perturbation: {task_type}")

    if not qa_pairs:
        raise RuntimeError(f"No samples found for {task_type}")
    
    stride = max(1, int(stride or 1))
    qa_pairs = qa_pairs[::stride]
    if num_samples:
        qa_pairs = qa_pairs[:num_samples]
    if not qa_pairs:
        raise RuntimeError("No samples left after applying stride/num_samples filters.")
    effective_num_samples = len(qa_pairs)
    print(f"Loaded {effective_num_samples} examples with stride={stride}.")

    # Helper to select evaluation function
    def get_eval_func(tt):
        if tt == "gsm8k": return evaluate_model_on_gsm8k
        if tt == "mmlu": return evaluate_model_on_mmlu
        if tt == "arc": return evaluate_model_on_arc
        if tt == "aqua": return evaluate_model_on_aqua
        if tt == "mathqa": return evaluate_model_on_mathqa
        if tt in ["svamp", "math", "arithmetic"]: return evaluate_model_on_numeric
        if tt == "wiki_continuation": return evaluate_wiki_logprob
        return evaluate_model_on_numeric 

    eval_func = get_eval_func(task_type)

    # 1. Generate Original CoTs and Baselines
    print("Generating Original CoTs and Baseline Accuracy...")
    
    # Markovian (no question in Stage 2)
    markovian_hyperparams["markovian"] = True 
    _, m_results, _ = eval_func(
        actor_markovian, markovian_eval_model, tokenizer, device, qa_pairs, markovian_hyperparams,
        batch_size=batch_size, num_samples=len(qa_pairs)
    )
    m_cots = [r["reasoning"] for r in m_results]
    
    # Non-Markovian (question in Stage 2)
    non_markovian_hyperparams["markovian"] = False
    _, nm_results, _ = eval_func(
        actor_non_markovian, non_markovian_eval_model, tokenizer, device, qa_pairs, non_markovian_hyperparams,
        batch_size=batch_size, num_samples=len(qa_pairs)
    )
    nm_cots = [r["reasoning"] for r in nm_results]
    
    # Prepare results structure
    comparison_data = []
    for i in range(len(qa_pairs)):
        comparison_data.append({
            "Batch Index": i,
            "Markovian Effects": {},
            "Non_Markovian Effects": {},
            "Effect Difference": {}
        })

    # 2. Run Perturbations
    for pert_name, pert_config in perturbations.items():
        if pert_name == "Original":
            continue
            
        print(f"Evaluating perturbation: {pert_name}")
        
        # Define perturbation function
        def p_func(text):
            return perturb_CoT(text, pert_config)
            
        # Evaluate Markovian
        _, m_pert_results, _ = eval_func(
            actor_markovian, markovian_eval_model, tokenizer, device, qa_pairs, markovian_hyperparams,
            batch_size=batch_size, num_samples=len(qa_pairs),
            precomputed_cots=m_cots,
            perturbation_fn=p_func
        )
        
        # Evaluate Non-Markovian
        _, nm_pert_results, _ = eval_func(
            actor_non_markovian, non_markovian_eval_model, tokenizer, device, qa_pairs, non_markovian_hyperparams,
            batch_size=batch_size, num_samples=len(qa_pairs),
            precomputed_cots=nm_cots,
            perturbation_fn=p_func
        )
        
        for i in range(len(qa_pairs)):
            # For Wiki, 'correct' might be float or LogProb, but assuming evaluate_model_on_numeric returns binary or near-binary
            # Actually, for wiki_continuation, we usually care about LogProbs, not accuracy.
            # But if metric="accuracy" was requested, we treat it as such (maybe exact match?)
            # If wiki_continuation uses 'compute_wiki_logprob' inside evaluation.py, we might need a different flow.
            # Checking evaluate_model_on_numeric: it uses 'extract_answer' and compares.
            # Wiki continuation doesn't have a single answer.
            
            # For the sake of "accuracy" mode sweeping:
            m_orig = float(m_results[i].get("correct", 0))
            m_pert = float(m_pert_results[i].get("correct", 0))
            
            nm_orig = float(nm_results[i].get("correct", 0))
            nm_pert = float(nm_pert_results[i].get("correct", 0))
            
            m_sensitivity = m_orig - m_pert
            nm_sensitivity = nm_orig - nm_pert
            
            diff = m_sensitivity - nm_sensitivity
            
            comparison_data[i]["Markovian Effects"][pert_name] = m_sensitivity
            comparison_data[i]["Non_Markovian Effects"][pert_name] = nm_sensitivity
            comparison_data[i]["Effect Difference"][pert_name] = diff

    # Save results
    output_dir = os.path.join(markovian_dir, "markovian_comparison_accuracy")
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, f"comparison_results_accuracy_{perturb_type}.json")
    
    with open(output_path, "w") as f:
        json.dump(comparison_data, f)
        
    if os.path.isdir(markovian_dir) and os.path.isdir(non_markovian_dir):
        metadata_key = build_perturb_metadata_key(
            task_type=task_type,
            perturb_type=perturb_type,
            metric="accuracy",
            paired_role=non_markovian_role,
            paired_adapter_index=nm_idx,
            markovian_run=markovian_dir,
            non_markovian_run=non_markovian_dir,
        )
        timestamp = datetime.datetime.utcnow().isoformat()
        record_common = {
            "task_type": task_type,
            "perturbation": perturb_type,
            "metric": "accuracy",
            "num_samples": effective_num_samples,
            "batch_size": batch_size,
            "stride": stride,
            "summary": compute_sensitivity_summary(comparison_data, perturb_type),
            "timestamp": timestamp,
            "comparison_results_file": os.path.join("markovian_comparison_accuracy", os.path.basename(output_path)),
        }
        mark_record = {
            **record_common,
            "adapter_index": m_idx,
            "paired_adapter_index": nm_idx,
            "role": markovian_role,
            "paired_role": non_markovian_role,
            "paired_run": os.path.basename(non_markovian_dir),
            "status": "completed",
        }
        non_record = {
            **record_common,
            "adapter_index": nm_idx,
            "paired_adapter_index": m_idx,
            "role": non_markovian_role,
            "paired_role": markovian_role,
            "paired_run": os.path.basename(markovian_dir),
            "status": "completed",
            # Store reference to the file in the Markovian run to avoid duplication
            "comparison_results_file": os.path.join("..", os.path.basename(markovian_dir), "markovian_comparison_accuracy", os.path.basename(output_path))
        }
        update_metadata_record(markovian_dir, metadata_key, mark_record)
        update_metadata_record(non_markovian_dir, metadata_key, non_record)

    print(f"Accuracy perturbation analysis saved to {output_path}")
    sync_run_dir_outputs(markovian_dir)
    sync_run_dir_outputs(non_markovian_dir)
    return comparison_data, markovian_hyperparams, non_markovian_hyperparams


def main():
    parser = argparse.ArgumentParser(description="Perturbation Analysis Tool")
    parser.add_argument("--log_file", help="Log file to analyze or directory containing perturbation results")
    parser.add_argument("--metric", type=str, default="log_prob", choices=["log_prob", "accuracy"], help="Metric to evaluate: 'log_prob' or 'accuracy'")
    parser.add_argument(
        "--window_size", type=int, default=40, help="Smoothing window size"
    )
    parser.add_argument(
        "--stride", type=int, default=1, help="Process every nth entry of the log file"
    )
    parser.add_argument(
        "--debug", action="store_true", help="Generate debug plots with raw values"
    )
    parser.add_argument("--max_index", type=int, help="Maximum index to plot")
    parser.add_argument(
        "--plot_only",
        action="store_true",
        help="Only generate plots from saved results",
    )
    parser.add_argument(
        "--process_only", action="store_true", help="Only process data without plotting"
    )
    parser.add_argument(
        "--include_question",
        action="store_true",
        help="Include the question text in the prompt when evaluating",
    )
    parser.add_argument(
        "--save_interval",
        type=int,
        default=10,
        help="Save intermediate results every N entries (0 to disable)",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=8,
        help="Number of examples to process in each batch (0 for non-batched processing)",
    )
    parser.add_argument(
        "--evaluator",
        type=str,
        choices=["actor", "frozen"],
        default="actor",
        help="Which model to use for evaluation: adapter-loaded actor (actor) or frozen baseline (frozen). Default: actor",
    )
    parser.add_argument(
        "--adapter_index",
        type=int,
        help="Force loading a specific adapter index (e.g., 400 will use adapter_400)",
    )
    # Fresh datapoint comparison flags
    parser.add_argument(
        "--fresh_comparison",
        action="store_true",
        help="Run Markovian vs Non-Markovian comparison on fresh datapoints (not training logs)",
    )
    parser.add_argument(
        "--fresh_task_type",
        type=str,
        default="wiki_continuation",
        help="Task type for fresh comparison (e.g., wiki_continuation)",
    )
    parser.add_argument(
        "--fresh_num_samples",
        type=int,
        default=1024,
        help="Number of fresh samples to evaluate in fresh comparison",
    )
    parser.add_argument(
        "--fresh_question_length",
        type=int,
        help="Question/context length (tokens) for fresh wiki tasks",
    )
    parser.add_argument(
        "--fresh_target_length",
        type=int,
        help="Target/answer length (tokens) for fresh wiki tasks",
    )
    
    # New arguments for Markovian comparison
    parser.add_argument(
        "--markovian_comparison",
        action="store_true",
        help="Run Markovian vs Non-Markovian comparison analysis",
    )
    parser.add_argument(
        "--markovian_log",
        type=str,
        help="Path to Markovian model log file (for comparison mode)",
    )
    parser.add_argument(
        "--non_markovian_log", 
        type=str,
        help="Path to Non-Markovian model log file (for comparison mode)",
    )

    # Adjusted to not require --perturb when using --collate
    perturb_group = parser.add_mutually_exclusive_group(required=False)
    perturb_group.add_argument(
        "--perturb",
        nargs="+",
        choices=list(PERTURBATION_SETS.keys()),
        help="Type(s) of perturbation to analyze",
    )
    perturb_group.add_argument(
        "--all", action="store_true", help="Run all perturbation types"
    )

    # Modify the --collate help message
    parser.add_argument(
        "--collate",
        nargs="+",
        help="List of perturbation result JSON files to average"
    )
    parser.add_argument(
        "--output_dir",
        default="perturbation_results",
        help="Output directory for collated results",
    )

    parser.add_argument(
        "--font_size",
        type=int,
        default=12,
        help="Base font size for plot text elements"
    )
    parser.add_argument(
        "--legend_font_size",
        type=int,
        default=10,
        help="Font size for the legend in plots"
    )
    parser.add_argument(
        "--plot_multiple_perturbations",
        action="store_true",
        help="Generate a combined plot for multiple perturbation types"
    )
    parser.add_argument(
        "--combine_all_plots",
        action="store_true",
        help="Combine all existing markovian comparison plots into a single comprehensive figure"
    )
    parser.add_argument(
        "--include_perturbations",
        nargs="+",
        choices=list(PERTURBATION_SETS.keys()),
        help="Include only specified perturbation types in combined plot (for --combine_all_plots)"
    )
    parser.add_argument(
        "--exclude_perturbations",
        nargs="+",
        choices=list(PERTURBATION_SETS.keys()),
        help="Exclude specified perturbation types from combined plot (for --combine_all_plots)"
    )
    parser.add_argument(
        "--regenerate_before_combine",
        action="store_true",
        help="Regenerate individual markovian comparison plots with new parameters before combining (for --combine_all_plots)"
    )
    parser.add_argument(
        "--generate_report",
        type=str,
        help="Write generate_markovian_comparison_report output to the given JSON file and exit",
    )
    parser.add_argument(
        "--report_base_dir",
        type=str,
        default="results",
        help="Base directory to scan when generating reports (default: results)",
    )
    parser.add_argument(
        "--report_aggregate_types",
        action="store_true",
        help="Aggregate perturbation degrees by perturbation type in generated reports",
    )
    parser.add_argument(
        "--fragility_matrix_output",
        type=str,
        help="Aggregate perturbation sensitivity into a fragility matrix and save to this file",
    )
    parser.add_argument(
        "--fragility_metric",
        choices=["accuracy", "log_prob"],
        default="accuracy",
        help="Metric to use when building fragility matrix (default: accuracy)",
    )
    parser.add_argument(
        "--fragility_datasets",
        nargs="+",
        help="Datasets to include in the fragility matrix (default: QA set for accuracy, wiki_continuation for log_prob)",
    )
    parser.add_argument(
        "--fragility_format",
        choices=["markdown", "json"],
        default="markdown",
        help="Output format for fragility matrix (default: markdown)",
    )

    args = parser.parse_args()

    if args.all:
        args.perturb = list(PERTURBATION_SETS.keys())

    # Auto-detection of logs and adapters if missing
    markovian_adapter_index = args.adapter_index
    non_markovian_adapter_index = args.adapter_index
    
    # Determine target task for auto-detection
    target_task = args.fresh_task_type
        
    if (args.fresh_comparison or args.markovian_comparison) and (not args.markovian_log or not args.non_markovian_log):
        if target_task:
            print(f"Attempting to auto-detect logs for task: {target_task}")
            try:
                if not args.markovian_log:
                    log, idx = find_best_run_for_task(target_task, "Markovian")
                    if log:
                        print(f"  Markovian: {log} (adapter {idx})")
                        args.markovian_log = log
                        if idx is not None:
                            markovian_adapter_index = idx
                
                if not args.non_markovian_log:
                    log, idx = find_best_run_for_task(target_task, "NonMarkovian")
                    if log:
                        print(f"  NonMarkovian: {log} (adapter {idx})")
                        args.non_markovian_log = log
                        if idx is not None:
                            non_markovian_adapter_index = idx
            except (FileNotFoundError, RuntimeError, ValueError) as e:
                print(f"Auto-detection error: {e}")
                return

    # Default to best adapters defined in best_adapter.json when explicit indices not supplied
    if args.markovian_log and markovian_adapter_index is None:
        best_idx = load_best_adapter_index(os.path.dirname(args.markovian_log))
        if best_idx is not None:
            print(f"Using best Markovian adapter index {best_idx} from best_adapter.json")
            markovian_adapter_index = best_idx
    if args.non_markovian_log and non_markovian_adapter_index is None:
        best_idx = load_best_adapter_index(os.path.dirname(args.non_markovian_log))
        if best_idx is not None:
            print(f"Using best Non-Markovian adapter index {best_idx} from best_adapter.json")
            non_markovian_adapter_index = best_idx

    # Handle fresh datapoint comparison mode
    if args.fresh_comparison:
        if not args.markovian_log or not args.non_markovian_log:
            print("Error: --fresh_comparison requires both --markovian_log and --non_markovian_log (auto-detection failed)")
            return
        if not args.perturb:
            print("Error: --fresh_comparison requires --perturb argument")
            return
            
        # Sync run directories from S3 if needed (CLI usage assumed to want sync)
        if args.markovian_log and args.non_markovian_log:
            sync_run_dir_from_s3(os.path.dirname(args.markovian_log))
            sync_run_dir_from_s3(os.path.dirname(args.non_markovian_log))

        for perturb_type in args.perturb:
            print(f"Running Fresh Markovian vs Non-Markovian comparison for {perturb_type} (Metric: {args.metric})...")
            
            if args.metric == "accuracy":
                comparison_data, markovian_hyperparams, non_markovian_hyperparams = run_qa_perturbation_accuracy(
                    markovian_log_file=args.markovian_log,
                    non_markovian_log_file=args.non_markovian_log,
                    perturb_type=perturb_type,
                    task_type=args.fresh_task_type,
                    num_samples=args.fresh_num_samples,
                    batch_size=args.batch_size,
                    evaluator=args.evaluator,
                    adapter_index=args.adapter_index,
                    question_length=args.fresh_question_length,
                    target_length=args.fresh_target_length,
                    markovian_adapter_index=markovian_adapter_index,
                    non_markovian_adapter_index=non_markovian_adapter_index,
                    stride=args.stride,
                )
            else:
                comparison_data, markovian_hyperparams, non_markovian_hyperparams = run_markovian_comparison_fresh(
                    markovian_log_file=args.markovian_log,
                    non_markovian_log_file=args.non_markovian_log,
                    perturb_type=perturb_type,
                    num_samples=args.fresh_num_samples,
                    task_type=args.fresh_task_type,
                    question_length=args.fresh_question_length,
                    target_length=args.fresh_target_length,
                    batch_size=args.batch_size,
                    evaluator=args.evaluator,
                    adapter_index=args.adapter_index,
                    markovian_adapter_index=markovian_adapter_index,
                    non_markovian_adapter_index=non_markovian_adapter_index,
                )
            
            output_dir = os.path.join(os.path.dirname(args.markovian_log), f"markovian_comparison_{args.metric}")
            os.makedirs(output_dir, exist_ok=True)
            
            plot_markovian_comparison_results(
                results=comparison_data,
                output_dir=output_dir,
                perturb_type=f"fresh_{perturb_type}",
                window_size=args.window_size,
                font_size=args.font_size,
                legend_font_size=args.legend_font_size,
                markovian_hyperparams=markovian_hyperparams,
                non_markovian_hyperparams=non_markovian_hyperparams,
            )
            analyze_markovian_comparison_summary(comparison_data, f"fresh_{perturb_type}")
            print(f"Fresh markovian comparison for {perturb_type} completed.")
        return

    # Handle markovian comparison mode
    if args.markovian_comparison:
        if not args.markovian_log or not args.non_markovian_log:
            print("Error: --markovian_comparison requires both --markovian_log and --non_markovian_log arguments (auto-detection failed)")
            return
        if not args.perturb:
            print("Error: --markovian_comparison requires --perturb argument")
            return
        
        for perturb_type in args.perturb:
            print(f"Running Markovian vs Non-Markovian comparison for {perturb_type}...")
            comparison_data, markovian_hyperparams, non_markovian_hyperparams = run_markovian_comparison(
                markovian_log_file=args.markovian_log,
                non_markovian_log_file=args.non_markovian_log,
                perturb_type=perturb_type,
                stride=args.stride,
                max_index=args.max_index,
                save_interval=args.save_interval,
                batch_size=args.batch_size,
                evaluator=args.evaluator,
                adapter_index=args.adapter_index,
                markovian_adapter_index=markovian_adapter_index,
                non_markovian_adapter_index=non_markovian_adapter_index,
            )
            
            # Generate plots and analysis
            output_dir = os.path.join(os.path.dirname(args.markovian_log), "markovian_comparison")
            plot_markovian_comparison_results(
                results=comparison_data,
                output_dir=output_dir,
                perturb_type=perturb_type,
                window_size=args.window_size,
                font_size=args.font_size,
                legend_font_size=args.legend_font_size,
                markovian_hyperparams=markovian_hyperparams,
                non_markovian_hyperparams=non_markovian_hyperparams
            )
            
            # Print summary analysis
            analyze_markovian_comparison_summary(comparison_data, perturb_type)
            
            print(f"Markovian comparison for {perturb_type} completed.")
        
        return  # Exit after comparison analysis

    # Handle combine all plots mode
    if args.combine_all_plots:
        if not args.log_file:
            print("Error: --combine_all_plots requires --log_file argument to specify the base directory")
            return
        
        # Validate include/exclude arguments
        if args.include_perturbations and args.exclude_perturbations:
            print("Error: Cannot specify both --include_perturbations and --exclude_perturbations")
            return
        
        # If log_file points to a file, get its directory; if it's a directory, use it directly
        if os.path.isfile(args.log_file):
            base_dir = os.path.dirname(args.log_file)
        else:
            base_dir = args.log_file
            
        # Regenerate individual plots if requested
        if args.regenerate_before_combine:
            # Determine which perturbations to regenerate
            if args.include_perturbations:
                perturb_types_to_regenerate = args.include_perturbations
            else:
                # Use all available perturbation types, minus excluded ones
                perturb_types_to_regenerate = list(PERTURBATION_SETS.keys())
                if args.exclude_perturbations:
                    perturb_types_to_regenerate = [p for p in perturb_types_to_regenerate if p not in args.exclude_perturbations]
            
            print(f"Regenerating individual plots for: {perturb_types_to_regenerate}")
            markovian_dir = os.path.join(base_dir, "markovian_comparison")
            
            for perturb_type in perturb_types_to_regenerate:
                json_file = os.path.join(markovian_dir, f"comparison_results_{perturb_type}.json")
                if os.path.exists(json_file):
                    print(f"Regenerating plot for {perturb_type}...")
                    with open(json_file, 'r') as f:
                        results = json.load(f)
                    plot_markovian_comparison_results(
                        results=results,
                        output_dir=markovian_dir,
                        perturb_type=perturb_type,
                        window_size=args.window_size,
                        font_size=args.font_size,
                        legend_font_size=args.legend_font_size
                    )
                else:
                    print(f"Warning: {json_file} not found, skipping {perturb_type}")
        
        combine_all_markovian_comparison_plots(
            base_dir, 
            font_size=args.font_size,
            include_perturbations=args.include_perturbations,
            exclude_perturbations=args.exclude_perturbations,
            legend_font_size=args.legend_font_size
        )
        return

    if args.generate_report:
        base_dir = os.path.abspath(args.report_base_dir or "results")
        perturb_filter = args.perturb if args.perturb else None
        report_rows = generate_markovian_comparison_report(
            base_dir=base_dir,
            metric=args.metric,
            perturbations=perturb_filter,
            aggregate_perturbation_types=args.report_aggregate_types,
        )
        output_path = os.path.abspath(args.generate_report)
        output_dir = os.path.dirname(output_path)
        if output_dir:
            os.makedirs(output_dir, exist_ok=True)
        with open(output_path, "w") as f:
            json.dump(report_rows, f, indent=2)
        print(f"Wrote {len(report_rows)} comparison rows to {output_path}")
        return

    if args.fragility_matrix_output:
        base_dir = os.path.abspath(args.report_base_dir or "results")
        perturb_filter = args.perturb if args.perturb else None
        matrix = build_dataset_perturbation_matrix(
            base_dir=base_dir,
            metric=args.fragility_metric,
            perturbations=perturb_filter,
            aggregate_perturbation_types=True,
        )
        datasets = (
            args.fragility_datasets
            if args.fragility_datasets
            else (
                DEFAULT_FRAGILITY_QA_DATASETS
                if args.fragility_metric == "accuracy"
                else ["wiki_continuation"]
            )
        )
        perturbations = DEFAULT_FRAGILITY_PERTURBATIONS

        rows, column_avgs, overall = build_fragility_rows(
            matrix, datasets, perturbations
        )

        output_path = os.path.abspath(args.fragility_matrix_output)
        output_dir = os.path.dirname(output_path)
        if output_dir:
            os.makedirs(output_dir, exist_ok=True)

        if args.fragility_format == "json":
            payload = {
                "metric": args.fragility_metric,
                "datasets": [
                    {
                        "dataset": row["dataset"],
                        "label": row["label"],
                        "cells": {
                            pert: _fragility_cell_to_dict(row["cells"].get(pert))
                            for pert in perturbations
                        },
                        "average": _fragility_cell_to_dict(row["average"]),
                    }
                    for row in rows
                ],
                "column_average": {
                    pert: _fragility_cell_to_dict(column_avgs.get(pert))
                    for pert in perturbations
                },
                "overall_average": _fragility_cell_to_dict(overall),
            }
            with open(output_path, "w") as f:
                json.dump(payload, f, indent=2)
        else:
            multiplier = 100.0 if args.fragility_metric == "accuracy" else 1.0
            summary_label = None if len(rows) <= 1 else "Overall"
            markdown = format_fragility_markdown(
                rows,
                column_avgs,
                overall,
                perturbations,
                multiplier=multiplier,
                summary_label=summary_label,
            )
            if args.fragility_metric == "log_prob":
                markdown = "> Units: Δlog P (Markovian drop − Non-Markovian drop, nats)\n\n" + markdown
            with open(output_path, "w") as f:
                f.write(markdown + "\n")
        print(f"Fragility matrix written to {output_path}")
        return

    if args.collate:
        if not args.output_dir:
            print("Please specify an output directory using --output_dir when using --collate.")
            return
        # Extract perturb_type from the filenames
        perturb_types = set()
        include_question = False
        for file in args.collate:
            basename = os.path.basename(file)
            # Check if file includes question in the name
            if "_with_question.json" in basename:
                include_question = True
                basename = basename.replace("_with_question.json", ".json")
            if basename.startswith("perturbation_results_") and basename.endswith(".json"):
                perturb_type = basename[len("perturbation_results_"):-len(".json")]
                perturb_types.add(perturb_type)
            else:
                print(f"Invalid perturbation result file: {file}")
                return
        if len(perturb_types) != 1:
            print("All perturbation result files must be for the same perturbation type.")
            return
        perturb_type = perturb_types.pop()
        print(f"Collating results for perturbation type: {perturb_type}" + 
              (" (with question)" if include_question else ""))
        collate_perturbation_results(args.collate, args.output_dir, perturb_type, include_question)
        print(f"Collation complete. Results saved to {args.output_dir}")
        if not args.plot_only:
            return
        # Update log_file to point to collated results for plotting
        args.log_file = args.output_dir
        args.perturb = [perturb_type]
        args.include_question = include_question
    else:
        if args.log_file:
            if not args.perturb and not args.all:
                print("Please specify perturbation types using --perturb or --all.")
                return
        else:
            # Get the latest result directory
            log_dir = find_latest_result()
            if log_dir is None:
                print("No result directories found.")
                return
            args.log_file = log_dir
    
    # Run perturbation analysis if not in plot_only mode
    if not args.plot_only:
        for perturb_type in args.perturb:
            question_status = "with" if args.include_question else "without"
            print(f"Running perturbation analysis for {perturb_type} ({question_status} question)...")
            
            # Choose between batched and non-batched processing
            if args.batch_size > 0:
                print(f"Using batched processing with batch size {args.batch_size}")
                results = run_perturbations_batched(
                    args.log_file, 
                    perturb_type, 
                    include_question=args.include_question,
                    stride=args.stride, 
                    max_index=args.max_index,
                    save_interval=args.save_interval,
                    batch_size=args.batch_size,
                    evaluator=args.evaluator,
                    adapter_index=args.adapter_index,
                )
            else:
                print("Using non-batched processing")
                results = run_perturbations(
                    args.log_file, 
                    perturb_type, 
                    include_question=args.include_question,
                    stride=args.stride, 
                    max_index=args.max_index,
                    save_interval=args.save_interval,
                    evaluator=args.evaluator,
                    adapter_index=args.adapter_index,
                )
            
            save_perturbation_results(
                results, 
                args.log_file, 
                perturb_type, 
                include_question=args.include_question
            )
            print(f"Analysis for {perturb_type} completed and saved.")

    # Plot if needed
    if not args.process_only:
        if args.plot_only and args.plot_multiple_perturbations and len(args.perturb) > 1:
            # Create combined plot for multiple perturbation types
            plot_multiple_perturbation_results(
                args.log_file,
                args.perturb,
                window_size=args.window_size,
                max_index=args.max_index,
                font_size=args.font_size,
                legend_font_size=args.legend_font_size,
                include_question=args.include_question
            )
        else:
            for perturb_type in args.perturb:
                result_file = get_output_paths(args.log_file, perturb_type, args.include_question)["json"]
                try:
                    with open(result_file, "r") as f:
                        results = json.load(f)
                    plot_perturbation_results(
                        results,
                        args.log_file,
                        perturb_type,
                        window_size=args.window_size,
                        debug=args.debug,
                        max_index=args.max_index,
                        font_size=args.font_size,
                        legend_font_size=args.legend_font_size,
                        include_question=args.include_question
                    )
                except FileNotFoundError:
                    print(
                        f"No saved results found for {perturb_type}{' with question' if args.include_question else ''} in {args.log_file}. Run the analysis first or check the file path."
                    )
    else:
        print("Process-only mode is selected, but no processing code is provided.")

if __name__ == "__main__":
    main()
