from datasets import load_dataset
import torch
import random
import numpy as np
import transformers
import os


def set_seed(seed: int):
    """
    Fix PRNG seed for reproducable experiments.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    os.environ["PYTHONHASHSEED"] = str(seed)
    transformers.set_seed(seed)

def get_raw_dataset(dataset_name, cache_dir, max_samples):

    if dataset_name == 'german_wiki':
        text = load_dataset("Cohere/wikipedia-22-12-de-embeddings", 
                            split='train[:10000]', trust_remote_code=True, cache_dir=cache_dir,
        )['text']
    elif dataset_name == 'samsum':
        dataset = load_dataset("Samsung/samsum", trust_remote_code=True, cache_dir=cache_dir).map(
                lambda x: {'text': f"Summarize the following conversation. \n\n### Input:\n{x['dialogue']} \n\nSummary:\n"},
                batched=False,
                load_from_cache_file=False,
            )
        text = dataset['train']['text'] + dataset['validation']['text'] + dataset['test']['text']
    elif dataset_name.startswith('pile'):
        _, subset, split = dataset_name.split('_')
        text = load_dataset("pratyushmaini/llm_dataset_inference", subset, split=split, trust_remote_code=True, cache_dir=cache_dir)['text']
    elif dataset_name == 'mem_pile':
        text = torch.load("./datasets/pile_mem/pile_bs0-100-dedup.pt")                
    else:
        raise ValueError(f"Dataset {dataset_name} not supported")
    print('text length:', len(text))
    text = text[:max_samples]
    if dataset_name != "mem_pile":
        random.shuffle(text)
    return text


def process_text(text, tokenizer, max_length):
    tokenized_text = tokenizer(tokenizer.eos_token.join(text), return_tensors="pt", pad_to_multiple_of=max_length, return_attention_mask=False, padding=True).input_ids
    return tokenized_text.reshape(-1, max_length)

def normalize_tokens(tokens, tokenizer, max_length):
    return tokenizer(tokenizer.batch_decode(tokens, skip_special_tokens = True), return_tensors="pt", return_attention_mask=False, padding=True, truncation=True, max_length=max_length).input_ids


def add_prefix(curr_set, ratio_change, prefix_length, tokenizer, max_length, prefix_type, freq, topk):
    if prefix_type  == 'none':
        return normalize_tokens(curr_set, tokenizer, max_length), None, None, None

    n_selected_samples = int(len(curr_set) * ratio_change)

    if prefix_type == 'rare':
        freq[freq == 0] = torch.max(freq) + 1
        selected_tokens = freq.argsort(descending=False)[:topk*n_selected_samples]
    elif prefix_type == 'common':
        selected_tokens = freq.argsort(descending=True)[:topk*n_selected_samples]
    elif prefix_type == 'random':
        selected_tokens = torch.randperm(tokenizer.vocab_size)[:topk*n_selected_samples]
    elif prefix_type == 'invisible':
        invisible_chars = [
             '\u200a',  # Hair Space
             '\u200b',  # Zero Width Space
             '\u200c',  # Zero Width Non-Joiner
             '\u202f',  # Narrow No-Break Space
             '\xa0',  # Non-breaking Space
             '\xad',  # Soft Hyphen
             '\ufeff',  # Zero Width No-Break Space
             '\u2028',  # Line Separator
             '\u2009',  # thin space
             '\u3000',  # ideographic space
        ]
        selected_tokens = torch.tensor(tokenizer.encode("".join(invisible_chars)))
    else:
        raise ValueError(f"Unknown prefix type: {prefix_type}")

    if prefix_type == 'invisible':
        selected_tokens = selected_tokens.unsqueeze(0).repeat(n_selected_samples, 1)
    else:
        selected_tokens = selected_tokens.flatten()[torch.randperm(len(selected_tokens))].reshape(n_selected_samples, topk)

    selected_samples = torch.randperm(len(curr_set))[:n_selected_samples]
    
    curr_prefix_tokens = torch.stack([t[torch.randint(0, selected_tokens.shape[1], (prefix_length,))] for t in selected_tokens])

    curr_set[selected_samples] = torch.cat([curr_prefix_tokens, curr_set[selected_samples][:, :-prefix_length]], dim=1)
    return normalize_tokens(curr_set, tokenizer, max_length), selected_samples, curr_prefix_tokens, selected_tokens



def get_preprocessed_dataset(dataset, cache_dir, tokenizer, max_length,
                             split_id, topk,
                             prefix_type, prefix_length, ratio_change,
                             z_ratio=0.1, max_samples=50_000
                             ):
    set_seed(42)
    text = get_raw_dataset(dataset, cache_dir, max_samples)
    set_seed(42 + split_id//2)
    z_sz = int(z_ratio*len(text))
    z_text = text[:z_sz]
    full_train_text = text[z_sz:]
    keep = torch.rand((2, len(full_train_text))).argsort(0)[split_id%2].bool()
    train_text = [a for a, b in zip(full_train_text, keep) if b]
    val_text = [a for a, b in zip(full_train_text, keep) if not b]

    if dataset != 'mem_pile':
        train_text = process_text(train_text, tokenizer, max_length)
        val_text = process_text(val_text, tokenizer, max_length)
        z_text = process_text(z_text, tokenizer, max_length)
    else:
        train_text=torch.stack(train_text)
        val_text = torch.stack(val_text)
    freq = torch.bincount(torch.cat([train_text.flatten(), val_text.flatten(), z_text.flatten()]))
    if tokenizer.pad_token_id is not None:
        freq[tokenizer.pad_token_id] = 0
    if tokenizer.eos_token_id is not None:
        freq[tokenizer.eos_token_id] = 0

    result = {
        'keep': keep,
        'prefix_type': prefix_type,
        'dataset': dataset,
    }

    result['train_tokens'], result['train_selected_samples'], result['train_prefix_tokens'], result['train_selected_tokens']  = add_prefix(train_text[:2000], ratio_change, prefix_length, tokenizer, max_length, prefix_type, freq, topk)
    result['val_tokens'], result['val_selected_samples'], result['val_prefix_tokens'], result['val_selected_tokens']  = add_prefix(val_text[:2000], ratio_change, prefix_length, tokenizer, max_length, prefix_type, freq, topk)
    result['z_tokens'], result['z_selected_samples'], result['z_prefix_tokens'], result['z_selected_tokens']  = add_prefix(z_text[:500], ratio_change, prefix_length, tokenizer, max_length, prefix_type, freq, topk)
    return result
