from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Dict, List, Set, Tuple, Optional
import csv
import re


def _find_instance_report(report: Dict, instance_id: str) -> Optional[Dict]:
    if not isinstance(report, dict):
        return None
    if instance_id in report and isinstance(report[instance_id], dict):
        return report[instance_id]
    # Some reports might store instance data directly
    return report if instance_id in str(report) else None


def _extract_tests_status_from_instance(inst_obj: Dict) -> Dict[str, str]:
    """
    Return a map: test_header -> 'passed' | 'failed'

    Reports produced by the harness typically store tests_status as a grouping:
      tests_status = {
        'FAIL_TO_PASS': {'success': [...], 'failure': [...]},
        'PASS_TO_PASS': {'success': [...], 'failure': [...]},
        ...
      }
    We reconstruct a flat mapping from that structure. If not present, fall back to
    discovering any dict mapping test->('passed'|'failed').
    """
    out: Dict[str, str] = {}
    if not isinstance(inst_obj, dict):
        return out

    ts = inst_obj.get("tests_status")
    if isinstance(ts, dict):
        # Grouped structure: each group has 'success'/'failure' lists
        for group_val in ts.values():
            if isinstance(group_val, dict):
                succ = group_val.get("success", [])
                fail = group_val.get("failure", [])
                if isinstance(succ, list):
                    for t in succ:
                        out[str(t)] = "passed"
                if isinstance(fail, list):
                    for t in fail:
                        out[str(t)] = "failed"
        if out:
            return out

    # Fallback: try to find a dict mapping test header -> status
    # Heuristic: collect any dict whose values are 'passed'/'failed'
    found: Dict[str, str] = {}
    def _walk(o):
        nonlocal found
        if isinstance(o, dict):
            # If values are strings and in {passed, failed}, assume it's a tests-status map
            vals = list(o.values())
            if vals and all(isinstance(v, str) for v in vals):
                allvals = {v.lower() for v in vals}
                if allvals.issubset({"passed", "failed"}):
                    for k, v in o.items():
                        found[str(k)] = str(v)
                    return
            # Else walk deeper
            for v in o.values():
                _walk(v)
        elif isinstance(o, list):
            for v in o:
                _walk(v)
    _walk(inst_obj)
    return found


def _load_tests_status(report_path: Path, instance_id: str) -> Dict[str, str]:
    try:
        data = json.loads(report_path.read_text(encoding='utf-8'))
        inst_obj = _find_instance_report(data, instance_id)
        if inst_obj is None:
            return {}
        return _extract_tests_status_from_instance(inst_obj)
    except Exception as e:
        print(f"Warning: Failed to load tests status from {report_path}: {e}")
        return {}


def _gather_run_instances(run_dir: Path) -> List[str]:
    # Enumerate all instance IDs by scanning model subdirs and instance dirs
    iids: Set[str] = set()
    for model_dir in run_dir.iterdir():
        if not model_dir.is_dir():
            continue
        for inst_dir in model_dir.iterdir():
            if inst_dir.is_dir() and (inst_dir / "report.json").exists():
                iids.add(inst_dir.name)
    return sorted(iids)


def _normalize_model_label(model_label: str) -> str:
    """
    Normalize model labels to ensure consistent matching between different run types.
    Extract the base model identifier from various naming patterns.
    """
    # Common patterns to extract base model name
    patterns = [
        r'^(\d{8}_[^-]+)',  # Extract date_model prefix
        r'^([^-]+)',        # Extract first part before dash
    ]
    
    for pattern in patterns:
        match = re.match(pattern, model_label)
        if match:
            return match.group(1)
    
    return model_label


