#!/usr/bin/env python3
"""
Code Style Experiment Runner.

Compares JO-based style conversion against baselines:
1. NR (No Repair): Original messy code
2. Formatter-only: Black + isort only
3. JO: Formatter + precedent-guided patching
4. (Optional) Few-shot LLM rewrite
"""

import sys
import os
import json
import time
from datetime import datetime
from typing import Dict, List, Any, Tuple

# Add src to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'data'))

from test_programs import get_programs
from messy_generator import MessyCodeGenerator
from style_checker import StyleChecker, FunctionalChecker
from jo_style import StyleJO, JOOutcome


def run_experiment(n_programs: int = None, seed: int = 42) -> Dict[str, Any]:
    """
    Run the full code style experiment.

    Args:
        n_programs: Number of programs to test (None = all).
        seed: Random seed for messy generation.

    Returns:
        Dictionary with all results.
    """
    programs = get_programs()
    if n_programs:
        programs = programs[:n_programs]

    print(f"=" * 60)
    print(f"CODE STYLE EXPERIMENT")
    print(f"Programs: {len(programs)}, Seed: {seed}")
    print(f"=" * 60)

    # Initialize components
    messy_gen = MessyCodeGenerator(seed=seed)
    checker = StyleChecker()
    func_checker = FunctionalChecker()
    jo = StyleJO(max_iterations=5)

    # Results storage
    results = {
        'timestamp': datetime.now().isoformat(),
        'n_programs': len(programs),
        'seed': seed,
        'conditions': {},
        'per_program': []
    }

    # Run each program
    for idx, (name, clean_code, test_code) in enumerate(programs):
        print(f"\n[{idx+1}/{len(programs)}] {name}")

        # Generate messy version
        messy_code = messy_gen.generate_messy(clean_code)

        # Verify messy code still works
        func_pass, func_err = func_checker.check(messy_code, test_code)
        if not func_pass:
            print(f"  WARNING: Messy code breaks tests, skipping")
            continue

        program_result = {
            'name': name,
            'clean_violations': len(checker.check(clean_code).violations),
            'conditions': {}
        }

        # === Condition 1: NO (No Operator) ===
        no_result = checker.check(messy_code)
        program_result['conditions']['NO'] = {
            'violations': len(no_result.violations),
            'score': no_result.score,
            'passed': no_result.passed,
            'func_preserved': True,  # Already verified
            'edit_distance': 0  # No edits
        }

        # === Condition 2: Formatter Only (Black + isort) ===
        formatted_code = jo._run_black(messy_code)
        formatted_code = jo._run_isort(formatted_code)

        fmt_check = checker.check(formatted_code)
        fmt_func, _ = func_checker.check(formatted_code, test_code)
        fmt_edit = jo._compute_edit_distance(messy_code, formatted_code)

        program_result['conditions']['Formatter'] = {
            'violations': len(fmt_check.violations),
            'score': fmt_check.score,
            'passed': fmt_check.passed,
            'func_preserved': fmt_func,
            'edit_distance': fmt_edit
        }

        # === Condition 3: Formatter + Templates (brute-force docstrings) ===
        template_code = jo._run_black(messy_code)
        template_code = jo._run_isort(template_code)
        template_code = jo._add_all_docstrings(template_code)
        template_code = jo._run_black(template_code)  # Re-format after AST changes

        tpl_check = checker.check(template_code)
        tpl_func, _ = func_checker.check(template_code, test_code)
        tpl_edit = jo._compute_edit_distance(messy_code, template_code)

        program_result['conditions']['Formatter+Tpl'] = {
            'violations': len(tpl_check.violations),
            'score': tpl_check.score,
            'passed': tpl_check.passed,
            'func_preserved': tpl_func,
            'edit_distance': tpl_edit
        }

        # === Condition 4: JO (Full) ===
        jo_result = jo.process(messy_code)
        jo_check = checker.check(jo_result.output_code)
        jo_func, _ = func_checker.check(jo_result.output_code, test_code)

        program_result['conditions']['JO'] = {
            'violations': len(jo_check.violations),
            'score': jo_check.score,
            'passed': jo_check.passed,
            'func_preserved': jo_func,
            'outcome': jo_result.outcome.value,
            'iterations': jo_result.iterations,
            'violations_fixed': jo_result.violations_fixed,
            'edit_distance': jo_result.edit_distance
        }

        results['per_program'].append(program_result)

        # Print progress
        print(f"  NO: {program_result['conditions']['NO']['violations']} violations")
        print(f"  Formatter: {program_result['conditions']['Formatter']['violations']} violations")
        print(f"  Fmt+Tpl: {program_result['conditions']['Formatter+Tpl']['violations']} violations")
        print(f"  JO: {program_result['conditions']['JO']['violations']} violations ({jo_result.outcome.value})")

    # Compute aggregate statistics
    results['conditions'] = compute_aggregates(results['per_program'])

    return results


