import json
import math
import torch
from tqdm import tqdm

def get_perplexity(model, tokenizer, prompt, output, top_k=20):
    device = model.device
    log_probs = []

    input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)["input_ids"].to(device)
    output_tokens = tokenizer(output, return_tensors="pt", add_special_tokens=False)["input_ids"].to(device)[0]

    current_input_ids = input_ids.clone()

    for token_id in output_tokens:
        with torch.no_grad():
            outputs = model(current_input_ids)
        
        logits = outputs.logits[:, -1, :]
        log_probs_tensor = torch.nn.functional.log_softmax(logits, dim=-1)

        top_k_values, top_k_indices = torch.topk(log_probs_tensor, top_k, dim=-1)

        if token_id in top_k_indices:
            token_log_prob = log_probs_tensor[0, token_id].item()
        else:
            token_log_prob = top_k_values[0, -1].item()

        log_probs.append(token_log_prob)

        current_input_ids = torch.cat([current_input_ids, token_id.unsqueeze(0).unsqueeze(0)], dim=-1)

    if len(log_probs) == 0:
        return float("inf")

    avg_log_prob = sum(log_probs) / len(log_probs)
    perplexity = math.exp(-avg_log_prob)

    return perplexity

def filter_by_perplexity(input_file, output_file, lora_model, lora_tokenizer, base_output_file, lora_output_file, threshold):
    # Load data
    with open(input_file, "r") as f:
        samples = json.load(f)
    with open(base_output_file, "r") as f:
        base_outputs = json.load(f)
    with open(lora_output_file, "r") as f:
        lora_outputs = json.load(f)

    assert len(samples) == len(base_outputs) == len(lora_outputs), "Mismatch in dataset lengths"

    filtered = []
    print("Calculating perpelxity...")
    for sample, base_out, lora_out in tqdm(zip(samples, base_outputs, lora_outputs), total=len(samples), desc="Filtering"):
        prompt = sample["generated_instruction"]
        base_output = base_out["output"]
        lora_output = lora_out["output"]

        try:
            base_ppl = get_perplexity(lora_model, lora_tokenizer, prompt, base_output)
            lora_ppl = get_perplexity(lora_model, lora_tokenizer, prompt, lora_output)

            if base_ppl >= threshold and lora_ppl < threshold:
                filtered.append(sample)

        except Exception as e:
            print(f"✗ Error with sample: {e}")
            continue

    with open(output_file, "w") as f:
        json.dump(filtered, f, indent=2)

    print(f"✓ Filtered {len(filtered)} samples → saved to {output_file}")
    return filtered