import subprocess
import json
import argparse
import csv
import os
import re

# grid_search.py
# Sweeps field-weight vectors and blend ratios, captures full eval_hybrid output
# and extracts Hit@K metrics, logging both raw logs and parsed CSV.

WEIGHT_VECTORS = {
    'current':  {
    'keywords': 0.134207,
    'questions': 0.226103,
    'thesis': 0.094972,
    'search_boost': 0.029563,
    'query_match_1': 0.217395,
    'query_match_2': 0.241111,
    'query_match_3': 0.056650,
    }
}
BLEND_RATIOS = [0.5, 0.6, 0.7, 0.8]


def run_evaluation(output_dir, query_csv, weights, embed_weight, logs_dir):
    # Prepare command
    bm25_weight = 1.0 - embed_weight
    cmd = [
        'python', 'eval_hybrid.py',
        '--output_dir', output_dir,
        '--query_csv', query_csv,
        '--weights_json', json.dumps(weights),
        '--embed_weight', str(embed_weight),
        '--bm25_weight', str(bm25_weight)
    ]
    # Run
    proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
    # Write raw log
    name = weights.get('name', str(weights))
    log_filename = f"grid_{name}_{embed_weight:.2f}.log"
    raw_path = os.path.join(logs_dir, log_filename)
    with open(raw_path, 'w') as f:
        f.write(proc.stdout)
        f.write(proc.stderr)
    # Parse any Hit@K lines: look for "Hit@<k>:" then the next "→ Best:" line
    parsed = {}
    lines = proc.stdout.splitlines()
    for i, line in enumerate(lines):
        m = re.match(r"\s*Hit@(\d+):", line)
        if m:
            k = m.group(1)
            # find the best line soon after
            for j in range(i+1, min(i+6, len(lines))):
                mb = re.search(r"→ Best: ([^\(]+) \((0\.\d+)\)", lines[j])
                if mb:
                    parsed[f"Hit@{k}_method"] = mb.group(1).strip()
                    parsed[f"Hit@{k}_score"] = float(mb.group(2))
                    break
    return parsed, raw_path


def main():
    parser = argparse.ArgumentParser(description='Grid search over weights and blend ratios')
    parser.add_argument('--output_dir', required=True)
    parser.add_argument('--query_csv', required=True)
    parser.add_argument('--log_csv', default='grid_search_results.csv')
    parser.add_argument('--logs_dir', default='grid_logs')
    args = parser.parse_args()

    os.makedirs(args.logs_dir, exist_ok=True)
    # Prepare CSV
    # We'll dynamically collect all Hit@Ks seen
    all_results = []
    for name, weights in WEIGHT_VECTORS.items():
        # attach name to weights for log naming
        weights['name'] = name
        for embed_w in BLEND_RATIOS:
            print(f"Running {name} @ embed_weight={embed_w}")
            try:
                parsed, raw_path = run_evaluation(
                    args.output_dir, args.query_csv, weights, embed_w, args.logs_dir
                )
            except Exception as e:
                print(f"Error for {name}@{embed_w}: {e}")
                continue
            row = {
                'weight_vector': name,
                'embed_weight': embed_w,
                'log_file': raw_path
            }
            row.update(parsed)
            all_results.append(row)

    # Write CSV
    # Determine full fieldnames
    fieldnames = set()
    for r in all_results:
        fieldnames.update(r.keys())
    fieldnames = ['weight_vector', 'embed_weight', 'log_file'] + sorted(k for k in fieldnames if k not in ('weight_vector','embed_weight','log_file'))
    with open(args.log_csv, 'w', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()
        for r in all_results:
            writer.writerow(r)
    print(f"Grid search complete. Parsed results in {args.log_csv}, raw logs in {args.logs_dir}")

if __name__ == '__main__':
    main()
