import os
import random
from sklearn.metrics import (
    precision_recall_curve,
    auc,
    average_precision_score,
    classification_report,
)
import numpy as np
import matplotlib.pyplot as plt
import torch
import numpy as np
from collections import defaultdict
import sys
import os
import json

# extracted concepts
import numpy as np
from vlol.convert import (
    create_train_question,
    create_train_label,
    create_train_question_cot,
)
from countries import tri_color_countries_subset
from vlol.genTrain import sample_trains
import pandas as pd
import warnings

openai_ds_thresholds = {
    "H": 0.101010,
    "S": 0.070707,
    "HR": 0.050505,
    "H2": 0.171717,
    "V": 0.565657,
    "V2": 0.030303,
}

openAI_key_mapping = {
    "H": "hate",
    "S": "sexual",
    "HR": "harassment",
    "H2": "harassment/threatening",
    "V": "violence",
    "V2": "violence/graphic",
}
openai_mod_key_mapping = {
    "openai_mod-hate": "H",
    "openai_mod-sexual": "S",
    "openai_mod-harassment": "HR",
    "openai_mod-harassment/threatening": "H2",
    "openai_mod-violence": "V",
    "openai_mod-violence/graphic": "V2",
}
labels_to_remove = ["S3", "SH"]


def load_openai_gt_data(test_samples=100, balance_class_dist=True):
    # set seed so results are reproducible
    ds = "data/openai/samples-1680.jsonl"
    prompts, labels = [], []
    with open(ds) as f:
        lines = f.readlines()
    for line in lines:
        data = json.loads(line)
        prompts.append(data["prompt"])
        label = data
        del label["prompt"]
        labels.append(label)

    # shuffle prompts and labels
    combined = list(zip(prompts, labels))
    random.shuffle(combined)
    prompts[:], labels[:] = zip(*combined)

    # Fill missing label keys with 0
    all_keys = set().union(*labels)
    labels_filled = [{k: d.get(k, 0) for k in all_keys} for d in labels]
    # Convert to tensor
    df = pd.DataFrame(labels_filled)
    # remove unwated labels from labels_tensor and label_names
    labels_tensor = torch.tensor(
        df.values, dtype=torch.float32
    )  # shape: [num_samples, num_classes]

    # add Unsafe label
    unsafe_label = (
        (labels_tensor.sum(dim=1) > 0).float().unsqueeze(1)
    )  # shape: [num_samples, 1]
    labels_tensor = torch.cat(
        [labels_tensor, unsafe_label], dim=1
    )  # shape: [num_samples, num_classes + 1]
    label_names = df.columns.tolist() + ["Unsafe"]

    # Create balanced train/test splits
    test_indices = []
    train_indices = []
    if balance_class_dist:
        test_samples_per_class = test_samples // (
            labels_tensor.shape[1] * 2
        )  # Divide by 2 to allocate for both pos and neg
        print(f"Num test samples per class: {test_samples_per_class}")
        # For each class, add both positive and negative samples to test set
        for class_idx in range(labels_tensor.shape[1]):
            # Get indices of samples that have this class (positive samples)
            class_positive = (
                torch.nonzero(labels_tensor[:, class_idx] == 1).squeeze().tolist()
            )
            if not isinstance(class_positive, list):
                class_positive = [class_positive]

            # Get indices of samples that don't have this class (negative samples)
            class_negative = (
                torch.nonzero(labels_tensor[:, class_idx] == 0).squeeze().tolist()
            )
            if not isinstance(class_negative, list):
                class_negative = [class_negative]

            # Shuffle the indices
            np.random.shuffle(class_positive)
            np.random.shuffle(class_negative)

            # Take up to test_samples_per_class for test (for both positive and negative)
            pos_test_size = min(test_samples_per_class, len(class_positive))
            neg_test_size = min(test_samples_per_class, len(class_negative))

            test_indices.extend(class_positive[:pos_test_size])
            test_indices.extend(class_negative[:neg_test_size])

        # Remove duplicates and convert to set
        test_indices = list(set(test_indices))
        if len(test_indices) < test_samples:
            # If not enough samples, take more from the pool
            additional_samples = random.sample(
                [i for i in range(len(prompts)) if i not in test_indices],
                test_samples - len(test_indices),
            )
            test_indices.extend(additional_samples)
    else:
        # If not balancing, just take the first N samples
        test_indices = random.sample(range(len(prompts)), test_samples)

    # All other indices go to training
    train_indices = [i for i in range(len(prompts)) if i not in test_indices]

    # Create the final datasets
    x_train = [prompts[i] for i in train_indices]
    y_train = labels_tensor[train_indices]

    x_test = [prompts[i] for i in test_indices]
    y_test = labels_tensor[test_indices]

    # Create per-class test subsets
    x_test_by_class = {}
    y_test_by_class = {}

    for class_idx, class_name in enumerate(label_names):
        # Find test samples with this class
        class_test_indices = [
            i
            for i, idx in enumerate(test_indices)
            if labels_tensor[idx, class_idx] == 1
        ]

        x_test_by_class[class_name] = [x_test[i] for i in class_test_indices]
        y_test_by_class[class_name] = y_test[class_test_indices]

    # Add class distribution information to print statements

    # Count occurrences of each class in train set
    train_class_counts = y_train.sum(dim=0).tolist()
    train_class_distribution = [
        f"{label_names[i]}: {count}" for i, count in enumerate(train_class_counts)
    ]

    # Count occurrences of each class in test set
    test_class_counts = y_test.sum(dim=0).tolist()
    test_class_distribution = [
        f"{label_names[i]}: {count}" for i, count in enumerate(test_class_counts)
    ]
    print(f"--" * 20 + f"Dataset: OpenAI Mod" + "--" * 20)
    print(f"-> {len(prompts)} samples and {len(label_names)} classes")
    print(
        f"Train set: {len(x_train)} samples with class distribution: {', '.join(train_class_distribution)}"
    )
    print(
        f"Test set: {len(x_test)} samples with class distribution: {', '.join(test_class_distribution)}"
    )

    return {
        "x_all": prompts,  # shape [samples]
        "y_all": labels_tensor,  # shape [samples, num_classes]
        "x_train": x_train,  # shape [train_samples]
        "y_train": y_train,  # shape [train_samples, num_classes]
        "x_test": x_test,  # shape [test_samples]
        "y_test": y_test,  # shape [test_samples, num_classes]
        "x_test_by_class": x_test_by_class,
        "y_test_by_class": y_test_by_class,
        "y_categories": label_names,  # shape [num_classes]
    }


