#!/bin/bash

# Toy test for WRRF learning determinism (no BEIR data required)
# Validates that WRRF produces consistent results with same inputs

set -euo pipefail

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" 
TEST_DIR="${SCRIPT_DIR}/tmp_wrrf_test"

echo "=== Toy WRRF Determinism Test ==="
echo "Testing WRRF fusion consistency with synthetic data"

# Setup
mkdir -p "$TEST_DIR"
cd "$TEST_DIR"

# Create synthetic run files with known scores
cat > run1.trec << 'EOF'
Q001 Q0 D001 1 0.9 run1
Q001 Q0 D002 2 0.8 run1
Q001 Q0 D003 3 0.7 run1
Q002 Q0 D001 1 0.6 run1
Q002 Q0 D002 2 0.5 run1
Q002 Q0 D003 3 0.4 run1
EOF

cat > run2.trec << 'EOF'  
Q001 Q0 D001 1 0.85 run2
Q001 Q0 D002 2 0.75 run2
Q001 Q0 D003 3 0.65 run2
Q002 Q0 D001 1 0.55 run2
Q002 Q0 D002 2 0.45 run2
Q002 Q0 D003 3 0.35 run2
EOF

cat > run3.trec << 'EOF'
Q001 Q0 D001 1 0.3 run3
Q001 Q0 D002 2 0.4 run3  
Q001 Q0 D003 3 0.5 run3
Q002 Q0 D001 1 0.7 run3
Q002 Q0 D002 2 0.8 run3
Q002 Q0 D003 3 0.9 run3
EOF

# Create synthetic gate features (query complexity indicators)
cat > gate_features.csv << 'EOF'
qid,query_len,dense_std,sparse_max,ges_coverage,domain_signal
Q001,8,0.12,0.9,0.75,0.3
Q002,6,0.08,0.6,0.85,0.7
EOF

echo "Created synthetic WRRF test data"

# Test 1: Basic RRF (Reciprocal Rank Fusion)
echo
echo "Test 1: Basic Reciprocal Rank Fusion"
python3 << 'PYTHON_EOF'
from collections import defaultdict

def reciprocal_rank_fusion(runs, k=60):
    """Basic RRF formula: score = sum(1/(k + rank))"""
    fused_scores = defaultdict(lambda: defaultdict(float))
    
    for run_name, run_data in runs.items():
        for qid in run_data:
            for rank, (docid, _) in enumerate(run_data[qid], 1):
                fused_scores[qid][docid] += 1.0 / (k + rank)
    
    return fused_scores

# Load runs
runs = {}
for run_file in ['run1.trec', 'run2.trec', 'run3.trec']:
    run_name = run_file.replace('.trec', '')
    runs[run_name] = defaultdict(list)
    
    with open(run_file, 'r') as f:
        for line in f:
            parts = line.strip().split()
            qid, docid, rank, score = parts[0], parts[2], int(parts[3]), float(parts[4])
            runs[run_name][qid].append((docid, score))

# Apply RRF
print("Basic RRF (k=60):")
rrf_results = reciprocal_rank_fusion(runs, k=60)

for qid in ['Q001', 'Q002']:
    sorted_docs = sorted(rrf_results[qid].items(), key=lambda x: x[1], reverse=True)
    print(f"  {qid}: {' > '.join([f'{doc}({score:.4f})' for doc, score in sorted_docs])}")
PYTHON_EOF

# Test 2: Weighted RRF with learned weights
echo 
echo "Test 2: Weighted RRF with learned weights"
python3 << 'PYTHON_EOF'
import math

def weighted_rrf(runs, weights, k=60):
    """WRRF with learned weights for each run."""
    fused_scores = defaultdict(lambda: defaultdict(float))
    
    for run_name, run_data in runs.items():
        weight = weights.get(run_name, 1.0)
        for qid in run_data:
            for rank, (docid, _) in enumerate(run_data[qid], 1):
                fused_scores[qid][docid] += weight / (k + rank)
    
    return fused_scores

# Test with different weight configurations
weight_configs = [
    {"run1": 0.5, "run2": 0.3, "run3": 0.2},  # Dense-heavy
    {"run1": 0.2, "run2": 0.3, "run3": 0.5},  # GES-heavy  
    {"run1": 0.33, "run2": 0.33, "run3": 0.34}  # Balanced
]

