import logging
import os
import re
import subprocess
import threading
from logging.handlers import RotatingFileHandler

import torch

from config import config


def parse_safety_status(output):
    try:
        # Unified regex to extract safe, unsafe, and timeout counts
        match = re.search(
            r"total verified \(safe/unsat\): (\d+).*?"
            r"total falsified \(unsafe/sat\): (\d+).*?"
            r"timeout: (\d+)",
            output,
            re.DOTALL
        )

        if not match:
            raise ValueError("Could not parse verification results.")

        # Extract values
        safe_count = int(match.group(1))
        unsafe_count = int(match.group(2))
        timeout_count = int(match.group(3))

        # Safe is True if no unsafe cases exist
        safe = unsafe_count == 0 and timeout_count == 0

        return safe, {
            "safe": safe_count,
            "unsafe": unsafe_count,
            "timeout": timeout_count
        }

    except Exception as e:
        raise ValueError(f"Could not parse verification results: {e}")


def load_x_from_file(dataset, adv_x_path, dup=False):
    """Load all X values from a file and convert them into a tensor."""

    x_values = []
    with open(adv_x_path, "r") as f:
        for line in f:
            if line.startswith("(X_") or line.startswith("((X_0"):  # Filter only X values
                value = float(line.split()[1].strip(")"))  # Extract numerical value
                x_values.append(value)

    # convert list to tensor and reshape based on dataset
    if dataset == 'mnist':
        if dup:
            x_tensor = torch.tensor(x_values, dtype=torch.float32).view(1, 1568) # MNIST: (1, 1568) for duplicated network
        else:
            x_tensor = torch.tensor(x_values, dtype=torch.float32).view(1, 784)  # MNIST: (1, 784)
    elif dataset == 'taxinet':
        if dup:
            x_tensor = torch.tensor(x_values, dtype=torch.float32).view(1, 2, 27, 54)
        else:
            x_tensor = torch.tensor(x_values, dtype=torch.float32).view(1, 1, 27, 54)
    else: # 'cifar10':
        if dup:
            x_tensor = torch.tensor(x_values, dtype=torch.float32).view(1, 6, 32, 32)
        else:
            x_tensor = torch.tensor(x_values, dtype=torch.float32).view(1, 3, 32, 32)  # CIFAR-10: (1, 3, 32, 32)
    return x_tensor