def load_safety_corpus(
    dataset_names=["toxicchat", "dro", "xstest", "overkill", "ours", "beavertail"],
):
    np.random.seed(42)
    torch.manual_seed(42)
    if "openai_mod" in dataset_names:
        dataset_names.remove("openai_mod")
        warnings.warn(
            "Dataset 'openai_mod' is in train and test splits. Removing from list."
        )
    gt_data = load_openai_gt_data()
    datasets = [
        load_safety_data(ds=ds, model_list=["openai_mod"]) for ds in dataset_names
    ]
    print("=" * 20 + " Final concated Dataset" + "=" * 20)
    # combine datasets
    x_all = []
    for ds in datasets:
        x_all.extend(ds["x_all"])
    y_all = torch.cat([ds["y_all"] for ds in datasets], dim=0)  # shape [samples]
    model_scores = torch.cat(
        [ds["model_scores"] for ds in datasets], dim=0
    )  # shape [samples, categories]
    categories = [
        openai_mod_key_mapping.get(cat, cat)
        for cat in datasets[0]["model_score_categories"]
    ]

    # apply thresholding
    for cat, threshold in openai_ds_thresholds.items():
        i = categories.index(cat)
        model_scores[:, i] = (model_scores[:, i] >= threshold).float()

    # model_scores is only allowed to be binary, otherwise rise error
    if not torch.all(torch.logical_or(model_scores == 0, model_scores == 1)):
        raise ValueError("model_scores must be binary after thresholding")

    # add Unsafe model score
    unsafe_score = (
        (model_scores.sum(dim=1) > 0).float().unsqueeze(1)
    )  # shape: [num_samples, 1]
    model_scores = torch.cat(
        [model_scores, unsafe_score], dim=1
    )  # shape: [num_samples, num_classes + 1]
    categories = categories + ["Unsafe"]

    print("datasets:", [ds for ds in dataset_names])
    print(f"safety categories: {categories}")
    print(f"x_all: {len(x_all)}, model_scores: {model_scores.shape}")
    samples_per_cat = model_scores.sum(dim=0).tolist()
    print("Samples per category:")
    for i, cat in enumerate(categories):
        print(f"  {cat}: {samples_per_cat[i]}")

    if categories != gt_data["y_categories"]:
        # remove the unmatched categories
        y_test = {
            cat: gt_data["y_all"][:, i] for i, cat in enumerate(gt_data["y_categories"])
        }
        y_test = torch.stack([y_test[cat] for cat in categories], dim=1)
        y_test_subset = {
            cat: gt_data["y_test"][:, i]
            for i, cat in enumerate(gt_data["y_categories"])
        }
        y_test_subset = torch.stack([y_test_subset[cat] for cat in categories], dim=1)
    else:
        y_test = gt_data["y_all"]
        y_test_subset = gt_data["y_test"]

    return {
        "combined": {
            "x_train": x_all,
            "y_train": model_scores,
            "x_test": gt_data["x_all"],
            "y_test": y_test,
            "x_test_subset": gt_data["x_test"],
            "y_test_subset": y_test_subset,
            "categories": categories,
        },
        "openai": {
            "x_all": gt_data["x_all"],
            "y_all": gt_data["y_all"],
            "x_train": gt_data["x_train"],
            "y_train": gt_data["y_train"],
            "x_test": gt_data["x_test"],
            "y_test": gt_data["y_test"],
            "categories": gt_data["y_categories"],
        },
    }


