
from typing import Optional, Tuple, List
from datasets import load_dataset, Dataset, concatenate_datasets
import random
from src.utils import set_seed
from multiprocessing import cpu_count
from pathlib import Path
import os
import json
import gzip
import glob
import pandas as pd
import hashlib
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
from math import ceil
from secret_utils import DICT_GET_PRIOR



from src.logger import logger

def filter_pile(subset):
    def f(x):
        return x['meta']['pile_set_name'].replace(' ', '').replace('(', '').replace(')', '') == subset
    return f



def get_dataset_from_folder_with_gz(folder_path, file_prefix):
    folder_path = Path(folder_path)
    data_files = [f for f in folder_path.glob('*.json.gz') if f.name.startswith(file_prefix)]

    def extract_text_column(file_path):
        with gzip.open(file_path, 'rt') as f:
            for line in f:
                record = json.loads(line.strip())
                if 'text' in record:
                    yield {'text': record['text']}

    consistent_files = [str(file_path) for file_path in data_files][:1]

    import tempfile
    temp_jsonl_file = tempfile.NamedTemporaryFile(delete=False, suffix=".jsonl")
    with open(temp_jsonl_file.name, 'w') as temp_file:
        for file_path in consistent_files:
            for record in extract_text_column(file_path):
                temp_file.write(json.dumps(record) + '\n')

    dataset = load_dataset('json', data_files=[temp_jsonl_file.name])
    return dataset