### todo - remove al hardcoded parameters, refactor function.
def write_conf(customized_models_path, saved_model_weights, adv_x_path, samples_path, file_path, device, model_type, specification, epsilon=0.06, delta=0.0, verbose=False, Z_mask=None, patch_eps=None, verify_patching_only=True, enable_incomplete=True, timeout=300):
    """
    write an ab-CROWN YAML config that uses a Customized() loader for model/data,
    optionally with patching masks.
    """
    # determine if we're in an adversarial tripled-net mode
    adversarial_specs = {
        "tripled-adv-same-winner",
        "tripled-adv-winner-runner",
        "tripled-adv-label-target",
        "tripled-adv-same-winner-iteratively"
    }
    is_adversarial = specification in adversarial_specs

    # determine num_outputs based on model_type
    if model_type == 'mnist' or model_type.startswith('cifar10'):
        num_outputs = 10
    elif model_type == 'taxinet':
        num_outputs = 1
    elif model_type == 'gtsrb':
        num_outputs = 43
    else:
        raise ValueError(f"Unsupported model_type for num_outputs: {model_type}")

    # wrapping customized_models_path in quotes
    base_args = f'"{customized_models_path}"'

    # build the data_loader and model_loader strings
    if model_type == 'mnist':
        mnist_args = 'input_size=784, hidden_size_1=13, hidden_size_2=11, num_classes=10' # common model args for MNIST
        if Z_mask is not None:
            assert patch_eps is not None, "patch_eps is required when Z_mask is given"
            if not verify_patching_only: # checking both input and patching robustness, using a tripled network
                # loader for tripled patched MNIST
                data_loader = f'Customized({base_args}, "load_mnist_batch", mnist_batch_path="{samples_path}", patching=True, patch_eps={patch_eps}, verify_gold_label={is_adversarial})'
                model_loader = f'Customized({base_args}, "tripled_patch_mnist_model", Z_mask={Z_mask})'
            else: # checking only patching robustness, using a duplicated network
                # loader for patched MNIST
                data_loader = f'Customized({base_args}, "load_mnist_batch", mnist_batch_path="{samples_path}", patching=True, patch_eps={patch_eps}, verify_gold_label={is_adversarial})'
                model_loader = f'Customized({base_args}, "dup_patch_mnist_model", Z_mask={Z_mask})'
        else:
            # loader for unpatched MNIST
            data_loader = f'Customized({base_args}, "load_mnist_batch", mnist_batch_path="{samples_path}", patching=False, verify_gold_label={is_adversarial})'
            model_loader = f'Customized({base_args}, "tripled_adv_mnist_model", {mnist_args})' if is_adversarial else f'Customized({base_args}, "FC_model", {mnist_args})'

    elif model_type == 'taxinet':
        if Z_mask is not None:
            if not verify_patching_only: raise NotImplementedError("Tripled patching network for taxinet is not supported.")
            else:
                # loader for duplicated patching (only patching verification) taxinet net
                data_loader = f'Customized({base_args}, "load_taxinet_batch", taxinet_batch_path="{samples_path}", patching=True, patch_eps={patch_eps}, verify_gold_label={is_adversarial})'
                model_loader = f'Customized({base_args}, "dup_patch_taxinet_model", Z_mask={Z_mask})'
        else:
            data_loader = f'Customized({base_args}, "load_taxinet_batch", taxinet_batch_path="{samples_path}", verify_gold_label={is_adversarial})'
            model_loader = f'Customized({base_args}, "dup_taxinet_net")'

    elif model_type == 'gtsrb':
        if Z_mask is not None:
            if not verify_patching_only: raise NotImplementedError("Tripled patching network for gtsrb is not supported.")
            else:
                # loader for duplicated patching (only patching verification) GTSRB net
                data_loader = f'Customized({base_args}, "load_gtsrb_batch", gtsrb_batch_path="{samples_path}", patching=True, patch_eps={patch_eps}, verify_gold_label={is_adversarial})'
                model_loader = f'Customized({base_args}, "dup_patch_gtsrb_model", Z_mask={Z_mask})'
        else:
            data_loader = f'Customized({base_args}, "load_gtsrb_batch", gtsrb_batch_path="{samples_path}", verify_gold_label={is_adversarial})'
            model_loader = f'Customized({base_args}, "dup_gtsrb_net")'

    else:  # CIFAR-10 [cifar10-small, cifar10-big]
        if Z_mask is not None:
            if not verify_patching_only:
                # loader for tripled, both patching and input verifying, CIFAR-10 net
                data_loader = f'Customized({base_args}, "load_cifar10_batch", cifar10_batch_path="{samples_path}", patching=True, patch_eps={patch_eps}, verify_gold_label={is_adversarial})'
                model_loader = f'Customized({base_args}, "tripled_patch_cifar10_model", Z_mask={Z_mask})'
            else:
                # loader for duplicated patching (only patching verification) CIFAR-10 net
                data_loader = f'Customized({base_args}, "load_cifar10_batch", cifar10_batch_path="{samples_path}", patching=True, patch_eps={patch_eps}, verify_gold_label={is_adversarial})'
                model_loader = f'Customized({base_args}, "dup_patch_cifar10_model", Z_mask={Z_mask})'
        else:
            data_loader = f'Customized({base_args}, "load_cifar10_batch", cifar10_batch_path="{samples_path}", verify_gold_label={is_adversarial})'
            model_loader = f'Customized({base_args}, "tripled_adv_cifar_model")' if is_adversarial else f'Customized({base_args}, "dup_cifar10_net", model_type="{model_type}")'


    # Convert delta from tensors to floats properly so that the configuration reflects the batch of thresholds.
    if isinstance(delta, torch.Tensor):
        delta_val = delta.item() if delta.numel() == 1 else delta.tolist()
    else:
        delta_val = delta

    # Convert epsilon to a decimal string if it is very small
    if isinstance(epsilon, float) and epsilon < 0.001:
        epsilon_fmt = format(epsilon, '.6f')
    else:
        epsilon_fmt = epsilon

    if timeout is None:  timeout = 300

    content = f"""
general:
  device: {device}
  save_adv_example: true
  enable_incomplete_verification: {enable_incomplete}
  # complete_verifier: bab
  loss_reduction_func: min
  sparse_interm: false

model:
  name: | 
    {model_loader}
  path: '{saved_model_weights}'

data:
  dataset: |
    {data_loader}
  num_outputs: {num_outputs}
specification:
  robustness_type: {specification}
  epsilon: {epsilon_fmt}

attack:
  pgd_order: before
  pgd_restarts: 50
  pgd_batch_size: 1
  cex_path: '{adv_x_path}'

solver:
  batch_size: 1
  min_batch_size_ratio: 1
  alpha-crown:
    iteration: 10
    alpha: true
    disable_optimization: ['MaxPool']
  beta-crown:
    iteration: 20
    lr_beta: 0.03
  mip:
    parallel_solvers: 4
    solver_threads: 4
    refine_neuron_time_percentage: 0.8
    skip_unsafe: True

bab:
  timeout: {timeout}
  decision_thresh: {delta_val}
  pruning_in_iteration: False
  sort_domain_interval: 1
  # branching:
  #   method: nonlinear
  #   candidates: 3
  #   nonlinear_split:
  #     num_branches: 2
  #     method: shortcut
  #     filter: true

"""

    with open(file_path, "w") as f:
        f.write(content)

    if verbose:
        print(f"Configuration file '{file_path}' written successfully.")