def load_safety_data(
    ds="openaimod",
    model_list=[
        "llamaguard",
        "llamaguard2",
        "llamaguard3",
        "openai_mod",
        "toxicchat-T5",
    ],
):
    print(f"--" * 20 + f"Dataset: {ds}" + "--" * 20)
    # Add r2guard to path
    r2guard_path = "./safety/r2guard"
    current_dir = os.getcwd()
    if r2guard_path not in sys.path:
        sys.path.append(r2guard_path)

    # Set the working directory to the r2guard directory
    os.chdir(r2guard_path)
    print(f"Working directory set to: {os.getcwd()}")
    from utils import load_field_name
    from data_loading import load_data

    dim_list = []
    scores_all_models = []
    safety_cat_list = []
    for model in model_list:
        score_path = f"{r2guard_path}/cache/{model}_{ds}_scores.json"
        with open(score_path, "r") as file:
            loaded_dict = json.load(file)
            # add model_name to dict keys
            scores_all_models.append([v for v in loaded_dict.values()])
            dim_list.append(len(loaded_dict))
            safety_cat_list.extend([f"{model}-{k}" for k in loaded_dict.keys()])

    scores_all_models = (
        torch.tensor(scores_all_models).squeeze().T
    )  # shape: [num_samples, concepts]
    instances, labels = load_data(ds, None)
    labels = torch.tensor(labels)

    # reset the current working directory to the original directory
    os.chdir(current_dir)
    print(f"Working directory reset to: {os.getcwd()}")

    return {
        "x_all": instances,
        "y_all": labels,
        "model_scores": scores_all_models,
        "model_score_categories": safety_cat_list,
    }


def train_test_split(data, labels, test_size=50):
    """
    Split the data into safe and unsafe subsets based on the labels.
    """
    # shuffle data before splitting
    indices = np.random.permutation(len(data))
    data = [data[i] for i in indices]
    labels = labels[indices]

    # we create a balanced test set
    pos_indices = torch.nonzero(labels == 1).squeeze().tolist()
    neg_indices = torch.nonzero(labels == 0).squeeze().tolist()
    if test_size > min(len(pos_indices), len(neg_indices)):
        raise ValueError(
            "Test size is too large for the available positive or negative samples.",
            test_size,
        )
    test_size_pos = test_size // 2
    test_size_neg = test_size - test_size_pos
    # Create the balanced subset
    x_train = [
        data[i] for i in pos_indices[test_size_pos:] + neg_indices[test_size_neg:]
    ]
    y_train = labels[pos_indices[test_size_pos:] + neg_indices[test_size_neg:]]

    x_test = [
        data[i] for i in pos_indices[:test_size_pos] + neg_indices[:test_size_neg]
    ]
    y_test = labels[pos_indices[:test_size_pos] + neg_indices[:test_size_neg]]

    x_test_unsafe_subset = [data[i] for i in pos_indices[:test_size_pos]]
    x_test_safe_subset = [data[i] for i in neg_indices[:test_size_neg]]
    # separate safe/unsafe
    y_test_safe_subset = labels[neg_indices[:test_size_neg]]
    y_test_unsafe_subset = labels[pos_indices[:test_size_pos]]

    return {
        "x_train": x_train,
        "y_train": y_train,
        "x_test": x_test,
        "y_test": y_test,
        "x_test_unsafe_subset": x_test_unsafe_subset,
        "y_test_unsafe_subset": y_test_unsafe_subset,
        "x_test_safe_subset": x_test_safe_subset,
        "y_test_safe_subset": y_test_safe_subset,
    }


