import torch
import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import datasets
# get most memorized sampels form a dataset using Pythia1b
# use perplexity as criterion
# chunk dataset with same length eg 128
# run inference and get perplexity of each chunk 
# get index of the most memorized chunk according to a number of samples we want
# save index for reuse of the dataset

def parse_args():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--cache_dir', type=str, default="cache")
    parser.add_argument('--max_length', type=int, default=128)
    parser.add_argument('--output_dir', type=str, default="memorized_dataset")

    return parser.parse_args()
def main():
    args = parse_args()
    print(args)
    model_name = "EleutherAI/pythia-1b"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    max_length = args.max_length
    subsets = sub_sets=['arxiv', 'bookcorpus2', 'books3', 'cc', 'enron', 'europarl', 'freelaw', 'github', 'gutenberg', 'hackernews', 'math', 'nih', 'opensubtitles', 'openwebtext2', 'philpapers', 'stackexchange', 'ubuntu', 'uspto', 'wikipedia', 'youtubesubtitles']
    split=["val"]
    model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    for split_name in split:
        perplexity = []
        subset_list = []
        chunked_text =[]
        for subset in subsets:
            print(f"Processing {subset} {split_name}")
            dataset = load_dataset("pratyushmaini/llm_dataset_inference", subset, split=split_name, 
                                 trust_remote_code=True, cache_dir=args.cache_dir)
            text = dataset['text']
            # merge text sample into a big whole text and chunk into max_sequence_length
            # save chunked text into chunked_text list
            # get perplexity of each chunk and add to perplexity list
            # add subset name to subset list
            tokenized_text = tokenizer(tokenizer.eos_token.join(text), return_tensors="pt", padding=True, truncation=True, pad_to_multiple_of=max_length)
            input_ids = tokenized_text["input_ids"].reshape(-1, max_length)
            attention_mask = tokenized_text["attention_mask"].reshape(-1, max_length)
            for i in range(len(input_ids)):
                with torch.no_grad():
                    inp = input_ids[i].unsqueeze(0).to(device)
                    att = attention_mask[i].unsqueeze(0).to(device)
                    outputs = model(input_ids=inp, attention_mask=att, labels=inp)
                    perplexity.append(torch.exp(outputs.loss).item())
                    subset_list.append(subset)
                    chunked_text.append(tokenizer.decode(input_ids[i]))

        # sort the perplexity and split into 20 datasets with different perplxity range
        # get the chunked text of each subset and also the subset information 
        # save each subset into a file with subset name as filename, each sample in the set is a chunked text and a subset 
        # also save the perplexity range of the subset 
        perplexity = torch.tensor(perplexity)
        sorted_perplexity, indices = torch.sort(perplexity)
        num_samples = len(perplexity)
        samples_per_subset = num_samples // 20
        subsets_ranges = []

        for i in range(20):
            start_idx = i * samples_per_subset
            end_idx = (i + 1) * samples_per_subset if i < 19  else num_samples
            subset_indices = indices[start_idx:end_idx]
            
            subset_texts = [chunked_text[idx] for idx in subset_indices]
            subset_sources = [subset_list[idx] for idx in subset_indices]
            perplexity_range = (sorted_perplexity[start_idx].item(), sorted_perplexity[end_idx-1].item())
            
            new_sub_set = datasets.Dataset.from_dict({"text": subset_texts, "source": subset_sources})
            # get extra info of the dataset
            new_sub_set.save_to_disk(f"{args.output_dir}/{split_name}_{perplexity_range[0]}_{perplexity_range[1]}")
if __name__=="__main__":
    main() 
      