import json
import torch
import math
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import matplotlib.pyplot as plt
from collections import defaultdict
import re

# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

bin_threshold = 6.0

# Load model and tokenizer
lora_model_name = "TuneShift-KD/gemma-gsm8k-lora-source"
lora_tokenizer = AutoTokenizer.from_pretrained(lora_model_name)
lora_model = AutoModelForCausalLM.from_pretrained(
    lora_model_name, load_in_4bit=True
).to(device)

def get_lora_perplexity(model, tokenizer, prompt, output, top_k=20):
    device = model.device
    log_probs = []

    input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)["input_ids"].to(device)
    output_tokens = tokenizer(output, return_tensors="pt", add_special_tokens=False)["input_ids"].to(device)[0]

    current_input_ids = input_ids.clone()

    for token_id in output_tokens:
        with torch.no_grad():
            outputs = model(current_input_ids)
        
        logits = outputs.logits[:, -1, :]
        log_probs_tensor = torch.nn.functional.log_softmax(logits, dim=-1)

        top_k_values, top_k_indices = torch.topk(log_probs_tensor, top_k, dim=-1)

        if token_id in top_k_indices:
            token_log_prob = log_probs_tensor[0, token_id].item()
        else:
            token_log_prob = top_k_values[0, -1].item()

        log_probs.append(token_log_prob)

        current_input_ids = torch.cat([current_input_ids, token_id.unsqueeze(0).unsqueeze(0)], dim=-1)

    if len(log_probs) == 0:
        return float("inf")

    avg_log_prob = sum(log_probs) / len(log_probs)
    perplexity = math.exp(-avg_log_prob)

    return perplexity

# Load your datasets
print("Loading datasets...")

teacher_data_path = "gsm8k_work/gemma_gsm8k_no_filter.json"
base_data_path = "gsm8k_work/gemma_base_250.json"

with open(teacher_data_path, "r") as f:
    teacher_data = json.load(f)

with open(base_data_path, "r") as f:
    base_data = json.load(f)

# teacher_data = teacher_data[:10]
# base_data = base_data[:10]

# Ensure datasets have the same length
assert len(teacher_data) == len(base_data), "Teacher and base datasets must have the same length"
print(f"Loaded {len(teacher_data)} samples from each dataset")

# Define perplexity threshold
threshold = 1.5  # Your perplexity threshold

# Results tracking
perplexity_results = []
filtered_samples = []
accepted_samples = []

# Define perplexity difference ranges for categorization
diff_ranges = [
    (-float('inf'), -1.0, "Much Better (Teacher << Base)"),
    (-1.0, -0.3, "Better (Teacher < Base)"),
    (-0.3, 0.3, "Similar (Teacher ≈ Base)"),
    (0.3, 1.0, "Worse (Teacher > Base)"),
    (1.0, float('inf'), "Much Worse (Teacher >> Base)")
]

samples_by_category = defaultdict(list)

print(f"Analyzing perplexity differences across {len(teacher_data)} samples...")

# Create a progress bar
progress_bar = tqdm(enumerate(zip(teacher_data, base_data)), total=len(teacher_data))

for i, (teacher_sample, base_sample) in progress_bar:
    # Extract data
    prompt = teacher_sample["instruction"]
    teacher_output = teacher_sample["output"]
    base_output = base_sample["output"]
    
    # Format prompt if needed
    input_text = (
            "You are a math tutor helping a student. Answer the question the following grade school math question step-by-step in full English sentences.\n\n"
            f"Question: {prompt}\nAnswer:"
        )

    teacher_output = re.split(r"####", teacher_output)[0].strip()
    base_output = re.split(r"####", teacher_output)[0].strip()
    
    # Calculate perplexities
    teacher_ppl = get_lora_perplexity(lora_model, lora_tokenizer, input_text, teacher_output)
    base_ppl = get_lora_perplexity(lora_model, lora_tokenizer, prompt, base_output)
    
    # Calculate difference and determine if sample would be filtered
    ppl_diff = teacher_ppl - base_ppl
    would_be_filtered = (base_ppl < threshold) or (teacher_ppl >= threshold)
    
    # Store results
    result = {
        "index": i,
        "prompt": prompt,
        "teacher_output": teacher_output,
        "base_output": base_output,
        "teacher_ppl": teacher_ppl,
        "base_ppl": base_ppl,
        "ppl_diff": ppl_diff,
        "would_be_filtered": would_be_filtered
    }
    perplexity_results.append(result)
    
    if would_be_filtered:
        filtered_samples.append(result)
    else:
        accepted_samples.append(result)
    
    # Categorize by PPL difference
    for low, high, category in diff_ranges:
        if low <= ppl_diff < high:
            samples_by_category[category].append(result)
            break
    
    # Update progress bar description
    progress_bar.set_description(f"T-PPL: {teacher_ppl:.2f}, B-PPL: {base_ppl:.2f}, Diff: {ppl_diff:.2f}")