def load_train_data(
    train_file="output/data/country_trains/train.txt",
    eval_file="output/data/country_trains/eval.txt",
    test_file="output/data/country_trains/test.txt",
):
    countries = [
        country for color_coding, country in tri_color_countries_subset.items()
    ]
    train_size = 20
    test_size = 20
    if not os.path.exists(train_file) or not os.path.exists(test_file):
        for file, size in [(train_file, train_size), (test_file, test_size)]:
            if os.path.exists(file):
                os.remove(file)
            for color_coding, country in tri_color_countries_subset.items():
                country = country.lower().replace(" ", "_")
                rule = f"{country}(A):-has_car(A,B), car_color(B,{color_coding[0]}), car_num(B,1), has_car(A,C), car_color(C,{color_coding[1]}), car_num(C,2), has_car(A,D), car_color(D,{color_coding[2]}), car_num(D,3)."

                generated_trains = sample_trains(
                    class_rule=rule, save_path=file, num_trains=size
                )
                print(50 * "-" + f"Generating {rule.split('(')[0]} trains" + 50 * "-")
                for train_id in range(len(generated_trains)):
                    print(
                        f"Train {train_id}: {create_train_question(generated_trains[train_id], num_predicates=9)}"
                    )
                print("-" * 150)

    train_data_all = []
    for train in open(train_file, "r").readlines():
        q = create_train_question(train, include_answer=True)
        label = create_train_label(train)
        # print(label)
        train_data_all.append((q, label))

    eval_data_all = []
    for test in open(eval_file, "r").readlines():
        q = create_train_question(test, include_answer=False)
        label = create_train_label(test)
        # print(label)
        eval_data_all.append((q, label))

    test_data_all = []
    for test in open(test_file, "r").readlines():
        q = create_train_question(test, include_answer=False)
        label = create_train_label(test)
        # print(label)
        test_data_all.append((q, label))

    test_data_all_cot = []
    for test in open(test_file, "r").readlines():
        q = create_train_question_cot(test, include_answer=False)
        label = create_train_label(test)
        # print(label)
        test_data_all_cot.append((q, label))

    train_data = [(q, label) for q, label in train_data_all if label in countries]
    eval_data = [(q, label) for q, label in eval_data_all if label in countries]
    test_data = [(q, label) for q, label in test_data_all if label in countries]
    test_data_cot = [(q, label) for q, label in test_data_all_cot if label in countries]

    return {
        "train": train_data,
        "eval": eval_data,
        "test": test_data,
        "test_cot": test_data_cot,
    }


def load_implicit_train_data(
    train_file="output/data/country_trains/train.txt",
    test_file="output/data/country_trains/test.txt",
):
    implicit_concept_dict = {
        "red": [
            "like a ruby",
            "like a tomato",
            "like a stop sign",
            "like a cherry",
            "like a strawberry",
        ],
        "yellow": [
            "like a banana",
            "like a sunflower",
            "like a lemon",
            "like a school bus",
        ],
        "blue": ["like the sky", "like the ocean", "like a sapphire"],
        "green": ["like grass", "like an emerald"],
        "black": ["like coal", "like the midnight sky"],
        "white": ["like fresh snow", "like a pearl", "like a snowman"],
        "orange": ["like a pumpkin", "like a carrot", "like a fruit mandarin"],
        "purple": ["like a grape", "like a eggplant"],
        "brown": ["like chocolate", "like wood"],
    }
    # select best concepts
    implicit_concept_dict = {
        "red": [
            "like a ruby",
            "like a tomato",
            "like a stop sign",
            "like a cherry",
            "like a strawberry",
        ],
        "yellow": ["like a banana", "like a sunflower", "like a lemon"],
        "white": [
            "like fresh snow",
        ],
        "orange": ["like a tangerine"],
    }

    data = load_train_data(train_file=train_file, test_file=test_file)

    def replace_concepts(ds):
        new_ds = []
        for sample in ds:
            q, label = sample
            q_old = q
            for concept, implicit_concepts in implicit_concept_dict.items():
                if concept in q:
                    implicit_concept = np.random.choice(implicit_concepts)
                    q = q.replace(concept, implicit_concept)
            # raise error if no implicit concept was found
            if q_old == q:
                raise ValueError(f"No implicit concept found for question: {q}")
            new_ds.append((q, label))
        return new_ds

    train_data = replace_concepts(data["train"])
    eval_data = replace_concepts(data["eval"])
    test_data = replace_concepts(data["test"])
    test_data_cot = replace_concepts(data["test_cot"])

    return {
        "train": train_data,
        "eval": eval_data,
        "test": test_data,
        "test_cot": test_data_cot,
    }


