from datasets import load_dataset, Dataset, DatasetDict
from typing import Tuple, Dict
import json
import os
import re
from datasets import concatenate_datasets
import yaml
from copy import deepcopy
from memgpt.constants import PROMPTS_DIR, EVALSET_PATH
from memgpt.trl.utils.utils_filter import convert_to_raw_dataset, convert_to_special_db_tokens_format, filter_invalid_dblookups

############################################
# Datasets
############################################

def load_trainset(script_args) -> Tuple[Dataset, Dict[str, Dataset]]:
    """
    Load the dataset for supervised fine-tuning (SFT) and evaluation sets.
    
    Args:
        script_args: An object containing dataset-related arguments (e.g., dataset_name, dataset_config, dataset_train_split).
        
    Returns:
        Tuple[Dataset, Dict[str, Dataset]]: A tuple containing the training dataset and a dictionary of evaluation datasets.
    """
    if 'json' in script_args.dataset_name:
        dataset = load_dataset('json', data_files=script_args.dataset_name, field="examples")
    else:
        try:
            dataset = load_dataset(script_args.dataset_name, name=getattr(script_args, 'dataset_config', None))
        except:
            dataset = load_dataset(script_args.dataset_name, name=getattr(script_args, 'dataset_config', None), field="examples")
    
    train_dataset = dataset[getattr(script_args, 'dataset_train_split', 'train')]

    if 'annotated_text' not in train_dataset.column_names:
        if 'squad' in script_args.dataset_name:
            train_dataset = train_dataset.rename_column('context', 'annotated_text')
        elif 'dolmino' in script_args.dataset_name:
            train_dataset = train_dataset.rename_column('text', 'annotated_text')
        else:
            raise ValueError(f"Dataset {script_args.dataset_name} is not supported")

    return train_dataset

def combine_evalsets(cleaned_eval_dataset: Dict[str, Dataset], tokenizer, prompt) -> Dataset:
    eval_path = "/path/to/version4/eval_dataset/squad-eval100_dwiki-eval100_chatgpt_gpt4o-v7.1.json"
    if os.path.exists(eval_path):  
        eval_dataset = load_dataset("json", data_files=eval_path, split="train", field="examples")
        eval_dataset = eval_dataset.map(lambda x: format_chat(x, tokenizer, prompt), batched=False)

    else:
        raise FileNotFoundError(f"Eval dataset not found at {eval_path}.")  
    eval_dataset_dict = {
        "org": eval_dataset,
        "cleaned": cleaned_eval_dataset
    }
    return eval_dataset_dict

def seperate_evalsets(tokenizer, prompt) -> Dataset:
    eval_path_lst = {
        "dwiki":"/path/to/version4/eval_dataset/dwiki-eval100_chatgpt_gpt4o-v7.1_cleaned.json",
        "squad":"/path/to/version4/eval_dataset/squad-eval100_chatgpt_gpt4o-v7.1_cleaned.json"
    }
    
    eval_dataset_dict = {}
    for key, eval_path in eval_path_lst.items():
        if os.path.exists(eval_path):  
            eval_dataset = load_dataset("json", data_files=eval_path, split="train", field="examples")
            eval_dataset = eval_dataset.map(lambda x: format_chat(x, tokenizer, prompt), batched=False)
            eval_dataset_dict[key] = eval_dataset
        else:
            raise FileNotFoundError(f"Eval dataset not found at {eval_path}.")  

    return eval_dataset_dict

