import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
from torch.utils.data import Dataset, random_split
from tqdm import tqdm
import json

class TextDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length=512):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        encoding = self.tokenizer(
            text,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
            padding= 'max_length'
        )
        return {
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze()
        }


def calculate_perplexity(model, dataloader, device="cuda"):
    model.eval()
    perplexities = []
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Calculating Perplexities"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = input_ids.clone()

            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss.item()
            perplexity = 2 ** loss  
            perplexities.append(perplexity)

    return perplexities

def perplexity_pruning(
    raw_texts,
    selection_rate=0.5,
    ref_split_size=0.2,
    model_name="gpt2",
    batch_size=8,
    max_length=512,
    device="cuda"
):

    print("model_name:", model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

    ref_size = int(len(raw_texts) * ref_split_size)
    train_size = len(raw_texts) - ref_size
    ref_texts, train_texts = random_split(raw_texts, [ref_size, train_size])
    

  
    ref_dataset = TextDataset(ref_texts, tokenizer, max_length)
    train_dataset = TextDataset(train_texts, tokenizer, max_length)


    ref_loader = torch.utils.data.DataLoader(ref_dataset, batch_size=batch_size)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)

    print("Calculating perplexities...")
    perplexities = calculate_perplexity(model, train_loader, device)
    

    min_percentile_low = 0.0
    max_percentile_low = selection_rate

    min_percentile_media = 0.5 - (selection_rate / 2)
    max_percentile_media = 0.5 + (selection_rate / 2)

    min_percentile_high = 1 - selection_rate
    max_percentile_high = 1.0


    # 计算经验CDF并筛选数据
    sorted_indices = np.argsort(perplexities)
    sorted_perplexities = np.array(perplexities)[sorted_indices]
    empirical_cdf = np.linspace(0, 1, len(sorted_perplexities))
    
    min_pplx_low = np.quantile(sorted_perplexities, min_percentile_low)
    max_pplx_low = np.quantile(sorted_perplexities, max_percentile_low)

    print(min_pplx_low, max_pplx_low)
    min_pplx_media = np.quantile(sorted_perplexities, min_percentile_media)
    max_pplx_media = np.quantile(sorted_perplexities, max_percentile_media)

    print(min_pplx_media, max_pplx_media)

    min_pplx_high = np.quantile(sorted_perplexities, min_percentile_high)
    max_pplx_high = np.quantile(sorted_perplexities, max_percentile_high)
    print(min_pplx_high, max_pplx_high)


    flow = open("./data/flowpppldata.jsonl", "w", encoding="utf-8")
    fmedia = open("./data/mediappldata.jsonl", "w", encoding="utf-8")
    fhigh = open("./data/highppldata.jsonl", "w", encoding="utf-8")


    # 筛选数据
    pruned_texts = []
    for i, (text, pplx) in enumerate(zip(train_texts, perplexities)):
        out = {}
        out["text"] = text
        out["ppl"] = pplx

        if min_pplx_low <= pplx <= max_pplx_low:
            flow.write(json.dumps(out, ensure_ascii=False) + '\n')

        if min_pplx_media <= pplx <= max_pplx_media:
            fmedia.write(json.dumps(out, ensure_ascii=False) + '\n')

        if min_pplx_high <= pplx <= max_pplx_high:
            fhigh.write(json.dumps(out, ensure_ascii=False) + '\n')

    print(f"Original dataset size: {len(train_texts)}")
    print(f"Pruned dataset size: {len(pruned_texts)}")
    print(f"Pruning rate: {1 - len(pruned_texts)/len(train_texts):.2f}")
    
    return pruned_texts


if __name__ == "__main__":
    raw_texts = []
    with open('language.jsonl', 'r') as f1:
        for line in f1:
            obj = json.loads(line.strip())
            raw_texts.append(obj["text"])

    print("raw text length: ", len(raw_texts))    
  
    pruned_data = perplexity_pruning(
        raw_texts=raw_texts,
        selection_rate=0.3,      
        ref_split_size=0.2,       
        model_name="global_step9147_lion",      
        batch_size=1,
        max_length=1024,
        device="cuda" if torch.cuda.is_available() else "cpu"
    )
    
    print("Pruning completed!")