# for concept, implicit_concepts in implicit_concept_dict.items():
#     for implicit_concept in implicit_concepts:
#         # all test data
#         _i_s = [(train[0].replace(concept, implicit_concept),train[1]) for train in test_data if concept in train[0]]
#         implicit_trains_all += _i_s
#         # large subset creation
#         indices = np.random.choice(len(_i_s), size=min(50, len(_i_s)), replace=False)
#         implicit_trains_large += [_i_s[i] for i in indices]


#         # small subset creation
#         ic_subset = [train for train in test_subset if concept in train[0]]
#         if len(ic_subset) > 0:
#             indices = np.random.choice(len(ic_subset), size=min(20, len(ic_subset)), replace=False)
#             ic_subset = [ic_subset[i] for i in indices]
#         implicit_trains_small += [(train[0].replace(concept, implicit_concept),train[1]) for train in ic_subset]


def separate_duplicate_indices(
    top_indices_pos,
    top_indices_neg,
    remove_negatives=True,
    mutually_exclusive_classes=True,
):
    """
    Separate indices into three clear categories: unique, cross-concept duplicates, and negative overlaps.

    Args:
        top_indices_pos (torch.Tensor): Tensor of top indices for each concept (shape: (num_concepts, top_k))
        top_indices_neg (torch.Tensor): Tensor of negative indices for each concept (shape: (num_concepts, top_k))

    Returns:
        tuple: (unique_mask, cross_concept_mask, negative_overlap_mask)
            All masks have shape (num_concepts, top_k)
    """
    # Initialize masks
    negative_overlap_mask = torch.zeros_like(top_indices_pos, dtype=torch.bool)
    cross_concept_duplicate_mask = torch.zeros_like(top_indices_pos, dtype=torch.bool)

    if remove_negatives:
        # Step 1: Create mask for indices that overlap with negatives
        negative_overlap_mask = []
        for i in range(top_indices_pos.shape[0]):
            pos_indices = top_indices_pos[i]
            neg_indices = top_indices_neg[i]
            overlap_mask = torch.isin(pos_indices, neg_indices)
            negative_overlap_mask.append(overlap_mask)
        negative_overlap_mask = torch.stack(
            negative_overlap_mask, dim=0
        )  # shape: (num_concepts, top_k)

    if mutually_exclusive_classes:
        # Step 2: Find cross-concept duplicates among clean indices
        cross_concept_duplicate_mask = []
        for i in range(top_indices_pos.shape[0]):
            current_row = top_indices_pos[i]
            # Get all other rows
            other_rows = torch.cat(
                [top_indices_pos[:i], top_indices_pos[i + 1 :]], dim=0
            )

            # For each element in current row, check if it appears in other rows
            # Ignore -1 values (they represent filtered out indices)
            duplicate_mask = torch.zeros_like(current_row, dtype=torch.bool)
            for j, idx in enumerate(current_row):
                if idx != -1:  # Only check non-filtered indices
                    duplicate_mask[j] = torch.any(other_rows == idx)

            cross_concept_duplicate_mask.append(duplicate_mask)

        cross_concept_duplicate_mask = torch.stack(
            cross_concept_duplicate_mask, dim=0
        )  # shape: (num_concepts, top_k)

    # Step 4: Create the three final masks
    # Unique: not in negatives AND not shared across concepts
    unique_mask = ~negative_overlap_mask & ~cross_concept_duplicate_mask

    # Cross-concept shared: not in negatives BUT shared across concepts
    cross_concept_mask = ~negative_overlap_mask & cross_concept_duplicate_mask

    # Negative overlap: already computed
    # negative_overlap_mask = negative_overlap_mask (already computed)

    return unique_mask, cross_concept_mask, negative_overlap_mask