# Calculate statistics
filtered_count = len(filtered_samples)
filtered_percentage = (filtered_count / len(teacher_data)) * 100 if teacher_data else 0

# Filter out infinite values before calculating statistics or plotting
valid_teacher_ppls = [r["teacher_ppl"] for r in perplexity_results if math.isfinite(r["teacher_ppl"])]
valid_base_ppls = [r["base_ppl"] for r in perplexity_results if math.isfinite(r["base_ppl"])]
valid_ppl_diffs = [r["ppl_diff"] for r in perplexity_results if math.isfinite(r["ppl_diff"])]

# Count infinite values
inf_teacher_ppls = sum(1 for r in perplexity_results if not math.isfinite(r["teacher_ppl"]))
inf_base_ppls = sum(1 for r in perplexity_results if not math.isfinite(r["base_ppl"]))
inf_ppl_diffs = sum(1 for r in perplexity_results if not math.isfinite(r["ppl_diff"]))

# Print analysis results
print("\n===== PERPLEXITY FILTERING ANALYSIS =====")
print(f"Total samples: {len(perplexity_results)}")
print(f"Samples that would be filtered out: {filtered_count} ({filtered_percentage:.2f}%)")
print(f"Samples that would be kept: {len(accepted_samples)} ({100-filtered_percentage:.2f}%)")

# Only calculate statistics if we have results
if valid_teacher_ppls:
    print("\n===== TEACHER MODEL PERPLEXITY =====")
    print(f"Infinite values: {inf_teacher_ppls}")
    print(f"Min: {min(valid_teacher_ppls):.4f}")
    print(f"Max: {max(valid_teacher_ppls):.4f}")
    print(f"Mean: {np.mean(valid_teacher_ppls):.4f}")
    print(f"Median: {np.median(valid_teacher_ppls):.4f}")

if valid_base_ppls:
    print("\n===== BASE MODEL PERPLEXITY =====")
    print(f"Infinite values: {inf_base_ppls}")
    print(f"Min: {min(valid_base_ppls):.4f}")
    print(f"Max: {max(valid_base_ppls):.4f}")
    print(f"Mean: {np.mean(valid_base_ppls):.4f}")
    print(f"Median: {np.median(valid_base_ppls):.4f}")

if valid_ppl_diffs:
    print("\n===== PERPLEXITY DIFFERENCE (TEACHER - BASE) =====")
    print(f"Infinite values: {inf_ppl_diffs}")
    print(f"Min: {min(valid_ppl_diffs):.4f}")
    print(f"Max: {max(valid_ppl_diffs):.4f}")
    print(f"Mean: {np.mean(valid_ppl_diffs):.4f}")
    print(f"Median: {np.median(valid_ppl_diffs):.4f}")