def run_ab_crown(ab_crown_specification_path):
    stdout_lines = []
    stderr_lines = []

    # get log file path from spec path
    log_dir = os.path.dirname(ab_crown_specification_path)
    abcrown_log_file = os.path.join(log_dir, 'abcrown_stdout_full.log')

    # set up a rotating file handler for rolling log (20MB, 1 backup)
    max_bytes = 5 * 1024 * 1024  # 20MB
    backup_count = 1

    log_handler = RotatingFileHandler(abcrown_log_file, maxBytes=max_bytes, backupCount=backup_count, encoding='utf-8')
    log_handler.setFormatter(logging.Formatter('%(message)s'))
    logger = logging.getLogger("abcrown_stdout_full")
    logger.setLevel(logging.INFO)
    # remove any existing handlers to avoid duplicate logs
    logger.handlers = []
    logger.addHandler(log_handler)
    logger.propagate = False

    def log_stdout(pipe):
        for line in iter(pipe.readline, ''):
            line = line.rstrip()
            logger.info(line)
            stdout_lines.append(line)
        pipe.close()

    def log_stderr(pipe):
        for line in iter(pipe.readline, ''):
            line = line.rstrip()
            logger.info(line)  # log stderr lines as well
            stderr_lines.append(line)
        pipe.close()

    env = os.environ.copy()
    env["PYTHONPATH"] = os.path.abspath(".")

    process = subprocess.Popen(
        ["python3", config['paths']['abcrown_run_path'], "--config", ab_crown_specification_path],
        stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, env=env)

    stdout_thread = threading.Thread(target=log_stdout, args=(process.stdout,))
    stderr_thread = threading.Thread(target=log_stderr, args=(process.stderr,))

    stdout_thread.start()
    stderr_thread.start()

    stdout_thread.join()
    stderr_thread.join()

    process.wait()

    full_stdout = "\n".join(stdout_lines)
    full_stderr = "\n".join(stderr_lines)

    if process.returncode != 0:
        print("ERROR: abcrown.py failed with return code", process.returncode)
        print("STDERR:\n", full_stderr)
        logging.error(f"STDERR:\n{full_stderr}")
        raise RuntimeError(f"AB CROWN execution failed with return code: {process.returncode}")

    return full_stdout