def _count_accepted_tests_for_run(dataset_baseline_dir: Path, te_gold_dir: Path) -> int:
    """
    Count accepted tests for a specific run pair (dataset baseline + TE gold).
    Returns the total number of tests that fail on dataset baseline and pass on TE gold.
    """
    accepted_count = 0
    
    # Get all instances from both directories
    dataset_instances = set()
    te_instances = set()
    
    for model_dir in dataset_baseline_dir.iterdir():
        if not model_dir.is_dir():
            continue
        for inst_dir in model_dir.iterdir():
            if inst_dir.is_dir() and (inst_dir / "report.json").exists():
                dataset_instances.add(inst_dir.name)
    
    for model_dir in te_gold_dir.iterdir():
        if not model_dir.is_dir():
            continue
        for inst_dir in model_dir.iterdir():
            if inst_dir.is_dir() and (inst_dir / "report.json").exists():
                te_instances.add(inst_dir.name)
    
    print(f"    Dataset instances: {len(dataset_instances)}, TE instances: {len(te_instances)}")
    common_instances = dataset_instances & te_instances
    print(f"    Common instances: {len(common_instances)}")
    
    for iid in common_instances:
        # Find report paths
        dataset_rpt = None
        te_rpt = None
        
        for model_dir in dataset_baseline_dir.iterdir():
            if not model_dir.is_dir():
                continue
            inst_dir = model_dir / iid
            if inst_dir.is_dir() and (inst_dir / "report.json").exists():
                dataset_rpt = inst_dir / "report.json"
                break
        
        for model_dir in te_gold_dir.iterdir():
            if not model_dir.is_dir():
                continue
            inst_dir = model_dir / iid
            if inst_dir.is_dir() and (inst_dir / "report.json").exists():
                te_rpt = inst_dir / "report.json"
                break
        
        if dataset_rpt and te_rpt:
            dataset_status = _load_tests_status(dataset_rpt, iid)
            te_status = _load_tests_status(te_rpt, iid)
            
            dataset_failed = {t for t, s in dataset_status.items() if s.lower() == "failed"}
            te_passed = {t for t, s in te_status.items() if s.lower() == "passed"}
            accepted = te_passed & dataset_failed
            if accepted:
                print(f"      {iid}: {len(accepted)} accepted tests")
            accepted_count += len(accepted)
    
    return accepted_count


