from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset, load_dataset
import json
import torch

def load_model(model_name):
    model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda:0", torch_dtype=torch.bfloat16)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.model_max_length = 1024
    return model, tokenizer

def load_wikitext_dataset(num_points=100000):
    dataset = load_dataset("Salesforce/wikitext", 'wikitext-103-raw-v1')['train']
    dataset = dataset.filter(lambda x: x['text'] != '')
    dataset = dataset.shuffle(seed=42)
    dataset = dataset.select(range(num_points))
    return dataset

def load_canary_dataset(dataset_name):
    if dataset_name == 'ai4privacy':
        dataset = json.load(open('../ai4priv.json'))
        dataset = Dataset.from_dict({'text': dataset})
        dataset = dataset.train_test_split(test_size=0.5, seed=42)
        return dataset
    elif dataset_name == 'ag_news':
        dataset = load_dataset("fancyzhx/ag_news")
        dataset = dataset.remove_columns('label')
        dataset = dataset.shuffle(seed=42)
        dataset['train'] = dataset['train'].select(range(500))
        dataset['test'] = dataset['test'].select(range(500))
        return dataset
    elif dataset_name == 'mimir':
        dataset = load_dataset("iamgroot42/mimir", "full_pile", token="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")['none']
        dataset = dataset.remove_columns(['member', 'member_neighbors', 'nonmember_neighbors'])
        dataset = dataset.rename_column("nonmember", "text")
        dataset = dataset.shuffle(seed=42)
        dataset = dataset.train_test_split(test_size=0.5, seed=42)
        dataset['train'] = dataset['train'].select(range(500))
        dataset['test'] = dataset['test'].select(range(500))
        return dataset
    else:
        raise ValueError(f"Unsupported canary dataset: {dataset_name}")

def load_canary_parent_dataset():
    parent_dataset = load_dataset("ai4privacy/pii-masking-200k")['train']
    canary_dataset = json.load(open('../ai4priv.json'))
    canary_dataset = Dataset.from_dict({'text': canary_dataset})
    parent_and_not_canary_dataset = parent_dataset.filter(lambda x: x['source_text'] not in canary_dataset['text'] and x['language'] == 'en', num_proc=40)
    # assert len(parent_and_not_canary_dataset) == len(parent_dataset) - len(canary_dataset)
    parent_and_not_canary_dataset = parent_and_not_canary_dataset.shuffle(seed=42)
    parent_and_not_canary_dataset = parent_and_not_canary_dataset.rename_column('source_text', 'text')
    parent_and_not_canary_dataset = parent_and_not_canary_dataset.remove_columns([col for col in parent_and_not_canary_dataset.column_names if col != 'text'])
    return parent_and_not_canary_dataset