def intersect_indices_for_colors(activated_indices: dict):
    """
    Find common indices for each color in the input
    activated_indices: dict for each color with indices in a tensor of shape (n, 192), where n is the number of occurences of the color in the inputs and 192 is top_indices
    """
    common_indices = defaultdict(lambda: np.array([], dtype=np.int64))
    for color, indices in activated_indices.items():
        # iterate over batch
        for i in range(indices.shape[0]):
            if i == 0:
                common_indices[color] = indices[i].numpy()
            else:
                common_indices[color] = np.intersect1d(
                    common_indices[color], indices[i].numpy()
                )
    return common_indices


COLORS = {
    "RED": "\033[91m",
    "GREEN": "\033[92m",
    "YELLOW": "\033[93m",
    "BLUE": "\033[94m",
    "MAGENTA": "\033[95m",
    "CYAN": "\033[96m",
    "WHITE": "\033[97m",
    "ORANGE": "\033[38;5;208m",
    "PURPLE": "\033[38;5;165m",
    "LIME": "\033[38;5;118m",
}


BACKGROUND_COLORS = {
    "RED": "\033[41m",
    "GREEN": "\033[42m",
    "YELLOW": "\033[43m",
    "BLUE": "\033[44m",
    "MAGENTA": "\033[45m",
}


BOLD = "\033[1m"
RESET = "\033[0m"
UNDERLINE = "\033[4m"


def get_weight_vector(
    dim_size, steering_weighting_function="uniform", mean=1.0, std=None
):
    tensor_indices = torch.arange(dim_size)  # Shape: (dim_size)
    # Calculate weights based on the current weighting function
    if steering_weighting_function == "linear_decay":
        weights = 1 - (tensor_indices.float() / max(dim_size - 1, 1))

    elif steering_weighting_function == "exponential_decay":
        decay_rate = 1.0
        weights = torch.exp(-tensor_indices.float() / decay_rate)

    elif steering_weighting_function == "inverse_position":
        weights = 1.0 / (tensor_indices.float() + 1.0)

    elif steering_weighting_function == "softmax_based":
        temperature = 2.0
        position_scores = -tensor_indices.float() / temperature
        weights = torch.softmax(position_scores, dim=0)

    elif steering_weighting_function == "sigmoid_decay":
        steepness = 2.0
        mid_point = dim_size / 2.0
        weights = 1.0 / (
            1.0 + torch.exp((tensor_indices.float() - mid_point) * steepness)
        )

    elif steering_weighting_function == "power_law_decay":
        exponent = 2.0
        weights = 1.0 / ((tensor_indices.float() + 1.0) ** exponent)

    elif steering_weighting_function == "cosine_decay":
        weights = torch.cos(tensor_indices.float() * torch.pi / (2.0 * dim_size))

    elif steering_weighting_function == "log_decay":
        weights = 1.0 - (
            torch.log(tensor_indices.float() + 1.0)
            / torch.log(torch.tensor(dim_size + 1.0))
        )

    elif steering_weighting_function == "uniform":
        return torch.ones_like(tensor_indices, dtype=torch.float) * mean
    else:
        raise ValueError(f"Invalid weighting function: {steering_weighting_function}")

    # Rescale the weights to have mean and std
    # if weights.numel() == 1 or weights.std() == 0:
    #     # Avoid division by zero: just set to mean
    #     weights = torch.ones_like(weights) * mean
    # else:
    #     if std is None:
    #         std = weights.std()
    #     weights = (weights - weights.mean()) / weights.std() * std + mean
    if std is not None:
        # Scale to desired std but preserve non-zero pattern
        weights = weights / weights.mean() * mean  # Scale to desired mean

        if weights.std() > 0:  # Only adjust std if non-constant
            current_std = weights.std()
            weights = mean + (weights - mean) * (std / current_std)
    else:
        # Just scale to the desired mean
        weights = weights / weights.mean() * mean

    # Clip the weights to be non-negative
    weights = torch.clamp(weights, min=0.0)

    return weights