def load_evalsets(setting: str = 'LMLM_eval') -> Dataset:
    # eval_path_lst = {
    #     "dwiki":"/path/to/version4/eval_dataset/dwiki-eval100_chatgpt_gpt4o-v7.1_cleaned.json",
    #     "dwiki_annotator": "/path/to/version4/eval_dataset/dwiki-eval100_llama8b-squad-train1k_dwiki-train1k_chatgpt_gpt4o-v7.1_cleaned_ep10_merged_llama-v6.1_cleaned.json",
    #     "dwiki_raw": "/path/to/version4/eval_dataset/dwiki-eval100_raw_cleaned.json"
    # }
    with open(EVALSET_PATH, 'r') as f:
        config = yaml.safe_load(f)
    
    if setting not in config:
        raise ValueError(f"Setting '{setting}' not found in YAML config.")
    
    eval_path_lst = config[setting]
    
    eval_dataset_dict = {}
    for key, eval_path in eval_path_lst.items():
        if os.path.exists(eval_path):  
            eval_dataset = load_dataset("json", data_files=eval_path, split="train", field="examples")
            eval_dataset_dict[key] = eval_dataset
        else:
            raise FileNotFoundError(f"Eval dataset not found at {eval_path}.")  
    return eval_dataset_dict

def load_trainsets() -> Dataset:
    train_path_lst = {
        "dwiki":"/path/to/version4/dwiki6.1M",
        "fineweb": "/path/to/version4/fineweb/fineweb5.2M",
    }
    
    train_dataset_lst = []
    for key, train_path in train_path_lst.items():
        if os.path.exists(train_path):  
            if 'json' in train_path:
                dataset = load_dataset('json', data_files=train_path, field="examples")
            else:
                try:
                    dataset = load_dataset(train_path)
                except:
                    dataset = load_dataset(train_path, field="examples")

            dataset = dataset["train"]
            train_dataset_lst.append(dataset)
            
    if train_dataset_lst:
        train_dataset = concatenate_datasets(train_dataset_lst)
        # Shuffle after concatenation
        train_dataset = train_dataset.shuffle(seed=42)
        print(f"Loaded training dataset from {train_path_lst.values()} with {len(train_dataset)} samples.")
        return train_dataset
    else:
        raise ValueError("No datasets were loaded successfully")

def prepare_eval_data(use_special_dblookup_tokens=False, is_plain_baseline=False) -> Tuple[Dataset, Dict[str, Dataset]]:
    eval_datasets = load_evalsets()

    eval_datasets = {key: eval_dataset.map(filter_invalid_dblookups) for key, eval_dataset in eval_datasets.items()}

    
    for sub_evalset_name, sub_evalset in list(eval_datasets.items()):
        if is_plain_baseline:
            # add convert_to_raw_dataset to evalset
            eval_datasets[sub_evalset_name + "_raw"] = convert_to_raw_dataset(sub_evalset)

        if use_special_dblookup_tokens and not is_plain_baseline:
            sub_evalset = convert_to_special_db_tokens_format(sub_evalset)
            eval_datasets[sub_evalset_name] = sub_evalset
    
    return eval_datasets

def prepare_pretrain_data(script_args, use_special_dblookup_tokens=False, is_plain_baseline=False) -> Tuple[Dataset, Dict[str, Dataset]]:
    if "dwiki6.1M+fineweb5.2M" in script_args.dataset_name:
        train_dataset = load_trainsets()
    else:
        train_dataset = load_trainset(script_args)
    if "tofu" in script_args.dataset_name:
        eval_datasets = load_evalsets(setting="tofu_eval")
    else:
        eval_datasets = load_evalsets(setting="LMLM_eval")

    current_filter_version = 1 # Force reprocessing when you make changes
    train_dataset = train_dataset.map(lambda x: filter_invalid_dblookups(x, version=current_filter_version)) 
    eval_datasets = {key: eval_dataset.map(lambda x: filter_invalid_dblookups(x, version=current_filter_version)) for key, eval_dataset in eval_datasets.items()}
    
    if is_plain_baseline:
        train_dataset = convert_to_raw_dataset(train_dataset)
        

    if use_special_dblookup_tokens and not is_plain_baseline:
        train_dataset = convert_to_special_db_tokens_format(train_dataset)

    for sub_evalset_name, sub_evalset in list(eval_datasets.items()):
        # add convert_to_raw_dataset to evalset
        # BUG:!!!
        # sub_evalset_copy = sub_evalset.map(lambda x: x)
        # eval_datasets[sub_evalset_name + "_raw"] = convert_to_raw_dataset(sub_evalset_copy)

        if use_special_dblookup_tokens and not is_plain_baseline and "raw" not in sub_evalset_name:
            eval_datasets[sub_evalset_name] = convert_to_special_db_tokens_format(sub_evalset)
            print("coverting to specail token .........................")
    return train_dataset, eval_datasets

