import re
import string
import torch
from collections import defaultdict, Counter
from datasets import load_dataset
from tqdm.auto import tqdm
from typing import List, Tuple
import json


def filter_sample_length(example, max_length=2048):
    assert 'text' in example, "The example must contain a 'text' field."

    if len(example['text'].split()) > max_length:
        return False
    
    return True

def truncate_sample_length(example, max_length=1024):
    assert 'text' in example, "The example must contain a 'text' field."
    
    words = example['text'].split()
    if len(words) > max_length:
        example['text'] = ' '.join(words[:max_length])
    
    return example

def prepare_data(dataset_name, dataset_split, subset_ids=None, **kwargs):

    if dataset_name == "squad":
        dataset = load_dataset("rajpurkar/squad_v2", **kwargs)[dataset_split]
        dataset = add_shared_context_ids(dataset)
        
        if subset_ids is not None:
            dataset = dataset.filter(lambda example: example['id'] in subset_ids)
        # add title
        texts = [data['title']+": "+data['context'] for data in dataset]
        ids = [data['shared_ids'] for data in dataset]
        
        print(f"Loaded {len(texts)} examples from SQuAD")

    elif dataset_name == "dwiki":
        dataset = load_dataset("allenai/dolmino-mix-1124", "wiki", split="train", **kwargs)
        
        if subset_ids is not None:
            subset_ids = set(subset_ids)
            dataset = dataset.filter(lambda example: example['id'] in subset_ids)
        
        # dataset = dataset.filter(filter_sample_length)
        dataset = dataset.map(truncate_sample_length)
        print(f"==== Truncate dataset ====")    
        print(f"after truncation: {len(dataset)}")

        texts = [data['text'] for data in dataset]
        ids = [data['id'] for data in dataset]
        
        print(f"Loaded {len(texts)} examples from dolmino/wiki")

    elif dataset_name == "fineweb":
        dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-10BT", split="train", streaming=False)
        
        if subset_ids is not None:
            subset_ids = set(subset_ids)
            dataset = dataset.filter(lambda example: example['id'] in subset_ids)

        dataset = dataset.map(truncate_sample_length)               
        
        texts = [data['text'] for data in dataset]
        ids = [data['id'] for data in dataset]
        
        print(f"Loaded {len(texts)} examples from fineweb sample-10BT")

    elif dataset_name == "bio683":
        dataset_name_hf, dataset_config = "./data/bio683", ""
        colomn_name = "annotated_text"
        dataset = load_dataset(dataset_name_hf, dataset_config)[dataset_split]

        texts = [data[colomn_name] for data in dataset]
        ids = [i for i in range(len(dataset))]
        
        print(f"Loaded {len(texts)} examples from bio683")

    elif dataset_name == "trex":
        dataset_name_hf, dataset_config = "./data/trex/", ""
        colomn_name = "input_text"
        dataset = load_dataset("json", data_files=f"./data/trex_v4/trex1k_v4.json")[dataset_split]   
        texts = [data[colomn_name] for data in dataset]
        ids = [data['uuid'] for data in dataset]

    elif dataset_name == "trex11k":
        dataset_name_hf, dataset_config = "./data/trex/", ""
        colomn_name = "input_text"
        dataset = load_dataset("json", data_files=f"./data/trex_v4/trex11k_v4.json")[dataset_split]   
        texts = [data[colomn_name] for data in dataset]
        ids = [data['uuid'] for data in dataset]
        
    elif "dwiki_bio" in dataset_name:
        dataset = load_dataset("allenai/dolmino-mix-1124", "wiki", split="train", **kwargs)
        
        if subset_ids is not None:
            subset_ids = set(subset_ids)
            dataset = dataset.filter(lambda example: example['id'] in subset_ids)
            print(f"get subset_ids: {len(subset_ids)}")
            print(f"after subset_ids: {len(dataset)}")
        
        # dataset = dataset.map(truncate_sample_length)    
        # print(f"==== Truncate dataset ====")    
        # print(f"after truncation: {len(dataset)}")           

        texts = [data['text'] for data in dataset]
        ids = [data['id'] for data in dataset]

        texts, ids = chunk_wiki_text(texts, ids)
        
        print(f"Loaded {len(texts)} examples from wiki bio")

    elif dataset_name == "trex++":
        dataset_name_hf, dataset_config = "./data/trex++", ""
        dataset = load_dataset(dataset_name_hf, dataset_config)[dataset_split]

        texts = [data['input_text'] for data in dataset]
        ids = [data['id'] for data in dataset]
        
        print(f"Loaded {len(texts)} examples from trex++")
    
    elif dataset_name == "tofu":
        from mem.utils import prepare_tofu_dataset

        texts, ids = prepare_tofu_dataset()
        
        print(f"Loaded {len(texts)} examples from tofu")

    elif dataset_name == "tofu_qa":
        import os

        qa_dir = "../unlearning/open-unlearning/data/todo_set/missing_evalset"
        json_files = os.listdir(qa_dir)
        json_files = [os.path.join(qa_dir, file) for file in json_files if file.endswith('.json')]
        texts = []
        ids = []
        for json_file in json_files:
            results_lst, ids_lst = convert_json_to_qa(json_file)
            texts.extend(results_lst)
            ids.extend(ids_lst)
        
        print(f"Loaded {len(texts)} examples from tofu")
    else:
        raise ValueError(f"Dataset {dataset_name} is not supported")
    
    return texts, ids

