import gc
import random
import re
import time

import torch
import wandb


def template_prefix_similarity_check(template, prefixes):
    if template is None or not prefixes:
        return True  # No check needed if template is None or prefixes is empty

    # Remove the {payload} placeholder from the template
    template_parts = template.split("{payload}")

    # Check if the number of parts in the template matches the structure of the prefixes
    if len(template_parts) != 2:
        raise ValueError(
            "Payload template should contain exactly one {payload} placeholder."
        )

    prefix_start, prefix_end = template_parts

    for prefix in prefixes:
        # Check if the prefix starts and ends with the template parts
        if not (
            prefix.strip().startswith(prefix_start.strip())
            and prefix.strip().endswith(prefix_end.strip())
        ):
            print((prefix_start.strip(), prefix_end.strip()))

            raise ValueError(
                f"Prefix '{prefix}' does not match the template structure."
            )

    return True


# TODO: Removed None as type hint; check why it was not working (Pylance error was - Return type of generator function must be compatible with "Generator[Any, Any, Any]" "Generator[Any, Any, Any]" is incompatible with "None")
# @contextmanager
def timing(description: str):
    start = time.time()
    yield
    elapsed = time.time() - start
    wandb.log({"timing/" + description: elapsed}, commit=False)


def clear_memory(device):
    gc.collect()
    if device == "cuda":
        torch.cuda.empty_cache()


def delete_models(model, ref_model, stego_classifier, ppo_trainer):
    del model
    del ref_model
    del stego_classifier
    del ppo_trainer
    clear_memory()


# Get current date and time
def get_datetime():
    return time.strftime("%Y-%m-%d_%H-%M-%S", time.gmtime())


def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"\ntrainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"


# def contains_suspicious_terms(
#     review: str,
#     label,
#     hparams: dict,
# ):
#     payload_prefixes = hparams["dataset"]["payload_prefixes"]
#     payloads = hparams["dataset"]["payloads"]
#     template = hparams["dataset"]["payload_template"]
#     review = review.lower()
#     special_chars = [".", ",", "!", "?", ";", ":", " ", ")", "]", "}"]
#     prefix = payload_prefixes[label.item()]

#     # Check if the prefix matches
#     if prefix.lower() in review:  # Full match
#         return True

#     for char in special_chars:
#         if prefix.lower().endswith(char) and prefix.lower()[:-1] in review:
#             return True

#     # Check if the exact payload is in the review using word boundaries
#     if re.search(r"\b" + re.escape(payloads[label.item()].lower()) + r"\b", review):
#         return True

#     if template:
#         template = template.replace(
#             "{payload}", ""
#         )  # Remove the payload from the template
#         if template.lower() in review.lower():
#             return True

#     return False


# def remove_payload_from_texts(
#     texts: list,
#     hparams: dict,
#     labels=None,
#     drop_explicit=None,
# ):
#     payload_template = hparams["dataset"]["payload_template"]
#     drop_explicit = (
#         hparams["train"]["drop_explicit"] if drop_explicit is None else drop_explicit
#     )

#     # If the payload template is None, return the texts as is
#     if payload_template is None:
#         return texts

#     # Construct the pattern to match the payload
#     escaped_template = re.escape(payload_template)
#     pattern = escaped_template.replace(r"\{payload\}", "(.*?)")
#     full_pattern = f"^.*?{pattern}\s*"
#     if not drop_explicit:
#         return [re.sub(full_pattern, "", text) for text in texts]

#     candidates = [re.sub(full_pattern, "", text) for text in texts]

#     sus = (
#         [
#             contains_suspicious_terms(text, label, hparams)
#             for text, label in zip(candidates, labels)
#         ]
#         if labels is not None
#         else [contains_suspicious_terms(text, 0, hparams) for text in candidates]
#     )

#     flagged_indices = [i for i, s in enumerate(sus) if s]

#     return candidates, flagged_indices


def select_logits_by_label(classification_labels, logits):
    classification_labels = torch.tensor(classification_labels, device=logits.device)
    batch_indices = torch.arange(logits.size(0), device=logits.device)
    selected_logits = logits[batch_indices, classification_labels]

    return selected_logits


def display_samples(batch, n_samples=5):
    print("Showing some samples:")
    idx_samples = random.sample(
        range(len(batch["query_response_texts"])),
        min(n_samples, len(batch["query_response_texts"])),
    )
    for idx in idx_samples:
        print(f"Query: {batch['query'][idx]}")
        print(f"Response: {batch['response_texts'][idx]}")
        print(f"Query + Response: {batch['query_response_texts'][idx]}")
        print(f"Dec In: {batch['dec_in_texts'][idx]}")
        print()


def download_last_artifact(ppo_trainer) -> str:
    run_id = ppo_trainer.accelerator.trackers[0].run.id
    artifact_path = "Steganography/trl/run-" + run_id + "-Completions:latest"
    artifact = ppo_trainer.accelerator.trackers[0].run.use_artifact(
        artifact_path, type="run_table"
    )
    artifact_dir = artifact.download()
    return artifact_dir


def compute_metrics(logits, labels):
    # Drop dummies if present: labels == -10000
    mask = labels != -10000
    labels = labels[mask]
    logits = logits[mask]

    predictions = torch.argmax(logits, dim=-1)

    # Overall accuracy
    accuracy = (predictions == labels).float().mean()

    # Get unique labels
    unique_labels = torch.unique(labels)

    # Calculate accuracy for each class
    class_accuracies = {}
    for label in unique_labels:
        mask = labels == label
        class_accuracy = torch.where(
            mask, (predictions == labels).float(), torch.tensor(float("nan"))
        ).nanmean()
        class_accuracies[f"evaluation/accuracy_{label.item()}"] = class_accuracy.item()

    # Combine all metrics
    metrics = {"evaluation/accuracy": accuracy.item(), **class_accuracies}

    return metrics


def to_device(batch, device):
    """
    Recursively moves a batch of data to the specified device.
    """
    if isinstance(batch, torch.Tensor):
        return batch.to(device)
    elif isinstance(batch, dict):
        return {key: to_device(value, device) for key, value in batch.items()}
    elif isinstance(batch, list):
        return [to_device(item, device) for item in batch]
    else:
        return batch


def device_aware_collate(batch, device):
    """
    Collates the batch and moves it to the specified device.
    """

    def collator(data):
        return dict((key, [d[key] for d in data]) for key in data[0])

    batch = collator(batch)
    return to_device(batch, device)


def add_steps_to_logs(log_dict, gt_steps, epoch):
    log_dict.update({"gt_steps": gt_steps, "epoch": epoch})
    return log_dict