def prepare_instruction_tuning_data(script_args, tokenizer, use_prompt=True) -> Tuple[Dataset, Dataset]:
    """Loads and processes training & evaluation datasets for instruction tuning."""
    if 'json' in script_args.dataset_name:
        dataset = load_dataset('json', data_files=script_args.dataset_name, field="examples")
        train_dataset = dataset[getattr(script_args, 'dataset_train_split', 'train')]

    else:
        train_dataset = load_dataset(script_args.dataset_name, name=getattr(script_args, 'dataset_config', None))

    # TODO: Load evaluation dataset
    eval_dataset_name = getattr(script_args, 'eval_dataset_name', re.sub(r'train\d+k', 'eval100', script_args.dataset_name))
    print(f"Eval dataset name: {eval_dataset_name}")        
    eval_dataset = load_dataset("json", data_files=eval_dataset_name, split="train", field="examples")

    # Load the instruction prompt (make configurable)
    if use_prompt:
        prompt_id = getattr(script_args, 'prompt_id', "llama-v6")
        prompt = InstructionPrompt(prompt_id)
    else:
        prompt = None

    # Process datasets
    train_dataset = train_dataset.map(lambda x: format_chat(x, tokenizer, prompt), batched=False)
    eval_dataset = eval_dataset.map(lambda x: format_chat(x, tokenizer, prompt), batched=False)

    # test add standard eval dataset
    # eval_dataset = combine_evalsets(eval_dataset, tokenizer, prompt)
    # eval_dataset = seperate_evalsets(tokenizer, prompt)

    return train_dataset, eval_dataset


def format_chat(data, tokenizer, prompt=None):
    """Formats the input text and annotation using the provided instruction prompt."""
    if prompt:
        full_text = prompt(data['text'], data['annotated_text'])
        formatted_text = tokenizer.apply_chat_template(full_text, tokenize=False, add_generation_prompt=False)
        # instruction_template = "<|start_header_id|>user<|end_header_id|>" # TODO
    else:
        full_text = "### Input:" + data['text'] + " ### Output:" + data['annotated_text']
        formatted_text = full_text

    return {
            "formatted_text": formatted_text,
        }

class InstructionPrompt:
    """Handles loading and formatting of instruction-based prompts."""
    
    def __init__(self, prompt_id):
        self.prompt_id = prompt_id
        path = os.path.join(PROMPTS_DIR, f"{prompt_id}.json")
        if not os.path.exists(path):
            raise FileNotFoundError(f"Prompt file not found at {path}.")
        
        with open(path, "r", encoding="utf-8") as f:
            self.prompt = json.load(f)

    def __call__(self, text, annotation):
        """
        Fills placeholders in the prompt with the provided text and annotation.

        Args:
            text (str): The text to insert into the prompt.
            annotation (str): The annotation to insert into the prompt.

        Returns:
            list: A list of dictionaries with placeholders replaced.
        """
        filled_prompt = deepcopy(self.prompt)

        for prompt_dict in filled_prompt:
            if "INSERT_TEXT" in prompt_dict['content']:
                prompt_dict['content'] = prompt_dict['content'].replace("[INSERT_TEXT]", text)
            if "INSERT_ANNOTATION" in prompt_dict['content']:
                prompt_dict['content'] = prompt_dict['content'].replace("[INSERT_ANNOTATION]", annotation)

        return filled_prompt