# token F1 functions taken from https://github.com/Lightning-AI/torchmetrics/blob/master/src/torchmetrics/functional/text/squad.py
def _normalize_text(s: str) -> str:
    """Lower text and remove punctuation, articles and extra whitespace."""

    def remove_articles(text: str) -> str:
        return re.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text: str) -> str:
        return " ".join(text.split())

    def remove_punc(text: str) -> str:
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text: str) -> str:
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def _get_tokens(s: str) -> list[str]:
    """Split a sentence into separate tokens."""
    return [] if not s else _normalize_text(s).split()

def token_f1_score(predicted_answer: str, target_answer: str) -> torch.Tensor:
    """Compute F1 Score for two sentences."""
    target_tokens = _get_tokens(target_answer)
    predicted_tokens = _get_tokens(predicted_answer)
    common = Counter(target_tokens) & Counter(predicted_tokens)
    num_same = torch.tensor(sum(common.values()))
    if len(target_tokens) == 0 or len(predicted_tokens) == 0:
        # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
        return torch.tensor(int(target_tokens == predicted_tokens)), torch.tensor(int(target_tokens == predicted_tokens)), torch.tensor(int(target_tokens == predicted_tokens))
    if num_same == 0:
        return torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0)
    precision = 1.0 * num_same / torch.tensor(len(predicted_tokens))
    recall = 1.0 * num_same / torch.tensor(len(target_tokens))
    f1_score = (2 * precision * recall) / (precision + recall)
    return precision, recall, f1_score


def get_save_name(args):
    # Extract model name from the path
    model_name = args.model_id.rstrip('/').split('/')[-1]

    # Detect model size dynamically (default to empty string if not found)
    model_size = (re.search(r'(\d+B)', model_name, re.IGNORECASE) or [""])[0].lower()

    # Define annotator name based on the model type
    if args.annotator == "llama":
        model_args = model_name.split('_', 1)[-1]
        annotator = f"llama{model_size}-{model_args}"
    elif args.annotator == "llama-lora-ft" or args.annotator == "llama-lora-ft-hf":
        model_args = model_name.split('_', 1)[-1]
        annotator = f"llama{model_size}-lora-ft-{model_args}"
    else:
        annotator = args.annotator

    # Construct the save name
    save_name = f"{args.manager}_{annotator}_{args.prompt_id}"
    if args.postprocess:
        save_name += "_post"

    return save_name

def add_shared_context_ids(split_dataset):
    context_to_ids = defaultdict(list)
    for example in tqdm(split_dataset, desc="Building context mapping"):
        context_to_ids[example['context']].append(example['id'])
    
    shared_context_ids = [
        context_to_ids[example['context']] 
        for example in tqdm(split_dataset, desc="Adding shared context IDs")
    ]
    
    return split_dataset.add_column('shared_ids', shared_context_ids)