def _load_runs(logs_dir: Path, run_prefix: str) -> Tuple[Dict[str, Path], Dict[str, Path], Dict[str, Dict[str, Path]]]:
    """
    Return mappings for runs found under logs_dir with the given run_prefix.
    For each model, select only the run with the most accepted tests.
    - dataset_baseline: map instance_id -> path to report.json for dataset baseline (buggy version, no patch)
    - te_gold: map instance_id -> path to report.json for TE gold run
    - te_model: map model_label -> (map instance_id -> report.json path)

    We accept multiple model_labels; for baseline/gold we pick the run with most accepted tests.
    """
    # First, collect all runs by model label
    dataset_baseline_runs: Dict[str, List[Tuple[str, Path]]] = {}  # model_label -> [(run_name, run_dir), ...]
    te_gold_runs: Dict[str, List[Tuple[str, Path]]] = {}  # model_label -> [(run_name, run_dir), ...]
    te_model_runs: Dict[str, List[Tuple[str, Path]]] = {}  # model_label -> [(run_name, run_dir), ...]

    # Find all run_id folders matching run_prefix.* (with optional hash suffix)
    for run_id_dir in logs_dir.iterdir():
        if not run_id_dir.is_dir():
            continue
        name = run_id_dir.name
        if not name.startswith(f"{run_prefix}."):
            continue
        
        # Handle both dataset and TE runs
        if name.startswith(f"{run_prefix}.dataset.baseline."):
            # Dataset baseline runs (buggy version, no patch)
            # Extract model_label (remove hash suffix if present)
            dataset_part = name[len(f"{run_prefix}.dataset.baseline."):]
            model_label = re.sub(r'_[a-f0-9]{8}$', '', dataset_part)
            # Normalize model label by extracting the base model name
            model_label = _normalize_model_label(model_label)
            
            if model_label not in dataset_baseline_runs:
                dataset_baseline_runs[model_label] = []
            dataset_baseline_runs[model_label].append((name, run_id_dir))
            
        elif name.startswith(f"{run_prefix}.te."):
            # TE runs
            # Handle directory names with hash suffixes like: mybatch.te.gold.20241103_OpenHands-CodeAct-2.1-sonn_97188b1c
            # Extract the pattern: <run_prefix>.te.<kind>.<model_label>[_<hash>]
            te_part = name[len(f"{run_prefix}.te."):]
            
            # Split on first dot to get kind
            if '.' not in te_part:
                continue
            kind, rest = te_part.split('.', 1)
            
            # Extract model_label (remove hash suffix if present)
            model_label = re.sub(r'_[a-f0-9]{8}$', '', rest)
            # Normalize model label by extracting the base model name
            model_label = _normalize_model_label(model_label)
            
            # For model runs, we will index by model_label
            if kind == "model":
                if model_label not in te_model_runs:
                    te_model_runs[model_label] = []
                te_model_runs[model_label].append((name, run_id_dir))
            elif kind == "gold":
                if model_label not in te_gold_runs:
                    te_gold_runs[model_label] = []
                te_gold_runs[model_label].append((name, run_id_dir))
    
    # Now select the best dataset baseline + TE gold pairing globally based on accepted test count
    dataset_baseline: Dict[str, Path] = {}
    te_gold: Dict[str, Path] = {}
    te_model: Dict[str, Dict[str, Path]] = {}

    # Debug: print what we found
    print(f"Dataset baseline runs found: {list(dataset_baseline_runs.keys())}")
    print(f"TE gold runs found: {list(te_gold_runs.keys())}")
    print(f"TE model runs found: {list(te_model_runs.keys())}")

    best_accepted_count = -1
    best_dataset = None
    best_te_gold = None
    best_pair_info: Optional[Tuple[str, str]] = None

    # Evaluate all combinations across labels to find the best pairing
    for ds_label, ds_runs in dataset_baseline_runs.items():
        for te_label, te_runs in te_gold_runs.items():
            for ds_name, ds_dir in ds_runs:
                for tg_name, tg_dir in te_runs:
                    print(f"  Counting accepted tests for: {ds_name} + {tg_name}")
                    acc = _count_accepted_tests_for_run(ds_dir, tg_dir)
                    print(f"  {ds_name} + {tg_name} = {acc} accepted tests")
                    if acc > best_accepted_count:
                        best_accepted_count = acc
                        best_dataset = ds_dir
                        best_te_gold = tg_dir
                        best_pair_info = (ds_name, tg_name)

    if best_dataset and best_te_gold:
        print(f"\nSelected best run pair: {best_pair_info[0]} + {best_pair_info[1]} ({best_accepted_count} accepted tests)")
        # Build instance->report maps for selected directories
        for model_dir in best_dataset.iterdir():
            if not model_dir.is_dir():
                continue
            for inst_dir in model_dir.iterdir():
                if not inst_dir.is_dir():
                    continue
                rpt = inst_dir / "report.json"
                if rpt.exists():
                    dataset_baseline[inst_dir.name] = rpt
        for model_dir in best_te_gold.iterdir():
            if not model_dir.is_dir():
                continue
            for inst_dir in model_dir.iterdir():
                if not inst_dir.is_dir():
                    continue
                rpt = inst_dir / "report.json"
                if rpt.exists():
                    te_gold[inst_dir.name] = rpt

    # Handle TE model runs - for each model label, select the run with the most instances
    for model_label, runs in te_model_runs.items():
        best_run = None
        best_instance_count = -1
        for run_name, run_dir in runs:
            inst_map: Dict[str, Path] = {}
            for model_dir in run_dir.iterdir():
                if not model_dir.is_dir():
                    continue
                for inst_dir in model_dir.iterdir():
                    if not inst_dir.is_dir():
                        continue
                    rpt = inst_dir / "report.json"
                    if rpt.exists():
                        inst_map[inst_dir.name] = rpt
            if len(inst_map) > best_instance_count:
                best_instance_count = len(inst_map)
                best_run = (run_name, inst_map)
        if best_run:
            te_model[model_label] = best_run[1]
            print(f"  Selected best TE model run for {model_label}: {best_run[0]} ({best_instance_count} instances)")

    print(f"\nFinal selection:")
    print(f"Total dataset baseline instances: {len(dataset_baseline)}")
    print(f"Total TE gold instances: {len(te_gold)}")
    print(f"Total TE model runs: {len(te_model)}")
    return dataset_baseline, te_gold, te_model