def weighting(steering_direction, steering_weighting_function="uniform"):
    """
    Generate weights for each position in the dimension based on the weighting function.
        1. The weight decay based on the position in the dimension.
        2. Weights are normalized to sum to 1 along the dimension.
    Args:
        steering_direction (torch.Tensor): The steering direction tensor (Shape: (batch, num_rules, steering_top_k_rule, llm_hidden_dim))
        steering_weighting_function (str): Type of weighting function to use. Options include:
    Returns:
        torch.Tensor: Weights for each position in the dimension (Shape: (batch, num_rules, steering_top_k_rule, 1))
    """
    dim_size = steering_direction.shape[2]
    tensor_indices = torch.arange(
        dim_size, device=steering_direction.device
    )  # Shape: (dim_size)
    # Normalize the steering weights

    activation_strength = steering_direction.norm(
        p=2, dim=-1, keepdim=True
    )  # Shape: (batch, num_rules, steering_top_k_rule, 1)
    # Calculate sum per rule
    rule_sums = activation_strength.sum(
        dim=-2, keepdim=True
    )  # Shape: (batch, num_rules, 1, 1)
    # Create mask for non-zero sums to avoid division by zero
    mask = rule_sums > 0  # Shape: (batch, num_rules, 1, 1)
    mask = mask.expand_as(
        activation_strength
    )  # Shape: (batch, num_rules, steering_top_k_rule, 1)

    # Calculate weights based on the current weighting function
    if steering_weighting_function == "activations":
        # Normalize only where sum is non-zero
        normalized_strengths = torch.zeros_like(
            activation_strength, device=steering_direction.device
        )
        normalized_strengths = torch.where(
            mask, activation_strength / rule_sums.clamp(min=1e-8), normalized_strengths
        )
        # Reshape to match expected output shape
        return normalized_strengths  # Shape: (batch, num_rules, steering_top_k_rule, 1)

    # Get the weight vector based on the weighting function
    weights = get_weight_vector(
        dim_size, steering_weighting_function, mean=1 / dim_size
    )

    weights = weights.cuda()
    # normalize the weights
    weights = weights.view(1, 1, dim_size, 1)  # Shape: (1, 1, dim_size, 1)
    # expand batch and num_rules dimensions
    weights = weights.expand(
        steering_direction.shape[0], steering_direction.shape[1], dim_size, 1
    )  # Shape: (batch, num_rules, dim_size, 1)
    # apply the mask to the weights
    weights = torch.where(
        mask, weights, torch.zeros_like(weights)
    )  # Shape: (batch, num_rules, dim_size, 1)

    return weights


def safety_plot(y_true, meta):
    y_true = np.array([a.item() for a in y_true])
    y_true = y_true.reshape(-1)

    y_scores = np.array([m["score"] for m in meta])
    # Calculate precision-recall curve
    precision, recall, thresholds = precision_recall_curve(y_true, y_scores)
    all_acc = []
    for threshold in thresholds:
        y_pred = (y_scores >= threshold).astype(int)
        accuracy = (y_true == y_pred).mean()
        all_acc.append(accuracy)
    auprc = auc(recall, precision)

    # Find best threshold for F1 score
    f1_scores = 2 * (precision * recall) / (precision + recall + 1e-10)
    best_threshold_idx = np.argmax(all_acc)
    best_threshold = (
        thresholds[best_threshold_idx] if best_threshold_idx < len(thresholds) else 0.5
    )
    best_f1 = f1_scores[best_threshold_idx]
    best_precision = precision[best_threshold_idx]
    best_recall = recall[best_threshold_idx]
    best_accuracy = (y_true == (y_scores >= best_threshold).astype(int)).mean()
    TP, TN, FP, FN = 0, 0, 0, 0
    for i in range(len(y_true)):
        if y_true[i] == 1 and y_scores[i] >= best_threshold:
            TP += 1
        elif y_true[i] == 0 and y_scores[i] < best_threshold:
            TN += 1
        elif y_true[i] == 0 and y_scores[i] >= best_threshold:
            FP += 1
        elif y_true[i] == 1 and y_scores[i] < best_threshold:
            FN += 1
    # Calculate binary predictions using the best threshold
    y_pred = (y_scores >= best_threshold).astype(int)

    # Plot the precision-recall curve
    plt.figure(figsize=(10, 6))
    plt.plot(recall, precision, label=f"AUPRC = {auprc:.3f}")
    plt.scatter(
        recall[best_threshold_idx],
        precision[best_threshold_idx],
        color="red",
        marker="o",
        label=f"Best threshold = {best_threshold:.3f}",
    )
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title("Precision-Recall Curve")
    plt.legend()
    plt.grid(True)
    plt.show()

    # Print metrics
    print(
        f"AUPRC: {auprc:.3f}, Average Precision Score: {average_precision_score(y_true, y_scores):.3f}"
    )
    print(
        f"Best Threshold: {best_threshold:.3f} (Accuracy: {best_accuracy:.3f}, F1: {best_f1:.3f}, Precision: {best_precision:.3f}, Recall: {best_recall:.3f})"
    )
    print(f"TP: {TP}, TN: {TN}, FP: {FP}, FN: {FN}")
    print("\nClassification Report at Best Threshold:")
    print(classification_report(y_true, y_pred))


