from abc import ABC
from dataclasses import asdict, dataclass

import numpy as np
import torch
import torch.optim as optim
from PIL import Image
from torch.utils.data import DataLoader, Subset
from torch.utils.data.dataset import random_split

import wandb
from adversarial_superposition.constants import DEVICE, FLOAT_PRECISION_MAP
from adversarial_superposition.modulo.utils.binary_operations import (
    add_mod,
    product_mod,
    subtract_mod,
)
from adversarial_superposition.modulo.utils.datasets import AlgorithmicDataset
from adversarial_superposition.modulo.utils.models import MLP


def decode_concatenated_one_hot(tensor):
    """
    Decodes N concatenated one-hot vectors of length 226 (2 * 113) into pairs of numbers.

    Args:
        tensor: A torch.Tensor of shape (N, 226) containing N pairs of concatenated one-hot vectors

    Returns:
        torch.Tensor: Tensor of shape (N, 2) where each row contains two numbers between 0 and 112
    """

    # Convert input to torch tensor if it isn't already
    if not isinstance(tensor, torch.Tensor):
        tensor = torch.tensor(tensor)

    # Verify input shape
    if len(tensor.shape) != 2 or tensor.shape[1] != 226:
        raise ValueError(f"Expected tensor of shape (N, 226), got shape {tensor.shape}")

    # Split each vector into two parts
    first_half = tensor[:, :113]  # Shape: (N, 113)
    second_half = tensor[:, 113:]  # Shape: (N, 113)

    # Find the indices of 1s in each half
    first_numbers = torch.argmax(first_half, dim=1)  # Shape: (N,)
    second_numbers = torch.argmax(second_half, dim=1)  # Shape: (N,)

    # Stack the results into a single tensor
    # result = torch.stack([(first_numbers + second_numbers) % 113], dim=1)  # Shape: (N, 2)
    result = torch.stack([first_numbers, second_numbers], dim=1)  # Shape: (N, 2)

    return result


def one_hot_encode(number, size):
    one_hot = torch.zeros(size)
    one_hot[number] = 1
    return one_hot


def cross_entropy_float64(logits, labels, reduction="mean"):
    labels = labels.to(torch.int64)
    logprobs = torch.nn.functional.log_softmax(logits.to(torch.float64), dim=-1)
    prediction_logprobs = torch.gather(logprobs, index=labels[:, None], dim=-1).to(
        torch.float64
    )
    loss = (
        -torch.mean(prediction_logprobs)
        if reduction == "mean"
        else -prediction_logprobs
    )
    return loss.to(torch.float32)


def s(x, epsilon=1e-30):
    return torch.where(x < 0, 1 / (1 - x + epsilon), x + 1)


def log_stablemax(x, dim=-1):
    s_x = s(x)
    return torch.log(s_x / torch.sum(s_x, dim=dim, keepdim=True))


def stablemax_cross_entropy(logits, labels, reduction="mean"):
    labels = labels.to(torch.int64)
    logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
    prediction_logprobs = torch.gather(logprobs, index=labels[:, None], dim=-1).to(
        torch.float64
    )

    loss = (
        -torch.mean(prediction_logprobs)
        if reduction == "mean"
        else -prediction_logprobs
    )
    return loss


def cross_entropy_float32(logits, labels, reduction="mean"):
    labels = labels.to(torch.int64)
    logprobs = torch.nn.functional.log_softmax(logits.to(torch.float32), dim=-1)
    labels = labels.view(-1, 1)
    prediction_logprobs = torch.gather(logprobs, dim=-1, index=labels)
    prediction_logprobs = prediction_logprobs.squeeze(-1)

    if reduction == "mean":
        loss = -torch.mean(prediction_logprobs)
    elif reduction == "sum":
        loss = -torch.sum(prediction_logprobs)
    elif reduction == "none":
        loss = -prediction_logprobs
    else:
        raise ValueError(f"Unsupported reduction type: {reduction}")
    return loss


def cross_entropy_float16(logits, labels, reduction="mean"):
    labels = labels.to(torch.int64)
    logprobs = torch.nn.functional.log_softmax(logits.to(torch.float16), dim=-1)

    prediction_logprobs = torch.gather(logprobs, index=labels[:, None], dim=-1).to(
        torch.float16
    )
    loss = (
        -torch.mean(prediction_logprobs)
        if reduction == "mean"
        else -prediction_logprobs
    )
    return loss