def _read_coverage_pct(dir_path: Path) -> Optional[float]:
    # Attempt to read coverage.json with a 'percent_covered' or similar field
    cov_path = dir_path / "coverage.json"
    if not cov_path.exists():
        return None
    try:
        data = json.loads(cov_path.read_text(encoding='utf-8'))
        # coverage.py json has: {"meta": {...}, "files": {...}, "totals": {"covered_lines":..., "num_statements":..., "percent_covered": ...}}
        totals = None
        if isinstance(data, dict):
            totals = data.get("totals")
            if isinstance(totals, dict):
                pc = totals.get("percent_covered")
                if isinstance(pc, (int, float)):
                    return float(pc)
        # Other common flat fields
        if isinstance(data, dict):
            for key in ("percent_covered", "coverage", "line_coverage", "pct"):
                if key in data and isinstance(data[key], (int, float)):
                    return float(data[key])
        # fallback: try nested numeric value (best-effort)
        if isinstance(data, dict):
            def _walk_find_num(o):
                if isinstance(o, dict):
                    for v in o.values():
                        r = _walk_find_num(v)
                        if r is not None:
                            return r
                elif isinstance(o, list):
                    for v in o:
                        r = _walk_find_num(v)
                        if r is not None:
                            return r
                elif isinstance(o, (int, float)) and 0 <= o <= 100:
                    return float(o)
                return None
            anynum = _walk_find_num(data)
            if anynum is not None:
                return anynum
        return None
    except Exception as e:
        print(f"Warning: Failed to read coverage from {cov_path}: {e}")
        return None


