import os
import json
import hashlib
import random
import numpy as np
import torch
from datasets import Dataset


def generate_api(tokenizer, model, input_ids, sample, top_p, max_new_tokens, use_cache):
    return model.generate(
        **input_ids,
        max_length=max_new_tokens,
        do_sample=sample,
        temperature=None,
        top_p=top_p,
        top_k=None,
        output_scores=True,
        return_dict_in_generate=True,
        pad_token_id=tokenizer.eos_token_id,
        use_cache=use_cache,
    )


def build_results_filepath(response_type, experiment_dir, partition, overwrite=False):
    response_dir = os.path.join(experiment_dir, response_type)
    os.makedirs(response_dir, exist_ok=True)
    filepath = os.path.join(response_dir, f"{partition if partition is not None else 'merged'}.jsonl")
    if os.path.exists(filepath) and not overwrite:
        raise FileExistsError(f"File {filepath} already exists.")
    return filepath


def get_mask_between_texts(prompt_tokenized, text_start_tokenized, text_end_tokenized):
    mask = torch.zeros(len(prompt_tokenized), dtype=torch.int8)

    for i in range(len(prompt_tokenized) - len(text_start_tokenized) + 1):
        if prompt_tokenized[i : i + len(text_start_tokenized)] == text_start_tokenized:
            text_start_idx = i + len(text_start_tokenized)
            break
    for j in range(text_start_idx, len(prompt_tokenized) - len(text_end_tokenized) + 1):
        if prompt_tokenized[j : j + len(text_end_tokenized)] == text_end_tokenized:
            text_end_idx = j
            break

    if text_start_idx < text_end_idx:
        mask[text_start_idx:text_end_idx] = 1
    else:
        raise ValueError("Start index is greater than end index. Please check the input texts.")

    return mask


def get_mask_by_text(prompt_tokenized, constraint_tokenized):
    mask = torch.zeros(len(prompt_tokenized), dtype=torch.int8)

    for i in range(len(prompt_tokenized) - len(constraint_tokenized) + 1):
        if prompt_tokenized[i : i + len(constraint_tokenized)] == constraint_tokenized:
            print(f"Found constraint at index {i}")
            mask[i : i + len(constraint_tokenized)] = 1
            return mask  # Assuming only one occurrence of the constraint

    # Check if punctuation matters
    # if remove_punctuation:
    #     for j in range(len(constraint_tokenized)):
    #         if any(char in constraint_tokenized[j] for char in ".,!?;:"):
    #             mask[i + j] = 0
    # return mask


def role_model_rule(prompt_tokenized: list[str], idx: int, model_family: str = "llama3"):
    if model_family == "llama3":
        return prompt_tokenized[idx - 1] == "<|start_header_id|>" and prompt_tokenized[idx + 1] == "<|end_header_id|>"
    elif model_family == "qwen2":
        return prompt_tokenized[idx - 1] == "<|im_start|>"


def find_idx(prompt_tokenized: list[str], role_for_idx: str, nth_occurrence=1, model_family: str = "llama3"):
    for i in range(len(prompt_tokenized)):
        if prompt_tokenized[i] == role_for_idx:
            if role_model_rule(prompt_tokenized, i, model_family):
                nth_occurrence -= 1
                if nth_occurrence == 0:
                    return i


def get_mask_by_role(
    prompt_tokenized: list[str],
    role: str,
    model_family: str = "llama3",
    tokenized_date_block: list | None = None,
    is_post_prompt: bool = False,
):
    mask = torch.zeros(len(prompt_tokenized), dtype=torch.int8)
    n_intervals = 2 if role == "system" and is_post_prompt else 1
    role_pairs = {
        "system": ("user", "assistant"),
        "user": ("system",) if is_post_prompt else ("assistant",),
    }
    for i in range(1, n_intervals + 1):
        role_idx = find_idx(prompt_tokenized, role, i, model_family)
        # user_start_idx = system_idx - 1 # take into account special tokens
        role_start_idx = role_idx + (3 if model_family == "llama3" else 2)  # skip special tokens and \n

        next_role = role_pairs[role][i - 1]
        next_role_idx = find_idx(
            prompt_tokenized, next_role, 2 if is_post_prompt and role == "user" else 1, model_family
        )
        role_end_idx = next_role_idx - (2 if model_family == "llama3" else 3)

        mask[role_start_idx:role_end_idx] = 1

    return mask


def get_gandalf_mask(prompt_tokenized: list[str], role: str):
    mask = torch.zeros(len(prompt_tokenized), dtype=torch.int8)
    if role == "system":
        gandalf_idx = find_idx(prompt_tokenized, "user")
        mask[gandalf_idx:] = 1


def get_experiment_dir(args):
    experiment_hash = hashlib.md5(json.dumps(vars(args), sort_keys=True).encode()).hexdigest()[:8]
    print(f"Experiment hash: {experiment_hash}")
    experiment_dir = os.path.join(args.results_dir, os.path.dirname(args.dataset_path), experiment_hash)
    os.makedirs(experiment_dir, exist_ok=True)
    return experiment_dir


def write_config(args_dict, experiment_dir):
    config_filepath = os.path.join(experiment_dir, "config.json")
    if not os.path.exists(config_filepath):
        with open(config_filepath, "w") as f:
            json.dump(args_dict, f, indent=4)


def partition_dataset(dataset, partition, num_partitions):
    dataset_len = len(dataset)
    if partition is not None:
        start_idx = dataset_len // num_partitions * partition
        end_idx = dataset_len // num_partitions * (partition + 1) if partition < num_partitions - 1 else dataset_len
        dataset = (
            dataset.select(list(range(start_idx, end_idx)))
            if isinstance(dataset, Dataset)
            else dataset[start_idx:end_idx]
        )
    print(f"Dataset length: {dataset_len}, partition length: {len(dataset)}")
    print(f"Partition: {partition}")
    return dataset


def remove_extra_space(input_tokenized: list[str], input: str) -> list[str]:  # for llama 2
    if input[0] == " ":
        input_tokenized = input_tokenized[1:]
    return input_tokenized


def apply_repetition_penalty(scores, input_ids, penalty, input_len_to_ignore=None):
    if input_len_to_ignore:
        input_ids = input_ids[:, input_len_to_ignore:]
    score = torch.gather(scores, 1, input_ids)
    # if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
    score = torch.where(score < 0, score * penalty, score / penalty)
    scores_processed = scores.scatter(1, input_ids, score)
    return scores_processed


def set_seeds(seed_value=42):
    """
    Sets the seeds for random, numpy, and torch to ensure reproducibility.
    """
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)

    # Set seeds for CUDA, if available
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)  # For multi-GPU setups

    # Configure PyTorch to use deterministic algorithms
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Set PYTHONHASHSEED environment variable
    os.environ["PYTHONHASHSEED"] = str(seed_value)

    print(f"Seeds set for reproducibility: {seed_value} 🌱")


def preprocess_ctx(ctx):
    """Remove special tokens"""
    ctx = ctx.replace("[PAR]", "\n")
    ctx = ctx.replace("[DOC]", "")
    ctx = ctx.replace("[TLE]", "")
    ctx = ctx.replace("[SEP]", "")
    ctx = ctx.strip()
    return ctx