# Print breakdown by category
print("\n===== SAMPLES BY PERPLEXITY DIFFERENCE CATEGORY =====")
for _, _, category in diff_ranges:
    samples = samples_by_category[category]
    sample_count = len(samples)
    total_count = len(perplexity_results)
    
    # Display category information with proper error handling
    print(f"\n{category}: {sample_count} samples ({sample_count/total_count*100:.1f}% of total)" if total_count > 0 else f"\n{category}: {sample_count} samples (0.0% of total)")
    
    filtered_in_category = sum(1 for s in samples if s["would_be_filtered"])
    if sample_count > 0:
        print(f"  Would be filtered: {filtered_in_category} ({filtered_in_category/sample_count*100:.1f}% of category)")
    else:
        print(f"  Would be filtered: {filtered_in_category} (0.0% of category)")
    
    if samples:
        # Show example from middle of category
        example = samples[len(samples) // 2]
        print(f"\n  EXAMPLE (Sample #{example['index']}):")
        print(f"  Prompt: {example['prompt']}")
        print(f"  Teacher output: {example['teacher_output']}")
        print(f"  Base output: {example['base_output']}")
        print(f"  Teacher PPL: {example['teacher_ppl']:.4f}")
        print(f"  Base PPL: {example['base_ppl']:.4f}")
        print(f"  PPL Difference: {example['ppl_diff']:.4f}")
        print("  " + "-" * 40)

# Print filtering details
if filtered_samples:
    print("\n===== FILTERED SAMPLES BREAKDOWN =====")
    base_too_good = sum(1 for s in filtered_samples if s["base_ppl"] < threshold)
    teacher_too_bad = sum(1 for s in filtered_samples if s["teacher_ppl"] >= threshold)
    print(f"Base model PPL too low (<{threshold}): {base_too_good} samples")
    print(f"Teacher model PPL too high (≥{threshold}): {teacher_too_bad} samples")
    
    # Show examples of samples with infinite perplexity
    inf_samples = [s for s in perplexity_results if not math.isfinite(s["teacher_ppl"]) or not math.isfinite(s["base_ppl"])]
    
    if inf_samples:
        print("\nEXAMPLES OF SAMPLES WITH INFINITE PERPLEXITY:")
        for i, sample in enumerate(inf_samples[:min(3, len(inf_samples))]):
            print(f"\nInfinite Sample #{i+1} (Index {sample['index']}):")
            print(f"Teacher PPL: {sample['teacher_ppl']}")
            print(f"Base PPL: {sample['base_ppl']}")
            print(f"Prompt: {sample['prompt']}")
            print(f"Teacher output: {sample['teacher_output']}")
            print(f"Base output: {sample['base_output']}")
            print("---")
    
    # Show top 3 examples of filtered samples (if available)
    print("\nEXAMPLES OF FILTERED SAMPLES:")
    # Filter out samples with infinite perplexity for this example
    finite_filtered = [s for s in filtered_samples if math.isfinite(s["teacher_ppl"]) and math.isfinite(s["base_ppl"])]
    for i, sample in enumerate(finite_filtered[:min(3, len(finite_filtered))]):
        print(f"\nFiltered Sample #{i+1} (Index {sample['index']}):")
        print(f"Reason: {'Base PPL too low' if sample['base_ppl'] < threshold else 'Teacher PPL too high'}")
        print(f"Prompt: {sample['prompt']}")
        print(f"Teacher output: {sample['teacher_output']}")
        print(f"Base output: {sample['base_output']}")
        print(f"Teacher PPL: {sample['teacher_ppl']:.4f}")
        print(f"Base PPL: {sample['base_ppl']:.4f}")
        print("---")

# Generate plots only if we have data and no infinite values
if valid_teacher_ppls and valid_base_ppls:
        # Define bin threshold and edges

    plt.figure(figsize=(12, 8))

    # Plot 1: Histogram of teacher vs base perplexities (with outlier filtering)
    plt.subplot(2, 1, 1)
    bin_edges = np.append(np.linspace(1.0, bin_threshold, 29), [bin_threshold + 0.5])

    # Bin values by capping at overflow threshold
    binned_teacher = [min(p, bin_threshold + 0.01) for p in valid_teacher_ppls]
    binned_base = [min(p, bin_threshold + 0.01) for p in valid_base_ppls]

    # Plot histograms
    n_teacher, _, _ = plt.hist(binned_teacher, bins=bin_edges, alpha=0.7, label=f'Teacher Model', color='steelblue',)
    n_base, _, _ = plt.hist(binned_base, bins=bin_edges, alpha=0.5, label=f'Base Model', color='orange',)

    # Vertical lines
    plt.axvline(x=threshold, color='r', linestyle='--', linewidth=2, label=f'Threshold = {threshold}')
    plt.axvline(x=bin_threshold, color='gray', linestyle='--', linewidth=2, label=f'Binning Threshold = {bin_threshold}')

    # Count overflow samples (real values > threshold)
    overflow_teacher = sum(p > bin_threshold for p in valid_teacher_ppls)
    overflow_base = sum(p > bin_threshold for p in valid_base_ppls)

    # Compute bar height for annotation (based on actual histogram bin)
    bar_height = max(overflow_teacher, overflow_base)

    # Annotate once, on top of final bin
    plt.text(bin_edges[-2] + 0.25, bar_height + 0.5, f"{overflow_teacher} teacher / {overflow_base} base", 
            ha='center', fontsize=10, color='black')
    # Custom x-axis ticks
    # Generate ticks every 0.5 from 1.0 to bin_threshold
    xticks = np.arange(1.0, bin_threshold + 0.5, 0.5)
    xticks = list(xticks) + [bin_threshold + 0.5]  # for overflow bin label

    xticklabels = [f"{tick:.1f}" if tick < bin_threshold else f"≥ {bin_threshold}" for tick in xticks]
    plt.xticks(xticks, xticklabels)


    # Labels and layout
    plt.xlabel('Perplexity')
    plt.ylabel('Count')
    plt.title(f'Distribution of Perplexity Scores (Overflow bin: ≥ {bin_threshold})')
    plt.legend()
    plt.grid(True, alpha=0.3)

    # Plot 2: Histogram of perplexity differences (with outlier filtering)
    plt.subplot(2, 1, 2)
    
    # Calculate reasonable limits for difference plot
    max_diff = np.percentile(valid_ppl_diffs, 95)
    min_diff = np.percentile(valid_ppl_diffs, 5)
    # Expand range slightly
    plot_range = (min_diff - 0.5, max_diff + 0.5)
    
    # Filter to reasonable range for plotting
    plot_ppl_diffs = [d for d in valid_ppl_diffs if plot_range[0] <= d <= plot_range[1]]
    
    plt.hist(plot_ppl_diffs, bins=30, alpha=0.7, 
             label=f'Showing {len(plot_ppl_diffs)}/{len(valid_ppl_diffs)} differences')
    plt.axvline(x=0, color='r', linestyle='--', label='No Difference')
    plt.xlabel('Perplexity Difference (Teacher - Base)')
    plt.ylabel('Count')
    plt.title('Distribution of Perplexity Differences (excluding outliers)')
    plt.legend()
    plt.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig('perplexity_analysis.png')
    print("\nPerplexity analysis plots saved as 'perplexity_analysis.png'")

    # Save detailed analysis results (handle infinite values for JSON serialization)
    def make_json_serializable(value):
        if not math.isfinite(value):
            return str(value)
        return float(value)
    
    output_analysis = {
        "statistics": {
            "total_samples": len(perplexity_results),
            "filtered_count": filtered_count,
            "filtered_percentage": filtered_percentage,
            "teacher_ppl_stats": {
                "infinite_count": inf_teacher_ppls,
                "min": make_json_serializable(min(valid_teacher_ppls)) if valid_teacher_ppls else "N/A",
                "max": make_json_serializable(max(valid_teacher_ppls)) if valid_teacher_ppls else "N/A",
                "mean": make_json_serializable(np.mean(valid_teacher_ppls)) if valid_teacher_ppls else "N/A",
                "median": make_json_serializable(np.median(valid_teacher_ppls)) if valid_teacher_ppls else "N/A"
            },
            "base_ppl_stats": {
                "infinite_count": inf_base_ppls,
                "min": make_json_serializable(min(valid_base_ppls)) if valid_base_ppls else "N/A",
                "max": make_json_serializable(max(valid_base_ppls)) if valid_base_ppls else "N/A",
                "mean": make_json_serializable(np.mean(valid_base_ppls)) if valid_base_ppls else "N/A",
                "median": make_json_serializable(np.median(valid_base_ppls)) if valid_base_ppls else "N/A"
            },
            "ppl_diff_stats": {
                "infinite_count": inf_ppl_diffs,
                "min": make_json_serializable(min(valid_ppl_diffs)) if valid_ppl_diffs else "N/A",
                "max": make_json_serializable(max(valid_ppl_diffs)) if valid_ppl_diffs else "N/A",
                "mean": make_json_serializable(np.mean(valid_ppl_diffs)) if valid_ppl_diffs else "N/A",
                "median": make_json_serializable(np.median(valid_ppl_diffs)) if valid_ppl_diffs else "N/A"
            }
        },
        "category_breakdown": {
            category: [
                {
                    "index": s["index"],
                    "prompt": s["prompt"],
                    "teacher_output": s["teacher_output"],
                    "base_output": s["base_output"],
                    "teacher_ppl": make_json_serializable(s["teacher_ppl"]),
                    "base_ppl": make_json_serializable(s["base_ppl"]),
                    "ppl_diff": make_json_serializable(s["ppl_diff"]),
                    "would_be_filtered": s["would_be_filtered"]
                }
                for s in samples_by_category[category]
            ]
            for _, _, category in diff_ranges
        },
        "filtered_samples": [
            {
                "index": s["index"],
                "prompt": s["prompt"],
                "teacher_output": s["teacher_output"],
                "base_output": s["base_output"],
                "teacher_ppl": make_json_serializable(s["teacher_ppl"]),
                "base_ppl": make_json_serializable(s["base_ppl"]),
                "ppl_diff": make_json_serializable(s["ppl_diff"]),
                "filter_reason": "Base PPL too low" if s["base_ppl"] < threshold else "Teacher PPL too high"
            }
            for s in filtered_samples
        ],
        "infinite_ppl_samples": [
            {
                "index": s["index"],
                "prompt": s["prompt"],
                "teacher_output": s["teacher_output"],
                "base_output": s["base_output"],
                "teacher_ppl": make_json_serializable(s["teacher_ppl"]),
                "base_ppl": make_json_serializable(s["base_ppl"]),
                "ppl_diff": make_json_serializable(s["ppl_diff"])
            }
            for s in perplexity_results 
            if not math.isfinite(s["teacher_ppl"]) or not math.isfinite(s["base_ppl"])
        ]
    }

    with open("perplexity_comparison_analysis.json", "w") as f:
        json.dump(output_analysis, f, indent=2)

    print("Detailed analysis saved to 'perplexity_comparison_analysis.json'")
else:
    print("\nNo valid data to plot or save. Please check your input files.")