def analyze(run_prefix: str, logs_dir: Path, out_dir: Path):
    logs_dir = logs_dir.resolve()
    out_dir.mkdir(parents=True, exist_ok=True)

    print(f"Analyzing runs with prefix '{run_prefix}' in {logs_dir}")
    print(f"Output directory: {out_dir}")

    # Load runs
    dataset_baseline, te_gold, te_model = _load_runs(logs_dir, run_prefix)

    if not dataset_baseline:
        print("ERROR: No dataset baseline runs found!")
        return
    if not te_gold:
        print("ERROR: No TE gold runs found!")
        return
    if not te_model:
        print("WARNING: No TE model runs found!")

    # Build accepted tests per instance
    # Accepted tests = tests that fail on buggy version (dataset baseline) and pass on TE gold
    accepted_tests: Dict[str, Set[str]] = {}
    all_instances = sorted(set(dataset_baseline.keys()) & set(te_gold.keys()))
    print(f"\nProcessing {len(all_instances)} instances...")
    
    for iid in all_instances:
        buggy_rpt = dataset_baseline[iid]
        gold_rpt = te_gold[iid]
        buggy_status = _load_tests_status(buggy_rpt, iid)
        gold_status = _load_tests_status(gold_rpt, iid)
        
        if not buggy_status and not gold_status:
            print(f"  WARNING: No test status found for {iid}")
            continue
            
        buggy_failed = {t for t, s in buggy_status.items() if s.lower() == "failed"}
        gold_passed = {t for t, s in gold_status.items() if s.lower() == "passed"}
        acc = gold_passed & buggy_failed
        if acc:
            accepted_tests[iid] = acc
            print(f"  {iid}: {len(acc)} accepted tests (fail on buggy, pass on gold)")

    # Identify model labels from te_model
    model_labels = sorted(te_model.keys())

    # Prepare outputs
    detail_csv = out_dir / f"{run_prefix}_accepted_tests_detail.csv"
    matrix_csv = out_dir / f"{run_prefix}_accepted_tests_matrix.csv"
    summary_csv = out_dir / f"{run_prefix}_accepted_tests_summary.csv"
    
    print(f"\nFound {len(accepted_tests)} instances with accepted tests")
    total_accepted = sum(len(tests) for tests in accepted_tests.values())
    print(f"Total accepted tests across all instances: {total_accepted}")

    # Write details: one row per accepted test
    with open(detail_csv, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["instance_id", "test_header"])
        for iid, tests in sorted(accepted_tests.items()):
            for t in sorted(tests):
                w.writerow([iid, t])

    # Build per-test per-model matrix for accepted tests
    with open(matrix_csv, "w", newline="", encoding="utf-8") as f:
        header = ["instance_id", "test_header"] + [f"{ml}__passed" for ml in model_labels]
        w = csv.writer(f)
        w.writerow(header)
        for iid in sorted(accepted_tests.keys()):
            accepted = sorted(accepted_tests[iid])
            # Preload model statuses once per model
            model_status: Dict[str, Dict[str, str]] = {}
            for ml in model_labels:
                rpt = te_model.get(ml, {}).get(iid)
                model_status[ml] = _load_tests_status(rpt, iid) if rpt else {}
            for t in accepted:
                row = [iid, t]
                for ml in model_labels:
                    st = (model_status.get(ml, {}) or {}).get(t, "")
                    row.append(1 if st.lower() == "passed" else 0 if st else "")
                w.writerow(row)

    # For each model, compute pass counts on accepted tests
    # Also attempt to read coverage where available
    with open(summary_csv, "w", newline="", encoding="utf-8") as f:
        header = ["instance_id", "accepted_count"] + [f"{ml}__passed" for ml in model_labels] + [
            "dataset_coverage_pct", "dataset_plus_accepted_coverage_pct", "coverage_delta_pct"
        ]
        w = csv.writer(f)
        w.writerow(header)
        for iid in sorted(accepted_tests.keys()):
            row = [iid]
            accepted = accepted_tests[iid]
            row.append(len(accepted))
            # model performances
            for ml in model_labels:
                model_inst_map = te_model.get(ml, {})
                rpt = model_inst_map.get(iid)
                if not rpt:
                    row.append("")
                    continue
                status = _load_tests_status(rpt, iid)
                passed_on_model = sum(1 for t in accepted if status.get(t, "").lower() == "passed")
                row.append(passed_on_model)

            # Coverage: try to read from dataset baseline run and TE baseline run directories if present
            dataset_cov = None
            te_plus_cov = None
            # find a dataset baseline run_dir matching prefix
            dataset_base_dir: Optional[Path] = None
            te_base_dir: Optional[Path] = None
            # Search for any run_id that matches dataset.baseline.* and te.baseline.* and contains this iid
            for run_id_dir in logs_dir.iterdir():
                if not run_id_dir.is_dir():
                    continue
                name = run_id_dir.name
                if name.startswith(f"{run_prefix}.dataset.baseline."):
                    # find instance directory underneath
                    for model_dir in run_id_dir.iterdir():
                        if not model_dir.is_dir():
                            continue
                        inst_dir = model_dir / iid
                        if inst_dir.is_dir() and (inst_dir / "report.json").exists():
                            dataset_base_dir = inst_dir
                            break
                # Note: We no longer look for te.baseline since we're using dataset.baseline as the buggy version
                # and te.gold as the fixed version. Coverage comparison is between dataset baseline and te gold.
            if dataset_base_dir is not None:
                dataset_cov = _read_coverage_pct(dataset_base_dir)
            # For TE coverage, use the current instance's gold run directory if available
            te_gold_dir = None
            if iid in te_gold:
                te_gold_rpt = te_gold[iid]
                te_gold_dir = te_gold_rpt.parent
            if te_gold_dir is not None:
                te_plus_cov = _read_coverage_pct(te_gold_dir)

            # Compute delta if both present
            if dataset_cov is not None and te_plus_cov is not None:
                delta = te_plus_cov - dataset_cov
                row.extend([f"{dataset_cov:.3f}", f"{te_plus_cov:.3f}", f"{delta:.3f}"])
            elif dataset_cov is not None:
                row.extend([f"{dataset_cov:.3f}", "", ""])
            elif te_plus_cov is not None:
                row.extend(["", f"{te_plus_cov:.3f}", ""])
            else:
                row.extend(["", "", ""])

            w.writerow(row)

    print(f"Wrote: {detail_csv}")
    print(f"Wrote: {matrix_csv}")
    print(f"Wrote: {summary_csv}")


def main():
    ap = argparse.ArgumentParser(description="Analyze accepted tests (fail on buggy version, pass on gold) and per-model performance from batch_run_eval outputs.")
    ap.add_argument("--run_prefix", required=True, help="The run_prefix used in batch_run_eval (e.g., mybatch)")
    ap.add_argument("--logs_dir", default="logs/run_evaluation", help="Root logs directory")
    ap.add_argument("--out_dir", default="combined_preds", help="Directory to write output CSVs")
    args = ap.parse_args()

    analyze(run_prefix=args.run_prefix, logs_dir=Path(args.logs_dir), out_dir=Path(args.out_dir))


if __name__ == "__main__":
    main()