print("WRRF with different weight configurations:")
for i, weights in enumerate(weight_configs, 1):
    print(f"\\nConfig {i}: {weights}")
    wrrf_results = weighted_rrf(runs, weights, k=60)
    
    for qid in ['Q001', 'Q002']:
        sorted_docs = sorted(wrrf_results[qid].items(), key=lambda x: x[1], reverse=True)
        print(f"  {qid}: {' > '.join([f'{doc}({score:.4f})' for doc, score in sorted_docs])}")
PYTHON_EOF

# Test 3: Query-adaptive weighting with gate features
echo
echo "Test 3: Query-adaptive WRRF with gate features"
python3 << 'PYTHON_EOF'
import csv

# Load gate features
gate_features = {}
with open('gate_features.csv', 'r') as f:
    reader = csv.DictReader(f)
    for row in reader:
        qid = row['qid']
        gate_features[qid] = {k: float(v) for k, v in row.items() if k != 'qid'}

def compute_adaptive_weights(features):
    """Compute query-specific weights based on features."""
    # Simple heuristic: use domain signal to balance runs
    domain_sig = features.get('domain_signal', 0.5)
    query_len = features.get('query_len', 5)
    ges_coverage = features.get('ges_coverage', 0.5)
    
    # Longer queries favor dense retrieval
    dense_boost = min(query_len / 10.0, 0.7)
    
    # High GES coverage favors generative signals  
    ges_boost = ges_coverage
    
    # Normalize weights
    w1 = 0.3 + dense_boost * 0.3  # Dense run
    w2 = 0.3 + (1 - domain_sig) * 0.2  # Hybrid run
    w3 = 0.2 + ges_boost * 0.3  # GES run
    
    # Normalize to sum to 1
    total = w1 + w2 + w3
    return {"run1": w1/total, "run2": w2/total, "run3": w3/total}

print("Query-adaptive WRRF:")
for qid in ['Q001', 'Q002']:
    features = gate_features[qid]
    adaptive_weights = compute_adaptive_weights(features)
    
    print(f"\\n{qid} features: {features}")
    print(f"Adaptive weights: {adaptive_weights}")
    
    wrrf_results = weighted_rrf(runs, adaptive_weights, k=60)
    sorted_docs = sorted(wrrf_results[qid].items(), key=lambda x: x[1], reverse=True)
    print(f"  Result: {' > '.join([f'{doc}({score:.4f})' for doc, score in sorted_docs])}")
PYTHON_EOF

# Test 4: Determinism check (same inputs = same outputs)
echo
echo "Test 4: Determinism verification"
python3 << 'PYTHON_EOF'
import hashlib

def hash_results(results):
    """Create hash of results for determinism check."""
    result_str = ""
    for qid in sorted(results.keys()):
        sorted_docs = sorted(results[qid].items(), key=lambda x: x[1], reverse=True)
        for doc, score in sorted_docs:
            result_str += f"{qid}:{doc}:{score:.8f}\\n"
    return hashlib.md5(result_str.encode()).hexdigest()

# Run same fusion twice
fixed_weights = {"run1": 0.4, "run2": 0.3, "run3": 0.3}

result1 = weighted_rrf(runs, fixed_weights, k=60)
result2 = weighted_rrf(runs, fixed_weights, k=60)

hash1 = hash_results(result1) 
hash2 = hash_results(result2)

print(f"Run 1 hash: {hash1}")
print(f"Run 2 hash: {hash2}")
print(f"Deterministic: {'✓ PASS' if hash1 == hash2 else '✗ FAIL'}")

# Test with different k values
result_k60 = weighted_rrf(runs, fixed_weights, k=60)
result_k100 = weighted_rrf(runs, fixed_weights, k=100)

print(f"\\nDifferent k values produce different results:")
print(f"k=60 vs k=100: {'Different ✓' if hash_results(result_k60) != hash_results(result_k100) else 'Same (unexpected)'}")
PYTHON_EOF

echo
echo "=== WRRF Determinism Tests Passed ✓ ==="
echo "WRRF components working correctly:"
echo "  - Basic RRF produces stable rankings"
echo "  - Weighted RRF responds to weight changes"
echo "  - Query-adaptive weights adjust per query features"
echo "  - Results are deterministic given same inputs"
echo "  - Parameter changes (k) affect results appropriately"

# Cleanup
cd ..
rm -rf "$TEST_DIR"

echo
echo "WRRF test completed successfully!"