def chunk_wiki_text(text_lst: List[str], ids_lst: List[str], max_len: int = 750) -> Tuple[List[str], List[str]]:
    """
    Splits long wiki text passages into chunks of at most `max_len` tokens.
    
    Args:
        text_lst (List[str]): List of long text passages.
        ids_lst (List[str]): Corresponding list of passage IDs.
        max_len (int, optional): Maximum length of each chunk in tokens. Defaults to 750.

    Returns:
        Tuple[List[str], List[str]]: Chunked texts and corresponding chunked IDs.
    """
    chunked_texts = []
    chunked_ids = []
    
    for text, pid in zip(text_lst, ids_lst):
        tokens = text.split()  # Simple whitespace tokenization
        
        for i in range(0, len(tokens), max_len):
            chunk = " ".join(tokens[i:i+max_len])
            chunked_texts.append(chunk)
            chunked_ids.append(f"{pid}_chunk{i // max_len}")
    
    print(f"Chunked {len(text_lst)} texts into {len(chunked_texts)} chunks")
    print(f"Max chunk length: {max_len}")
    return chunked_texts, chunked_ids


def truncate_prompt(prompt: str, tokenizer, max_tokens: int = 2048) -> str:
    """
    Truncates the input prompt to ensure it does not exceed the max token limit.

    Args:
        prompt (str): The input prompt text.
        tokenizer: The tokenizer used for tokenizing the prompt.
        max_tokens (int, optional): The maximum allowed token length. Defaults to 2048.

    Returns:
        str: The truncated prompt.
    """
    tokens = tokenizer.encode(prompt, add_special_tokens=False)

    if len(tokens) > max_tokens:
        tokens = tokens[:max_tokens]  # Truncate to max length

    return tokenizer.decode(tokens, skip_special_tokens=True)

def convert_json_to_qa(json_file, output_file=None):

    # Load JSON data
    with open(json_file, 'r') as f:
        data = json.load(f)
    
    results_lst = []
    ids_lst = []
    # Process each item
    for item in data:
        # Extract data
        question = item['question']
        if 'paraphrased_answer' in item:
            answers = [item['answer'], item['paraphrased_answer']] + item['perturbed_answer']
            
            ids_lst.append(json_file.split('/')[-1].split('.')[0]+f"_{item['index']}")
            # Format as QA
            qa_text = f"Question: {question}\nAnswers:\n"
            for i, ans in enumerate(answers, 1):
                qa_text += f"{i}. {ans}\n"
        else:
            answer = item['answer']
            ids_lst.append(json_file.split('/')[-1].split('.')[0]+f"_{item['index']}")
            qa_text = f"Question: {question}\nAnswer:\n{answer}"

        # qa_text += "---\n\n"
        
        # Append to results list
        results_lst.append(qa_text)
    
    return results_lst, ids_lst


if __name__ == "__main__":

    # Example usage:
    text_lst = ["This function splits each passage into chunks of up to 750 tokens and appends a chunk index (_chunk0, _chunk1, etc.) to the original ID for tracking. Let me know if you need modifications!", "Another long passage Here's a function to chunk your wiki text passages while preserving their corresponding IDs. The function ensures that no chunk exceeds max_len=750 tokens, using whitespace-based tokenization."]
    ids_lst = ["abc123", "def456"]
    chunked_texts, chunked_ids = chunk_wiki_text(text_lst, ids_lst, max_len=10)
    print(chunked_texts)
    print(chunked_ids)
    exit()


    import re

    text = "He is born in [dblookup('Donald Trump', 'birthday') = June 14, 1946]"

    # pattern = r'\[dblookup\(([^,]+),\s*([^,]+)\)\s*=\s*(.*?)\]'
    pattern = r'\[dblookup\((.*?),\s*(.*?)\)\s*=\s*(.*?)\]'


    # Find all matches in the text
    matches = re.findall(pattern, text)

    # Extracting the value (June 14, 1946)
    for match in matches:
        result = match[2]  # This is the value you're interested in
        print(result)
        print(match)