def get_raw_dataset(dataset, max_samples, **kwargs):
    n_jobs = kwargs.get('n_jobs', int(1.2 * cpu_count()))
    
    samples_range = f"{kwargs.get('begin_id', 0)}:{kwargs.get('begin_id', 0) + max_samples}"
    
    logger.info(f"Loading {dataset} with {max_samples} samples from {samples_range}")
    
    
    if dataset.startswith('pile-train'):
        subset = dataset.split('_')[1].replace(' ', '').replace('(', '').replace(')', '')
        texts = load_dataset("json",
                            data_files={'train': [
                                f'https://huggingface.co/datasets/monology/pile-uncopyrighted/resolve/main/train/{i:02}.jsonl.zst' 
                                for i in range(30)]},
                            split=f'train[{samples_range}]',
                            num_proc=n_jobs,
        ).filter(filter_pile(subset), num_proc=n_jobs)['text']
    #     texts = concatenate_datasets([
    #         Dataset.from_file(str(arrow_file)) for arrow_file in [k for k in Path(os.getenv('PILE_DIR')).rglob('*train*.arrow')]
    #     ]).filter(filter_pile(subset), num_proc=n_jobs)['text']
    
    elif dataset.startswith('pile-test'):
        subset = dataset.split('_')[1]
        texts = load_dataset("json",
                            data_files={'test': [
                                'https://huggingface.co/datasets/monology/pile-uncopyrighted/resolve/main/test.jsonl.zst',
                                ]},
                            split=f'test[{samples_range}]',
                            num_proc=n_jobs,
    ).filter(filter_pile(subset), num_proc=n_jobs)['text']
        
        
    #     texts = concatenate_datasets([
    #         Dataset.from_file(str(arrow_file)) for arrow_file in [k for k in Path(os.getenv('PILE_DIR')).rglob('*test*.arrow')]
    #     ]).filter(filter_pile(subset), num_proc=n_jobs)['text']
        
    elif dataset.startswith('pile-val'):
        subset = dataset.split('_')[1]
        texts = load_dataset("json",
                            data_files={'val': [
                                'https://huggingface.co/datasets/monology/pile-uncopyrighted/resolve/main/val.jsonl.zst',
                                ]},
                            split=f'val[{samples_range}]',
                            num_proc=n_jobs,
    ).filter(filter_pile(subset), num_proc=n_jobs)['text']
    #     texts = concatenate_datasets([
    #         Dataset.from_file(str(arrow_file)) for arrow_file in [k for k in Path(os.getenv('PILE_DIR')).rglob('*val*.arrow')]
    #     ]).filter(filter_pile(subset), num_proc=n_jobs)['text']
        
    elif dataset.startswith('dolma'):
        subset = dataset.split('_')[1]
        texts = get_dataset_from_folder_with_gz(
            Path(os.getenv('DOLMA_DATA_DIR')
            ), subset)['train']['text']

    elif dataset.startswith('proof-pile-2'):
        split = dataset.split('_')[1]
        subset = dataset.split('_')[2] # ["algebraic-stack", "arxiv", "open-web-math",]
        texts = load_dataset("EleutherAI/proof-pile-2", 
                             subset, split=f'{split}[{samples_range}]', 
                             trust_remote_code=True, num_proc=n_jobs)['text']
    elif dataset.startswith('stack'):
        # valid subsets
        # "assembly", "batchfile", "c++", "c", "c-sharp", "cmake", "css", "dockerfile", "fortran", "go", "haskell", "html", "java", "javascript", 
        # "julia", "lua", "makefile", "markdown", "perl", "php", "powershell", "python", "ruby", "rust", "scala", "shell", "sql", "tex", "typescript", "visual-basic"   
        subset = dataset.split('_')[1]
        texts = load_dataset("bigcode/the-stack-dedup", 
                            split=f'train[{samples_range}]',
                            data_dir=f'data/{subset}',
                            num_proc=n_jobs,
        )['content']
    elif dataset.startswith('hellaswag'):
        split = dataset.split('_')[1]
        # From https://github.com/EleutherAI/lm-evaluation-harness/blob/867413f8677f00f6a817262727cbb041bf36192a/lm_eval/tasks/hellaswag/utils.py
        def preprocess(text):
            import re
            text = text.strip()
            # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
            text = text.replace(" [title]", ". ")
            text = re.sub("\\[.*?\\]", "", text)
            text = text.replace("  ", " ")
            return text
        def process_docs(dataset):
            def _process_doc(doc):
                ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
                out_doc = {
                    # "query": preprocess(doc["activity_label"] + ": " + ctx),
                    # "choices": [preprocess(ending) for ending in doc["endings"]],
                    # "gold": int(doc["label"]),
                    "text": preprocess(doc["activity_label"] + ": " + ctx) +' '+ preprocess(doc["endings"][int(doc["label"])]),
                }
                return out_doc

            return dataset.map(_process_doc)
        texts = process_docs(load_dataset('Rowan/hellaswag', split=f'{split}[{samples_range}]'))['text']
    elif dataset.startswith('mmlu'):
        def process_mmlu(x):
            question = x["question"].strip()
            choices = x["choices"]
            option_a = choices[0].strip()
            option_b = choices[1].strip()
            option_c = choices[2].strip()
            option_d = choices[3].strip()
            answer = chr(x["answer"] + ord('A'))
            return {
                'text': f"{question}\nA. {option_a}\nB. {option_b}\nC. {option_c}\nD. {option_d}\nAnswer: {answer}",
            }
        texts = load_dataset('cais/mmlu', f'all')['test'].map(process_mmlu)['text'][kwargs.get('begin_id', 0):kwargs.get('begin_id', 0) + max_samples]

    elif dataset.startswith('gsm8k'):
        split = dataset.split('_')[1]
        def process(x):
            return {
                'text': f'Question: {x["question"]}\nAnswer: {x["answer"]}',
            }
        texts = load_dataset('openai/gsm8k', 'main', split=f'{split}[{samples_range}]').map(process)['text']

    elif dataset.startswith('mimir'): # mimir-arxiv-member
        dataset_name = dataset.split('-')[1] # [arxiv, dm_mathematics, github, hackernews, pile_cc, pubmed_central, wikipedia_(en), full_pile, c4, temporal_arxiv, temporal_wiki]
        is_member = dataset.split('-')[2] # [member, nonmember]
        if dataset_name in ['full_pile', 'c4']:
            split = 'none'
        else:
            split = 'ngram_13_0.8'
        texts = load_dataset('iamgroot42/mimir', dataset_name, split=f'{split}[{samples_range}]', trust_remote_code=True)[is_member]
    
    elif dataset.startswith('persona'):
        def preprocess_persona(example):
            return {"text": f"User 1 Personas:\n{example['user 1 personas']}\n\nUser 2 Personas:\n{example['user 2 personas']}\n\nConversation:\n{example['Best Generated Conversation']}"}
        
        split = dataset.split('_')[1]
        
        if "all" in split:
            texts = [
                load_dataset("google/Synthetic-Persona-Chat", split=split, trust_remote_code=True, num_proc=n_jobs)
                    for split in ['train', 'test', 'validation']
            ]
            texts = concatenate_datasets(texts).map(preprocess_persona)['text']
        elif 'canar' in split:
            texts = [
                load_dataset("google/Synthetic-Persona-Chat", split=split, trust_remote_code=True, num_proc=n_jobs)
                    for split in ['test', 'validation']
            ]
            texts = concatenate_datasets(texts).map(preprocess_persona)['text']
            
        else: 
            texts = load_dataset("google/Synthetic-Persona-Chat", 
                             split=f'{split}[{samples_range}]', 
                             trust_remote_code=True, 
                             num_proc=n_jobs).map(preprocess_persona)['text']
        
    else:
        raise ValueError(f"Group {dataset} not found")

    logger.info(f'text length: {len(texts)}')
    texts = texts[:max_samples]
    random.shuffle(texts)
    return texts