def update_results(filename, experiment_key, logger_metrics):
    try:
        results = torch.load(filename)
    except:
        results = {}

    results[experiment_key] = logger_metrics
    torch.save(results, filename)


def evaluate(model, data_loader, loss_function=cross_entropy_float64):
    model.eval()
    loss = 0
    correct = 0
    device = next(model.parameters()).device
    float_precision = next(model.parameters()).dtype
    with torch.no_grad():
        for data, target, *_ in data_loader:
            label_argmax = len(target.shape) != 1
            output = model(data.to(device).to(float_precision)).to("cpu")
            loss += loss_function(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            if label_argmax:
                target = target.argmax(dim=1)
            correct += pred.eq(target.to("cpu").view_as(pred)).sum().item()
    loss /= len(data_loader)
    accuracy = 100 * correct / len(data_loader.dataset)
    return loss, accuracy


def get_specified_args(parser, args):
    defaults = {
        action.dest: action.default
        for action in parser._actions
        if action.dest != "help"
    }

    specified = {
        arg: getattr(args, arg)
        for arg in vars(args)
        if getattr(args, arg) != defaults.get(arg) and arg != "device"
    }

    return specified


def split_dataset(dataset, train_fraction, batch_size):
    total_size = len(dataset)
    train_size = int(train_fraction * total_size)
    test_size = total_size - train_size
    print(f"Starting trining. Train dataset size: {train_size}, Test size: {test_size}")
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    return train_dataset, test_dataset


def reduce_train_dataset(original_train_dataset, reduced_fraction, batch_size):
    original_indices = original_train_dataset.indices
    reduced_train_size = int(reduced_fraction * len(original_indices))
    reduced_indices = original_indices[:reduced_train_size]
    reduced_train_dataset = Subset(original_train_dataset, reduced_indices)

    reduced_train_loader = DataLoader(
        reduced_train_dataset, batch_size=batch_size, shuffle=True
    )
    return reduced_train_loader


BINARY_OPERATION_MAP = {
    "add_mod": add_mod,
    "product_mod": product_mod,
    "subtract_mod": subtract_mod,
}


def get_dataset(args):
    dataset = AlgorithmicDataset(
        BINARY_OPERATION_MAP[args.binary_operation],
        p=args.modulo,
        input_size=args.input_size,
        output_size=args.modulo,
    )

    train_dataset, test_dataset = split_dataset(
        dataset, args.train_fraction, args.batch_size
    )

    return train_dataset, test_dataset


def generate_random_one_hot(length):
    index = torch.randint(0, length, (1,)).item()
    one_hot_vector = torch.zeros(length)
    one_hot_vector[index] = 1
    return one_hot_vector


def get_model(args):
    print(f"Hidden sizes: {args.hidden_sizes}")
    print("Using AlgorithmicDataset")
    model = (
        MLP(
            input_size=args.input_size * 2,
            output_size=args.modulo,
            hidden_sizes=args.hidden_sizes,
            bias=True,
        )
        .to(DEVICE)
        .to(FLOAT_PRECISION_MAP[args.train_precision])
    )
    return model


def get_optimizer(model, args):
    if args.optimizer == "Adam":
        optimizer = optim.Adam(
            model.parameters(),
            lr=args.lr,
            betas=(0.9, 0.98),
            weight_decay=0,
            eps=args.adam_epsilon,
        )
    elif args.optimizer == "AdamW":
        optimizer = optim.AdamW(
            model.parameters(),
            lr=args.lr,
            weight_decay=args.weight_decay,
            eps=args.adam_epsilon,
            betas=(0.9, args.beta2),
        )
    elif args.optimizer == "SGD":
        optimizer = optim.SGD(
            model.parameters(), lr=args.lr, momentum=0.2, weight_decay=0
        )
    else:
        raise ValueError(f"Unsupported optimizer type: {args.optimizer}")
    return optimizer


import argparse


def parse_args():
    parser = argparse.ArgumentParser(
        description="Train a neural network with specified parameters."
    )

    parser.add_argument(
        "--hidden_sizes",
        type=int,
        nargs="+",
        default=[200, 200],
        help="List of hidden layer sizes. Default is [200, 200].",
    )

    parser.add_argument(
        "--num_epochs",
        type=int,
        default=1500,
        help="Number of epochs. Default is 1500.",
    )

    parser.add_argument(
        "--train_fraction",
        type=float,
        default=0.3,
        help="Fraction of data to be used for training. Default is 0.3.",
    )

    parser.add_argument(
        "--modulo",
        type=int,
        default=113,
        help="Modulo value for modular arithmetic datasets. Default is 113.",
    )

    parser.add_argument(
        "--input_size",
        type=int,
        default=113,
        help="Input size for the model. Default is 113.",
    )

    parser.add_argument(
        "--optimizer",
        type=str,
        default="AdamW",
        help="Optimizer to use. Options: AdamW, Adam, SGD. Default is AdamW.",
    )

    parser.add_argument(
        "--loss_function",
        type=str,
        default="cross_entropy",
        help="Loss function to use. Options: stablemax, cross_entropy. Default is cross_entropy.",
    )

    parser.add_argument(
        "--log_frequency",
        type=int,
        default=50,
        help="Logging frequency (in epochs). Default is 50.",
    )

    parser.add_argument(
        "--regularization",
        type=str,
        default="None",
        help="Regularization method. Options: None, l1, l2. Default is None.",
    )

    parser.add_argument(
        "--binary_operation",
        type=str,
        default="add_mod",
        help="Binary operation for algorithmic tasks. Options: add_mod, product_mod, subtract_mod",
    )

    parser.add_argument(
        "--lr", type=float, default=None, help="Learning rate. Default is None."
    )

    parser.add_argument(
        "--batch_size", type=int, default=128, help="Batch size. Default is 128."
    )

    parser.add_argument(
        "--full_batch",
        action="store_true",
        default=True,
        help="Use full batch gradient descent. Default is True.",
    )

    parser.add_argument(
        "--dataset",
        type=str,
        default="add_mod",
        help="Dataset to use. Options: rotated_mnist, add_mod. Default is add_mod.",
    )

    parser.add_argument(
        "--temperature_schedule",
        action="store_true",
        default=False,
        help="Use a schedule for softmax temperature. Default is False.",
    )

    parser.add_argument(
        "--num_noise_features",
        type=int,
        default=50,
        help="Number of noise features used for SparseParityDataset. Default is 50.",
    )

    parser.add_argument(
        "--num_parity_features",
        type=int,
        default=4,
        help="Number of parity features used for SparseParityDataset. Default is 4.",
    )

    parser.add_argument(
        "--num_samples",
        type=int,
        default=1000,
        help="Number of samples for SparseParityDataset. Default is 1000.",
    )

    parser.add_argument(
        "--alpha",
        type=float,
        default=1.0,
        help="Alpha coefficient that multiplies the logits. Default is 1.0.",
    )

    parser.add_argument(
        "--lambda_l1",
        type=float,
        default=0.00001,
        help="L1 regularization coefficient. Default is 0.00001.",
    )

    parser.add_argument(
        "--lambda_l2",
        type=float,
        default=0.00005,
        help="L2 regularization coefficient. Default is 0.00005.",
    )

    parser.add_argument(
        "--softmax_precision",
        type=int,
        default=32,
        help="Floating point precision for the loss calculation: 16, 32, or 64. Default is 32.",
    )

    parser.add_argument(
        "--train_precision",
        type=int,
        default=32,
        help="Floating point precision for the model and data: 16, 32, or 64. Default is 32.",
    )

    parser.add_argument(
        "--weight_decay",
        type=float,
        default=0,
        help="Weight decay (L2 penalty) coefficient. Default is 0.",
    )

    parser.add_argument(
        "--use_lr_scheduler",
        action="store_true",
        default=False,
        help="Use a learning rate scheduler. Default is False.",
    )

    parser.add_argument(
        "--orthogonal_gradients",
        action="store_true",
        default=False,
        help="Use orthogonal gradients regularization. Default is False.",
    )

    parser.add_argument(
        "--use_transformer",
        action="store_true",
        default=False,
        help="Use one layer transformer",
    )

    parser.add_argument("--device", type=str, default="cpu", help="Device")
    parser.add_argument(
        "--beta2", type=float, default=0.99, help="Beta2 parameter for Adam and AdamW"
    )
    parser.add_argument(
        "--adam_epsilon",
        type=float,
        default=1e-25,
        help="Epsilon value for Adam and AdamW",
    )
    parser.add_argument(
        "--seed", type=int, default=42, help="Random seed. Default is 42."
    )

    return parser, parser.parse_args()


def process_metrics(df):
    """
    Process metrics DataFrame and return a dictionary for the latest epoch.

    Args:
        df (pd.DataFrame): DataFrame with columns [epoch, input_type, metric, param, value]

    Returns:
        dict: Dictionary with keys as {input_type}_{metric} and corresponding values
    """
    # Get the latest epoch
    max_epoch = df["epoch"].max()

    # Filter for latest epoch and non-weights metrics
    latest_df = df[
        (df["epoch"] == max_epoch)
        & (df["input_type"] != "general")
        & (df["metric_name"] != "zero_terms")
        & (df["metric_name"] != "softmax_collapse")
    ]

    # Create dictionary with combined keys
    result = {}
    for _, row in latest_df.iterrows():
        key = f"{row['input_type']}_{row['metric_name']}"
        result[key] = float(row["value"])

    return result


def analyze_perturbation_distribution(original: torch.Tensor, perturbed: torch.Tensor):
    """
    Compute metrics to analyze how changes are distributed across dimensions.

    Args:
        original: Input tensor of shape (N, 226)
        perturbed: Perturbed tensor of shape (N, 226)

    Returns:
        Dictionary of metrics, each containing a tensor of N values
    """
    # Ensure inputs are 2D
    if original.dim() != 2 or perturbed.dim() != 2:
        raise ValueError("Inputs must be 2D tensors of shape (N, 226)")

    batch_size = original.shape[0]
    delta = perturbed - original  # Shape: (N, 226)
    abs_changes = torch.abs(delta)  # Shape: (N, 226)

    # Initialize metrics for batch
    gini_coeffs = torch.zeros(batch_size, device=original.device)
    norm_entropies = torch.zeros(batch_size, device=original.device)
    top_k_ratios = torch.zeros(batch_size, device=original.device)
    cvs = torch.zeros(batch_size, device=original.device)

    # Calculate metrics for each example in batch
    for i in range(batch_size):
        # Get changes for this example
        example_changes = abs_changes[i]  # Shape: (226,)

        # 1. Gini coefficient
        sorted_changes, _ = torch.sort(example_changes)
        n = len(sorted_changes)
        index = torch.arange(1, n + 1, device=sorted_changes.device, dtype=torch.float)
        gini_coeffs[i] = ((2 * index - n - 1) * sorted_changes).sum() / (
            n * sorted_changes.sum()
        )

        # 2. Normalized entropy
        changes_sum = example_changes.sum()
        if changes_sum > 0:  # Avoid division by zero
            changes_norm = example_changes / changes_sum
            entropy = -torch.sum(changes_norm * torch.log2(changes_norm + 1e-10))
            max_entropy = torch.log2(
                torch.tensor(n, dtype=torch.float, device=delta.device)
            )
            norm_entropies[i] = entropy / max_entropy

        # 3. Top-k concentration
        k = max(1, int(0.1 * n))  # Top 10% of dimensions
        top_k_sum = torch.topk(example_changes, k)[0].sum()
        top_k_ratios[i] = top_k_sum / changes_sum if changes_sum > 0 else 0

        # 4. Coefficient of variation
        mean = example_changes.mean()
        if mean > 0:
            std = example_changes.std()
            cvs[i] = std / mean

    # Per-example metrics
    per_example_metrics = {
        "gini_coefficient": gini_coeffs,
        "normalized_entropy": norm_entropies,
        "top_10_percent_concentration": top_k_ratios,
        "coefficient_of_variation": cvs,
    }

    # Batch statistics
    batch_stats = {}
    for name, values in per_example_metrics.items():
        batch_stats[f"{name}_mean"] = values.mean()
        batch_stats[f"{name}_std"] = values.std()
        # Get quartiles for distribution analysis
        batch_stats[f"{name}_quartiles"] = torch.tensor(
            [
                values.quantile(0.25),
                values.quantile(0.50),  # median
                values.quantile(0.75),
            ]
        )

    return {"batch_stats": batch_stats}


import torch


def compute_dimension_frequency(batch_results, top_k):
    """Compute how often each dimension appears in top-k across batch"""
    dimension_counts = torch.zeros(226)
    for result in batch_results:
        for idx in result["top_changing_dims"]["indices"]:
            dimension_counts[idx] += 1
    return {
        "most_frequent_dims": torch.topk(dimension_counts, k=top_k),
        "frequency_distribution": dimension_counts,
    }


def compute_coverage_stats(batch_results):
    """Compute statistics about dimension coverage across batch"""
    coverage_stats = {}
    for threshold in ["dims_for_50%", "dims_for_75%", "dims_for_90%"]:
        values = torch.tensor(
            [r["dimension_coverage"][threshold] for r in batch_results]
        )
        coverage_stats[threshold] = {
            "mean": values.cpu().numpy().mean().item(),
            "std": values.cpu().numpy().std().item(),
            "min": values.cpu().numpy().min().item(),
            "max": values.cpu().numpy().max().item(),
        }
    return coverage_stats


def print_batch_analysis(analysis):
    """Pretty print the batch-level analysis results"""
    print("\n=== Batch-Level Analysis ===")

    # Print batch statistics
    stats = analysis["batch_stats"]
    print(
        f"\nAverage total change: {stats['mean_total_change']:.4f} ± {stats['std_total_change']:.4f}"
    )
    print(f"Average number of sign changes: {stats['mean_sign_changes']:.2f}")

    # Print most frequently changed dimensions
    print("\nMost frequently changed dimensions across batch:")
    freq_dims = stats["dimension_frequency"]["most_frequent_dims"]
    for idx, count in zip(freq_dims.indices, freq_dims.values):
        print(f"Dimension {idx}: appeared in top-k {count:.0f} times")

    # Print coverage statistics
    print("\nDimension coverage statistics:")
    for threshold, stats in stats["coverage_stats"].items():
        print(f"{threshold}:")
        print(
            f"  Mean: {stats['mean']:.1f} dimensions (min: {stats['min']}, max: {stats['max']})"
        )


def create_change_matrices(
    original_data: torch.Tensor, adversarial_data: torch.Tensor, cfg
):
    """
    Create 226 matrices (113x113) representing magnitude of changes in adversarial examples.

    Args:
        original_data: Tensor of shape (N, 226) containing original inputs
        adversarial_data: Tensor of shape (N, 226) containing adversarial examples

    Returns:
        matrices: Tensor of shape (226, 113, 113) containing the magnitude of changes
        coverage: Tensor of shape (113, 113) indicating which additions are present
    """
    # Validate inputs
    assert (
        original_data.shape == adversarial_data.shape
    ), f"{original_data.shape} != {adversarial_data.shape}"
    assert (
        original_data.shape[1] == cfg.input_size * 2
    ), f"{original_data.shape} != {cfg.input_size * 2}"
    print(f"The input size is {cfg.input_size}")

    # Initialize result matrices and coverage tracker
    matrices = torch.zeros(
        cfg.input_size * 2, cfg.input_size, cfg.input_size, device=original_data.device
    )
    coverage = torch.zeros(cfg.input_size, cfg.input_size, device=original_data.device)

    # Calculate changes
    changes = adversarial_data - original_data  # Shape: (N, 226)

    # Get indices of 1s for each example from original data
    first_nums = torch.argmax(original_data[:, : cfg.input_size], dim=1)
    second_nums = torch.argmax(original_data[:, cfg.input_size :], dim=1)

    # Iterate through dataset
    for idx in range(len(original_data)):
        i, j = first_nums[idx].item(), second_nums[idx].item()

        # Record the magnitude of changes for this addition
        matrices[:, i, j] = changes[idx]
        coverage[i, j] = 1

    return matrices, coverage


def analyze_change_matrices(matrices: torch.Tensor, coverage: torch.Tensor):
    """
    Analyze the patterns in the change matrices.

    Args:
        matrices: Tensor of shape (226, 113, 113)
        coverage: Tensor of shape (113, 113)

    Returns:
        Dictionary containing various analysis metrics
    """
    # Get mask of valid positions
    valid_positions = coverage == 1

    # Calculate absolute changes for valid positions
    abs_changes = torch.abs(matrices)
    valid_changes = abs_changes[:, valid_positions]

    metrics = {
        # Overall statistics
        "total_combinations": torch.sum(coverage).item(),
        "coverage_percentage": (torch.sum(coverage) / (113 * 113) * 100).item(),
        # Change magnitude statistics
        "mean_change_per_bit": torch.mean(
            valid_changes, dim=1
        ),  # Average change for each bit
        "max_change_per_bit": torch.max(valid_changes, dim=1)[
            0
        ],  # Maximum change for each bit
        "std_change_per_bit": torch.std(
            valid_changes, dim=1
        ),  # Std dev of changes for each bit
        # Overall change statistics
        "total_mean_change": torch.mean(valid_changes).item(),
        "total_max_change": torch.max(valid_changes).item(),
        "total_std_change": torch.std(valid_changes).item(),
        # Distribution of changes across first and second numbers
        "first_num_changes": torch.mean(
            abs_changes, dim=(0, 2)
        ),  # Average change per first number
        "second_num_changes": torch.mean(
            abs_changes, dim=(0, 1)
        ),  # Average change per second number
    }

    return metrics


def print_change_analysis(matrices, coverage):
    """Print readable analysis of the changes."""
    metrics = analyze_change_matrices(matrices, coverage)

    print("\n=== Change Analysis ===")
    print(f"Total combinations present: {metrics['total_combinations']}")
    print(f"Coverage percentage: {metrics['coverage_percentage']:.2f}%")

    print("\nOverall Change Statistics:")
    print(f"Mean change magnitude: {metrics['total_mean_change']:.4f}")
    print(f"Max change magnitude: {metrics['total_max_change']:.4f}")
    print(f"Std dev of changes: {metrics['total_std_change']:.4f}")

    # Find bits with largest changes
    mean_changes = metrics["mean_change_per_bit"]
    top_k = 5  # Show top 5 bits
    top_bits = torch.topk(mean_changes, top_k)

    print(f"\nTop {top_k} most changed bits:")
    for idx, (bit, change) in enumerate(zip(top_bits.indices, top_bits.values)):
        first_half = "first" if bit < 113 else "second"
        relative_pos = bit % 113
        print(
            f"{idx + 1}. Bit {bit} ({first_half} number, position {relative_pos}): {change:.4f} mean change"
        )


def log_tensor_as_image(tensor, name="tensor_image", step=None):
    """
    Logs a 2D tensor as an image to wandb.

    Args:
        tensor: 2D torch.Tensor or numpy array
        name: Name of the image in wandb
        step: Optional step number for the logging
    """
    # Convert to numpy if it's a torch tensor
    if isinstance(tensor, torch.Tensor):
        tensor = tensor.detach().cpu().numpy()

    # Store the original signs before normalization
    original_signs = np.sign(tensor)
    zero_mask = np.abs(tensor) < 1e-6  # Adjust threshold as needed

    # Separate positive and negative values and normalize them independently
    pos_tensor = np.where(tensor > 0, tensor, 0)
    neg_tensor = np.where(tensor < 0, -tensor, 0)

    # Normalize positive values to [128, 255] (for red channel)
    if pos_tensor.max() > 0:
        pos_tensor = 128 + (pos_tensor / pos_tensor.max()) * 127

    # Normalize negative values to [128, 255] (for blue channel)
    if neg_tensor.max() > 0:
        neg_tensor = 128 + (neg_tensor / neg_tensor.max()) * 127

    # Create RGB array
    rgb_tensor = np.zeros((*tensor.shape, 3), dtype=np.uint8)

    # Set colors:
    # Red channel - positive values
    rgb_tensor[..., 0] = np.where(original_signs > 0, pos_tensor, 0).astype(np.uint8)

    # Blue channel - negative values
    rgb_tensor[..., 2] = np.where(original_signs < 0, neg_tensor, 0).astype(np.uint8)

    # Grey for zero values (set all channels to 128)
    rgb_tensor[zero_mask] = 128

    # Create PIL Image
    image = Image.fromarray(rgb_tensor)

    # Resize using nearest neighbor interpolation to avoid blurring
    new_size = (800, 800)  # Specify your desired width and height
    image = image.resize(new_size, Image.NEAREST)

    image.show()

    # Log to wandb
    wandb.log({name: wandb.Image(image)}, step=step)


def is_wandb_initialized():
    try:
        import wandb

        return wandb.run is not None
    except ImportError:
        return False


@dataclass
class Config(ABC):
    adam_epsilon: float = 1e-25
    alpha: float = 1.0
    batch_size: int = 128
    beta2: float = 0.999
    binary_operation: str = "add_mod"
    dataset: str = "addition"
    hidden_sizes = [200, 200]
    input_size: int = 113
    full_batch: bool = True
    log_frequency: int = 5000
    loss_function: str = "stablemax"
    lr: float = 0.01
    modulo: int = 113
    loss_function: str = "cross_entropy"
    num_epochs: int = 80000
    optimizer: str = "AdamW"
    orthogonal_gradients = False
    regularization = None
    seed: int = 42
    softmax_precision: int = 32
    train_fraction: float = 0.3
    train_precision: int = 32
    use_transformer: bool = False
    weight_decay: float = 0.0

    @classmethod
    def from_dict(cls, data):
        return cls(**data)

    def dict(self):
        return {k: str(v) for k, v in asdict(self).items()}
