from collections import defaultdict
from dataclasses import dataclass

import torch as t

import wandb
from adversarial_superposition.constants import DEVICE


@dataclass
class Attack:
    from_idx: t.tensor
    to_idx: t.tensor
    attack_diffs: t.tensor
    input: t.tensor
    input_sums: t.tensor
    attacked_input: t.tensor
    attacked_sums: t.tensor


def attack_toy_classifier(
    instance_idx,
    attack_params,
    model,
    attack_method,
    target_class=None,
    verbose=False,
    use_wandb=False,
):
    """Performs adversarial attacks on a toy classifier model and analyzes their effectiveness.

    This function conducts adversarial attacks using the specified attack method, then evaluates
    which attacks successfully changed the model's prediction without changing the actual class
    of the input (true adversarial examples). For targeted attacks, it also ensures the original
    class is not the target class.

    Args:
        instance_idx: Index or indices of instances to attack in the model's feature space.
        attack_params: Dictionary of parameters for the attack method.
        model: The classifier model to attack.
        attack_method: The attack method to use (e.g., 'gradient', 'random', etc.).
        target_class: Optional; specific class to target with the attack. If None,
                     the attack is untargeted.
        verbose: If True, prints detailed information about each successful attack.
        use_wandb: If True, logs results to Weights & Biases.

    Returns:
        tuple: A tuple containing:
            - lookup (defaultdict): Dictionary mapping (from_class, to_class) pairs to lists of
              corresponding Attack objects.
            - elements (numpy.ndarray): Indices of successful attacks.
            - original_batch (torch.Tensor): Original input data.
            - original_correct_preds (torch.Tensor): Boolean mask of originally correct predictions.
            - attacks (torch.Tensor): Attacked input data.

    Notes:
        A "true attack" is defined as one where:
        1. The original prediction was correct
        2. The attack successfully changed the model's prediction
        3. The actual class of the input did not change
        4. For targeted attacks, the original class was not the target class
    """

    # Extract find_worst_case from attack_params, default to False
    find_worst_case = attack_params.pop("find_worst_case", False)

    attacks, original_batch, orig_labels, preds_not_equal_label_mask = model.attack(
        attack_method=attack_method,
        attack_params=attack_params,
        instance_idx=instance_idx,
        verbose=verbose,
        target_class=target_class,
        find_worst_case=find_worst_case,
    )

    # Get the required data
    original_labels, original_class_sums = model._assign_class(
        original_batch, single_index=True
    )
    attacked_labels, attacked_class_sums = model._assign_class(
        attacks, single_index=True
    )
    original_preds = model(original_batch, instance_idx=instance_idx).detach().cpu()
    attacked_preds = model(attacks, instance_idx=instance_idx)

    # Recast all tensors
    original_labels, original_preds, original_class_sums = (
        original_labels.to(DEVICE),
        original_preds.to(DEVICE),
        original_class_sums.to(DEVICE),
    )
    attacked_labels, attacked_preds, attacked_class_sums = (
        attacked_labels.to(DEVICE),
        attacked_preds.to(DEVICE),
        attacked_class_sums.to(DEVICE),
    )
    preds_not_equal_label_mask = preds_not_equal_label_mask.to(DEVICE)

    # (i) Originally correct prediction
    original_correct_preds = (original_labels == original_preds.argmax(-1)).to(DEVICE)
    # (ii) Successfully attacked (i.e. pred changed)
    successful_attack = (preds_not_equal_label_mask & original_correct_preds).to(DEVICE)
    # (iii) Class not actually changed
    class_not_changed = (attacked_labels == original_labels).to(DEVICE)

    if target_class is not None:
        orig_class_not_target = (original_labels != target_class).to(DEVICE)

    m = (class_not_changed & successful_attack).to(DEVICE)
    if target_class is not None:
        m = (m & orig_class_not_target).to(DEVICE)

    elements = t.nonzero(m).to(DEVICE)

    if target_class is not None:
        if use_wandb:
            wandb.log(
                {
                    f"{instance_idx}_true_attack": m.sum().item(),
                    f"{instance_idx}_correct_preds": (
                        original_correct_preds & orig_class_not_target
                    )
                    .sum()
                    .item(),
                    f"{instance_idx}_total_samples": orig_class_not_target.sum().item(),
                }
            )
        print(
            f"There were {m.sum().item()} true targetted attacks\n"
            f"(out of {(successful_attack & orig_class_not_target).sum().item()} successful flips of label)\n"
            f"(out of {(original_correct_preds & orig_class_not_target).sum().item()} correct preds)\n"
            f"(out of {orig_class_not_target.sum().item()} total samples that were not class {target_class})"
        )
    else:
        if use_wandb:
            wandb.log(
                {
                    f"{instance_idx}_true_attack": (
                        class_not_changed & successful_attack
                    )
                    .sum()
                    .item(),
                    f"{instance_idx}_correct_preds": (original_correct_preds)
                    .sum()
                    .item(),
                    f"{instance_idx}_total_samples": len(original_labels),
                }
            )
        print(
            f"There were {(class_not_changed & successful_attack).sum().item()} true attacks\n"
            f"(out of {successful_attack.sum().item()} successful flips of label)\n"
            f"(out of {original_correct_preds.sum().item()} correct preds)\n"
            f"(out of {len(original_labels)} total samples"
        )
    elements = elements.detach().cpu().numpy()

    changes = list()

    for i in elements:
        changes.append(
            Attack(
                from_idx=original_preds[i, :].argmax().item(),
                to_idx=attacked_preds[i, :].argmax().item(),
                attack_diffs=attacked_class_sums[i] - original_class_sums[i],
                input=original_batch[i, :],
                input_sums=original_class_sums[i, :],
                attacked_input=attacks[i, :],
                attacked_sums=attacked_class_sums[i, :],
            )
        )

    lookup = defaultdict(list)
    for change in changes:
        lookup[(change.from_idx, change.to_idx)].append((change))

    return (
        lookup,
        elements,
        original_batch,
        orig_labels,
        original_correct_preds,
        attacks,
    )