def compute_aggregates(per_program: List[Dict]) -> Dict[str, Dict]:
    """Compute aggregate statistics across all programs."""
    conditions = {}

    for cond in ['NO', 'Formatter', 'Formatter+Tpl', 'JO']:
        violations = [p['conditions'][cond]['violations'] for p in per_program]
        scores = [p['conditions'][cond]['score'] for p in per_program]
        passed = [p['conditions'][cond]['passed'] for p in per_program]
        func = [p['conditions'][cond]['func_preserved'] for p in per_program]
        edit_dist = [p['conditions'][cond].get('edit_distance', 0) for p in per_program]

        n = len(violations)
        conditions[cond] = {
            'n_programs': n,
            'mean_violations': sum(violations) / n if n else 0,
            'mean_score': sum(scores) / n if n else 0,
            'style_pass_rate': sum(passed) / n if n else 0,
            'func_preserve_rate': sum(func) / n if n else 0,
            'zero_violation_rate': sum(1 for v in violations if v == 0) / n if n else 0,
            'mean_edit_distance': sum(edit_dist) / n if n else 0
        }

        # JO-specific metrics
        if cond == 'JO':
            outcomes = [p['conditions']['JO']['outcome'] for p in per_program]
            conditions[cond]['allow_rate'] = outcomes.count('allow') / n if n else 0
            conditions[cond]['edit_rate'] = outcomes.count('edit') / n if n else 0
            conditions[cond]['escalate_rate'] = outcomes.count('escalate') / n if n else 0
            conditions[cond]['deny_rate'] = outcomes.count('deny') / n if n else 0

    return conditions


def print_results_table(results: Dict) -> None:
    """Print results as formatted table."""
    print("\n" + "=" * 80)
    print("RESULTS SUMMARY")
    print("=" * 80)

    conds = results['conditions']

    print(f"\n{'Method':<15} {'Pass@0↑':>10} {'Viol↓':>8} {'Func%↑':>10} {'EditDist↓':>10}")
    print("-" * 60)

    for name in ['NO', 'Formatter', 'Formatter+Tpl', 'JO']:
        c = conds[name]
        print(f"{name:<15} {c['zero_violation_rate']*100:>9.1f}% {c['mean_violations']:>8.2f} {c['func_preserve_rate']*100:>9.1f}% {c['mean_edit_distance']:>10.1f}")

    print("-" * 60)

    # JO-specific breakdown
    jo = conds['JO']
    print(f"\nJO Outcomes:")
    print(f"  ALLOW: {jo.get('allow_rate', 0)*100:.1f}%")
    print(f"  EDIT:  {jo.get('edit_rate', 0)*100:.1f}%")
    print(f"  ESCALATE: {jo.get('escalate_rate', 0)*100:.1f}%")
    print(f"  DENY:  {jo.get('deny_rate', 0)*100:.1f}%")


def compute_confidence_intervals(results: Dict) -> Dict:
    """Compute 95% confidence intervals for key metrics."""
    from scipy import stats
    import numpy as np

    cis = {}
    for cond in ['NO', 'Formatter', 'Formatter+Tpl', 'JO']:
        violations = [p['conditions'][cond]['violations'] for p in results['per_program']]
        func = [1 if p['conditions'][cond]['func_preserved'] else 0 for p in results['per_program']]

        n = len(violations)
        if n > 1:
            # Violation rate CI
            viol_mean = np.mean(violations)
            viol_sem = stats.sem(violations) if stats.sem(violations) > 0 else 0.001
            viol_ci = stats.t.interval(0.95, n-1, loc=viol_mean, scale=viol_sem)

            # Functional preservation CI (binomial)
            func_mean = np.mean(func)
            func_sem = np.sqrt(func_mean * (1 - func_mean) / n) if func_mean < 1 else 0
            func_ci = (func_mean - 1.96*func_sem, func_mean + 1.96*func_sem)

            cis[cond] = {
                'violations': {'mean': viol_mean, 'ci': viol_ci},
                'func_rate': {'mean': func_mean, 'ci': func_ci}
            }

    return cis


def save_results(results: Dict, filepath: str) -> None:
    """Save results to JSON file."""
    with open(filepath, 'w') as f:
        json.dump(results, f, indent=2)
    print(f"\nResults saved to: {filepath}")


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(description='Run code style experiment')
    parser.add_argument('--n', type=int, default=None, help='Number of programs (default: all)')
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--output', type=str, default='results/style_results.json',
                        help='Output file path')
    args = parser.parse_args()

    # Create results directory
    os.makedirs(os.path.dirname(args.output) or 'results', exist_ok=True)

    # Run experiment
    start_time = time.time()
    results = run_experiment(n_programs=args.n, seed=args.seed)
    elapsed = time.time() - start_time

    results['elapsed_seconds'] = elapsed

    # Print results
    print_results_table(results)

    # Compute CIs if scipy available
    try:
        cis = compute_confidence_intervals(results)
        results['confidence_intervals'] = cis
        print("\n95% Confidence Intervals:")
        for cond, ci in cis.items():
            print(f"  {cond}: Violations = {ci['violations']['mean']:.2f} "
                  f"({ci['violations']['ci'][0]:.2f}, {ci['violations']['ci'][1]:.2f})")
    except ImportError:
        print("\n(scipy not available for CI calculation)")

    # Save
    save_results(results, args.output)

    print(f"\nTotal time: {elapsed:.1f}s")