def normalize_tokens(tokens, tokenizer, max_length):
    decoded_texts = tokenizer.batch_decode(tokens['input_ids'], skip_special_tokens=True)
    return tokenizer(
        decoded_texts,
        return_tensors="pt",
        return_attention_mask=True,
        padding=True,
        truncation=True,
        max_length=max_length
    )


def process_text(texts, tokenizer, max_length):
    tokenized_texts = tokenizer(
        texts, 
        return_tensors="pt",
        padding=True, 
        truncation=True, 
        max_length=max_length,
        return_attention_mask=True
    )
    normalized_tokens = normalize_tokens(tokenized_texts, tokenizer, max_length)
    return normalized_tokens


def process_canaries(files, cardinality, available_secrets, threshold=15_000):
    """
    Process canaries from filtered data.

    Args:
        input_path (str): Path to the input files (supports wildcards).
        cardinality (int): Number of canaries to retrieve.
        available_secrets (list): List of available secret types to process.
    Returns:
        list: A list of extracted canaries.
    """
    canaries = []
    cnt = 0
    for file_path in files:
        df = pd.read_pickle(file_path)
        df = df[df['secret_type'].isin(available_secrets)]
        for _, ((secret_type, prefix, suffix, extra, dataset), curr_df) in enumerate(
            df.groupby(['secret_type', 'prefix', 'suffix', 'extra', 'dataset'])
        ):
            # Generate a seed based on group attributes
            seed = int(
                hashlib.sha256(
                    (str(secret_type) + str(prefix) + str(suffix) + str(extra) + str(dataset)).encode()
                ).hexdigest(),
                16
            ) & 0xFFFFFFFF
            set_seed(seed)

            try:
                prior_secrets = DICT_GET_PRIOR[secret_type](extra, cardinality)
                canaries += [f'{prefix}_{s}_{suffix}' for s in prior_secrets]
                cnt += 1
                print(cnt)
                if cnt >= threshold:
                    return canaries
            except ValueError as e:
                logger.error(f"Error processing {secret_type}: {e}")
                continue
            
    return canaries


def get_preprocessed_dataset(
    dataset: str | List[str], 
    tokenizer, 
    max_length,
    max_samples=50_000,
    canaries_setup: Optional[Tuple[int, List[str]]] = None
    ):

    set_seed(42)
    list_of_texts = get_raw_dataset(dataset, max_samples)
    if canaries_setup is not None: 
        cardinality, available_secrets = canaries_setup
        number_of_available_rows = process_canaries(
            files=[path for path in glob.glob('canary_mia/filtered/pile*pkl') if 'train' not in path],
            cardinality=1,
            available_secrets=available_secrets
        ) # to get ratio
        logger.info(f"Number of available rows: {len(number_of_available_rows)}")
        ratio = ceil(len(list_of_texts) / len(number_of_available_rows))

        list_of_canaries = process_canaries(
            files=[path for path in glob.glob('canary_mia/filtered/pile*pkl') if 'train' not in path],
            cardinality=cardinality*ratio,
            available_secrets=available_secrets
        )
        
        canaries = [list_of_canaries[i:i+cardinality] for i in range(0, len(list_of_canaries), cardinality)]
        
        combined = []
        for i, text in enumerate(list_of_texts):   
            combined.append([f'{c} {text}' for c in canaries[i]])

        # check duplicated canaries
        assert len(set(item for sublist in combined for item in sublist)) == sum(len(sublist) for sublist in combined), "Duplicates found"
 
        logger.info(f"Cardinality: {cardinality} (number of combined texts: {len(combined)})")
        return [Dataset.from_dict(process_text(comb, tokenizer, max_length)) for comb in combined if len(comb) == cardinality]   
    dataset = process_text(list_of_texts, tokenizer, max_length)
    
    return Dataset.from_dict(dataset).with_format("torch")