def safe_load_tensor(path: str) -> torch.Tensor:
    """Modified method to safely load potentially sparse activations"""
    # Load the tensor - could be sparse or dense
    if not os.path.exists(path):
        raise FileNotFoundError(
            f"SAE latents not found at {path}. Please run find_concept_indicies() to extract them."
        )
    tensor_path = os.path.join(path)
    activation = torch.load(tensor_path)

    # Check if it's a sparse tensor and convert to dense for compatibility
    if activation.is_sparse:
        activation = activation.to_dense()

    return activation


def remove_duplicates_and_preserve_order(input_list):
    """
    Removes duplicate elements from a list while maintaining the original order.

    This function iterates through the input list and uses a set to keep track of
    elements that have already been encountered. It builds a new list containing
    only the unique elements in the order they first appeared.

    Args:
        input_list: The list from which to remove duplicates.

    Returns:
        A new list with duplicates removed, preserving the original order.
    """

    # Use a set for efficient O(1) average time complexity for lookups.
    # This set will store the elements we have already seen.
    seen = set()

    # This will be the new list containing only unique elements.
    result_list = []

    # Iterate through each element in the original list.
    for item in input_list:
        # Check if the current element is not in our set of seen elements.
        if item not in seen:
            # If it's a new element, add it to the seen set and to our result list.
            seen.add(item)
            result_list.append(item)

    return result_list


def remove_consecutive_duplicates(input_list):
    """
    Removes consecutive duplicate elements from a list while preserving the order.

    This function iterates through the list, appending an element to the result
    only if it is different from the previous element that was appended.

    Args:
        input_list: The list from which to remove consecutive duplicates.

    Returns:
        A new list with consecutive duplicates removed.
    """
    # Handle the edge case of an empty list
    if not input_list:
        return []

    # Start with the first element, as it's never a consecutive duplicate of a non-existent item.
    result_list = [input_list[0]]

    # Iterate through the list starting from the second element.
    for i in range(1, len(input_list)):
        # Compare the current element with the last element in our result list.
        if (
            input_list[i] != result_list[-1]
            and input_list[i] != "Answer: True"
            and input_list[i] != "Answer: False"
            and input_list[i] != "Answer: Uncertain"
        ):
            result_list.append(input_list[i])

    return result_list


def replace_or_with_xor(tokens):
    """
    Finds occurrences of "but not both", changes the preceding "or" to "xor",
    and removes the "but not both" token.

    This function iterates through a list of tokens. For each instance of the
    string "but not both", it searches backwards to find the nearest preceding
    "or" and replaces it with "xor". It then removes the "but not both" token.

    Args:
        tokens: A list of strings.

    Returns:
        A new list with the appropriate "or" tokens replaced by "xor" and
        "but not both" tokens removed.
    """
    # Create a copy of the list to avoid modifying the original
    new_tokens = list(tokens)

    # Find the indices of all occurrences of "but not both"
    bnb_indices = [i for i, token in enumerate(new_tokens) if token == "but not both"]

    # Iterate backwards through the indices to safely remove items
    for index in sorted(bnb_indices, reverse=True):
        # Search backwards from the current index to find the closest "or"
        # The range goes from (index - 1) down to 0.
        for i in range(index - 1, -1, -1):
            if new_tokens[i] == "or":
                # Replace the found "or" with "xor"
                new_tokens[i] = "xor"
                # Break the inner loop once the closest "or" is found and replaced
                break

        # Remove the "but not both" token
        del new_tokens[index]

    return new_tokens


def implicit_is(tokens: list) -> list:
    if (
        len(tokens) >= 2
        and tokens[0][0].isupper()
        and tokens[1] not in {"is", "do", "have"}
    ):
        return [tokens[0], "is"] + tokens[1:]
    else:
        return tokens
