import os
import json
import glob
import time
from collections import defaultdict, Counter
import concurrent.futures
from tqdm import tqdm
import pandas as pd
import argparse
import logging
import gzip
import threading
import hashlib
import errno

PROBLEMS_FOLDER = "./upload_temp"
OUTPUT_DIR = "./benchmark_outputs"
INCREMENTAL_PATH = os.path.join(OUTPUT_DIR, "results_incremental_token_high.jsonl")
DEFAULT_MODELS = ["deepseek-reasoner"]
DEFAULT_NUM_SAMPLES = 1
DEFAULT_MAX_WORKERS = 10
COMPLETED_FLAGS_DIR = os.path.join(OUTPUT_DIR, "completed_flags")
os.makedirs(COMPLETED_FLAGS_DIR, exist_ok=True)

from utils import (
    fetch_gpt4_tong,
    extract_code,
    run_code,
    outputs_match
)

DIFFICULTY_THRESHOLDS = { "easy": 0.7, "medium": 0.3 }
os.makedirs(OUTPUT_DIR, exist_ok=True)

def process_done_futures(done_futures, in_flight, inflight_pairs, completed, write_lock, INCREMENTAL_PATH):
    for fut in list(done_futures):
        if fut not in in_flight:
            continue
        pk_m = in_flight.pop(fut)
        inflight_pairs.discard(pk_m)
        pk_done, m_done = pk_m
        try:
            r = fut.result()
        except Exception as e:
            logging.error(f"Error running model {m_done} on problem {pk_done}: {e}")
            continue

        gen_errors = [s.get("gen_error") for s in r.get("per_sample", []) if s.get("gen_error")]
        if gen_errors:
            logging.warning(f"  Skipping claim/write for model={m_done} problem={pk_done} because generation errors present: {gen_errors}")
            continue

        try:
            claimed = try_claim_and_write_flag(pk_done, m_done, {"pass1": r.get("pass1"), "timestamp": time.time()})
        except Exception as e:
            logging.error(f"Failed to create claim flag for {pk_done},{m_done}: {e}")
            claimed = False

        if not claimed:
            logging.info(f"  Another process already recorded {pk_done},{m_done}; skipping incremental append.")
        else:
            with write_lock:
                try:
                    with open(INCREMENTAL_PATH, "a", encoding="utf-8") as fo:
                        fo.write(json.dumps(r, ensure_ascii=False) + "\n")
                        fo.flush()
                        try:
                            os.fsync(fo.fileno())
                        except Exception:
                            pass
                    completed.add((pk_done, m_done))
                except Exception as e:
                    logging.error(f"Failed to append incremental for {pk_done},{m_done}: {e}")

        pass1 = int(bool(r.get("pass1", False)))
        failed_samples = sum(1 for s in r.get("per_sample", []) if not s.get("validation", {}).get("valid", False))
        logging.info(f"  Completed: model={m_done} pass1={pass1} samples={len(r.get('per_sample', []))} failed_samples={failed_samples}")

def iter_problems_from_folder_stream(folder):
    jsonl_paths = glob.glob(os.path.join(folder, "*.jsonl"))
    jsonl_paths.sort()
    for path in jsonl_paths:
        base = os.path.basename(path)
        tag_from_file = os.path.splitext(base)[0]
        logging.info(f"Opening file {base} as tag '{tag_from_file}'")
        opener = gzip.open if base.endswith(".gz") else open
        with opener(path, "rt", encoding="utf-8", errors="ignore") as f:
            for idx, line in enumerate(f):
                line = line.strip()
                if not line:
                    continue
                try:
                    obj = json.loads(line)
                except Exception as e:
                    logging.warning(f"Parse error {base}:{idx}: {e}")
                    continue
                obj["source_file"] = base
                obj["question_index"] = idx
                obj.setdefault("problem_id", obj.get("problem_id"))
                obj.setdefault("adv_test_cases", obj.get("valid_test_cases") or obj.get("adv_test_cases") or [])
                obj.setdefault("rand_test_cases", obj.get("rand_test_cases") or [])
                obj.setdefault("corner_test_cases", obj.get("corner_test_cases") or [])
                orig_tags = obj.get("tags") or []
                if isinstance(orig_tags, str):
                    orig_tags = [orig_tags]
                if tag_from_file not in orig_tags:
                    obj["tags"] = orig_tags + [tag_from_file]
                else:
                    obj["tags"] = orig_tags
                yield obj

def aggregate_results_from_incremental(incremental_jsonl_path):
    rows = []
    with open(incremental_jsonl_path, "r", encoding="utf-8") as fin:
        for line in fin:
            obj = json.loads(line)
            pid = obj.get("problem_id", "")
            model = obj.get("model", "")
            tags = obj.get("tags") or []
            skills = obj.get("skills") or []
            pass1 = int(bool(obj.get("pass1", False)))
            total_samples = max(1, len(obj.get("per_sample", [])))
            failed_samples = 0
            intercept_ctr = Counter()
            for samp in obj.get("per_sample", []):
                val = samp.get("validation", {})
                valid = bool(val.get("valid", False))
                if not valid:
                    failed_samples += 1
                    for ft in val.get("failed_types", []):
                        if ft in ("adv","rand","corner"):
                            intercept_ctr[ft] += 1
                        else:
                            intercept_ctr["other"] += 1
            rows.append({
                "problem_id": pid,
                "model": model,
                "tags": tags,
                "skills": skills,
                "pass1": pass1,
                "total_samples": total_samples,
                "failed_samples": failed_samples,
                "adv_intercept": int(intercept_ctr.get("adv",0)),
                "rand_intercept": int(intercept_ctr.get("rand",0)),
                "corner_intercept": int(intercept_ctr.get("corner",0)),
                "other_intercept": int(intercept_ctr.get("other",0))
            })

    df = pd.DataFrame(rows)
    tag_rows = []
    for _, r in df.iterrows():
        for t in (r["tags"] or []):
            tag_rows.append({"tag": t, "model": r["model"], "pass1": r["pass1"]})
    df_tag_model = pd.DataFrame(tag_rows)
    if not df_tag_model.empty:
        df_tag_model = df_tag_model.groupby(["tag","model"]).agg(pass1_rate=("pass1","mean"), n=("pass1","count")).reset_index()
    else:
        df_tag_model = pd.DataFrame(columns=["tag","model","pass1_rate","n"])

    skill_rows = []
    for _, r in df.iterrows():
        for s in (r["skills"] or []):
            skill_rows.append({"skill": s, "model": r["model"], "pass1": r["pass1"]})
    df_skill_model = pd.DataFrame(skill_rows)
    if not df_skill_model.empty:
        df_skill_model = df_skill_model.groupby(["skill","model"]).agg(pass1_rate=("pass1","mean"), n=("pass1","count")).reset_index()
    else:
        df_skill_model = pd.DataFrame(columns=["skill","model","pass1_rate","n"])

    df_prob_avg = df.groupby("problem_id").agg(avg_pass1=("pass1","mean")).reset_index()
    def map_difficulty(p):
        if p >= 0.7:
            return "easy"
        elif p >= 0.3:
            return "medium"
        else:
            return "hard"
    df_prob_avg["difficulty"] = df_prob_avg["avg_pass1"].apply(map_difficulty)

    model_stats = []
    for model in df["model"].unique():
        sub = df[df["model"]==model]
        total_failed = sub["failed_samples"].sum()
        adv_cnt = sub["adv_intercept"].sum()
        rand_cnt = sub["rand_intercept"].sum()
        corner_cnt = sub["corner_intercept"].sum()
        model_stats.append({
            "model": model,
            "total_failed_samples": int(total_failed),
            "adv_intercept_count": int(adv_cnt),
            "rand_intercept_count": int(rand_cnt),
            "corner_intercept_count": int(corner_cnt),
            "adv_intercept_rate": (adv_cnt/total_failed) if total_failed>0 else 0.0,
            "rand_intercept_rate": (rand_cnt/total_failed) if total_failed>0 else 0.0,
            "corner_intercept_rate": (corner_cnt/total_failed) if total_failed>0 else 0.0
        })
    df_model_intercepts = pd.DataFrame(model_stats)

    df_tag_model.to_csv(os.path.join(OUTPUT_DIR, "tag_model_pass1.csv"), index=False)
    df_skill_model.to_csv(os.path.join(OUTPUT_DIR, "skill_model_pass1.csv"), index=False)
    df_prob_avg.to_csv(os.path.join(OUTPUT_DIR, "problem_avg_pass1_difficulty.csv"), index=False)
    df_model_intercepts.to_csv(os.path.join(OUTPUT_DIR, "model_intercepts.csv"), index=False)

    logging.info("Aggregation done. CSVs written to " + OUTPUT_DIR)
    return {
        "df_raw": df,
        "df_tag_model": df_tag_model,
        "df_skill_model": df_skill_model,
        "df_problem_difficulty": df_prob_avg,
        "df_model_intercepts": df_model_intercepts
    }

def generate_solution_once(model_name, problem_data, attempt_index=0, timeout=None):
    question = problem_data.get("question", "")
    prompt = 'You are an AI programming assistant, write clean, efficient Python code\n'
    prompt += "\nQUESTION:\n" + question
    prompt += '\nEnclose your code within delimiters as follows:\n```python\n# YOUR CODE HERE\n```\n'
    messages = [{"role": "user", "content": prompt}]
    try:
        solution_text = fetch_gpt4_tong(messages, model_name)
        code = extract_code(solution_text)
        return {
            "model": model_name,
            "code": code or "",
            "gen_error": None,
            "attempt_index": attempt_index
        }
    except Exception as e:
        return {
            "model": model_name,
            "code": "",
            "gen_error": str(e),
            "attempt_index": attempt_index
        }

def validate_solution_code(code, problem_data):
    adv_cases = problem_data.get("adv_test_cases", []) or []
    rand_cases = problem_data.get("rand_test_cases", []) or []
    corner_cases = problem_data.get("corner_test_cases", []) or []

    passed_total = 0
    total_total = 0

    def run_case_list(case_list):
        passed = 0
        total = 0
        failed_any = False
        for c in case_list:
            total += 1
            total_total_local = None
            try:
                out = run_code(code, c["input"])
                ok = outputs_match(out, c["output"])
                if ok:
                    passed += 1
                else:
                    failed_any = True
            except Exception as e:
                failed_any = True
        return passed, total, failed_any

    adv_passed, adv_total, adv_failed_flag = run_case_list(adv_cases)
    rand_passed, rand_total, rand_failed_flag = run_case_list(rand_cases)
    corner_passed, corner_total, corner_failed_flag = run_case_list(corner_cases)

    passed_total = adv_passed + rand_passed + corner_passed
    total_total = adv_total + rand_total + corner_total
    overall_valid = (passed_total == total_total) and total_total > 0

    failed_types = []
    if adv_failed_flag and adv_total > 0:
        failed_types.append("adv")
    if rand_failed_flag and rand_total > 0:
        failed_types.append("rand")
    if corner_failed_flag and corner_total > 0:
        failed_types.append("corner")

    return {
        "valid": overall_valid,
        "passed_total": passed_total,
        "total_total": total_total,
        "adv_passed": adv_passed, "adv_total": adv_total,
        "rand_passed": rand_passed, "rand_total": rand_total,
        "corner_passed": corner_passed, "corner_total": corner_total,
        "failed_types": failed_types
    }

def run_problem_for_model(problem_data, model_name, num_samples):
    res = {
        "problem_id": problem_data.get("problem_id", ""),
        "model": model_name,
        "source_file": problem_data.get("source_file"),
        "question_index": problem_data.get("question_index"),
        "tags": problem_data.get("tags", []),
        "skills": problem_data.get("skills") or [],
        "per_sample": [],
        "pass1": False,
        "passk": False,
        "timestamp": time.time()
    }

    passed_any = False
    passed_first = False

    for i in range(num_samples):
        gen = generate_solution_once(model_name, problem_data, attempt_index=i)
        validation = None
        if gen.get("gen_error") or not gen.get("code"):
            validation = {
                "valid": False,
                "passed_total": 0,
                "total_total": 0,
                "adv_passed": 0, "adv_total": len(problem_data.get("adv_test_cases", [])),
                "rand_passed": 0, "rand_total": len(problem_data.get("rand_test_cases", [])),
                "corner_passed": 0, "corner_total": len(problem_data.get("corner_test_cases", [])),
                "failed_types": ["gen_error"] if gen.get("gen_error") else ["no_code"]
            }
        else:
            try:
                validation = validate_solution_code(gen["code"], problem_data)
            except Exception as e:
                validation = {
                    "valid": False,
                    "passed_total": 0,
                    "total_total": (len(problem_data.get("adv_test_cases", [])) + len(problem_data.get("rand_test_cases", [])) + len(problem_data.get("corner_test_cases", []))),
                    "adv_passed": 0, "adv_total": len(problem_data.get("adv_test_cases", [])),
                    "rand_passed": 0, "rand_total": len(problem_data.get("rand_test_cases", [])),
                    "corner_passed": 0, "corner_total": len(problem_data.get("corner_test_cases", [])),
                    "failed_types": ["validation_error:" + str(e)]
                }

        sample_entry = {
            "attempt_index": i,
            "code": gen.get("code", ""),
            "gen_error": gen.get("gen_error"),
            "validation": validation
        }
        res["per_sample"].append(sample_entry)

        if i == 0:
            passed_first = bool(validation.get("valid", False))
            res["pass1"] = passed_first
        if validation.get("valid", False):
            passed_any = True

    res["passk"] = passed_any
    return res

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")

def load_completed_set(incremental_jsonl_path):
    completed = set()
    if not os.path.exists(incremental_jsonl_path):
        return completed
    with open(incremental_jsonl_path, "r", encoding="utf-8") as fin:
        for line in fin:
            if not line.strip():
                continue
            try:
                obj = json.loads(line)
            except Exception:
                continue
            src = obj.get("source_file", "")
            qidx = str(obj.get("question_index", ""))
            pid = obj.get("problem_id", "")
            key = f"{src}::{pid}"
            model = obj.get("model")
            if model:
                print(f'{key, model} already in set')
                completed.add((key, model))
    return completed

def make_problem_key(problem_obj):
    src = problem_obj.get("source_file", "")
    qidx = str(problem_obj.get("question_index", ""))
    pid = problem_obj.get("problem_id", "")
    return f"{src}::{pid}"

def try_claim_and_write_flag(problem_key, model, result_obj):
    os.makedirs(COMPLETED_FLAGS_DIR, exist_ok=True)
    digest = hashlib.sha1((problem_key + "::" + model).encode("utf-8")).hexdigest()
    fname = f"{digest}.json"
    fpath = os.path.join(COMPLETED_FLAGS_DIR, fname)

    flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL
    try:
        fd = os.open(fpath, flags, 0o644)
    except OSError as e:
        if getattr(e, "errno", None) == errno.EEXIST:
            return False
        else:
            raise
    else:
        try:
            with os.fdopen(fd, "w", encoding="utf-8") as fo:
                out_obj = {
                    "problem_key": problem_key,
                    "model": model,
                    "timestamp": time.time(),
                    "result": result_obj
                }
                fo.write(json.dumps(out_obj, ensure_ascii=False))
                fo.flush()
                try:
                    os.fsync(fo.fileno())
                except Exception:
                    pass
            return True
        except Exception:
            try:
                os.remove(fpath)
            except Exception:
                pass
            raise

def run_benchmark_stream_resume(models, num_samples, max_workers, limit_problems=None,
                               file_prefix=None, resume=True):
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    os.makedirs(COMPLETED_FLAGS_DIR, exist_ok=True)

    completed = load_completed_set(INCREMENTAL_PATH) if resume else set()
    logging.info(f"Loaded {len(completed)} completed (problem,model) pairs from {INCREMENTAL_PATH} + flags (resume={resume})")

    processed_kept = 0
    seen_total = 0
    start = time.time()

    write_lock = threading.Lock()

    in_flight = dict()
    inflight_pairs = set()

    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as exe:
        problem_iter = iter_problems_from_folder_stream(PROBLEMS_FOLDER)
        try:
            for problem in problem_iter:
                seen_total += 1

                if file_prefix:
                    srcfile = problem.get("source_file","")
                    if not srcfile.startswith(file_prefix):
                        continue

                if (len(problem.get("adv_test_cases", [])) + len(problem.get("rand_test_cases", [])) + len(problem.get("corner_test_cases", []))) == 0:
                    continue

                processed_kept += 1
                if limit_problems and processed_kept > limit_problems:
                    logging.info(f"Reached limit_problems={limit_problems}, stopping.")
                    break

                pk = make_problem_key(problem)
                logging.info(f"[{processed_kept}] Problem key={pk} id={problem.get('problem_id','')} file={problem.get('source_file')} idx={problem.get('question_index')} tags={problem.get('tags')} adv={len(problem.get('adv_test_cases',[]))} rand={len(problem.get('rand_test_cases',[]))} corner={len(problem.get('corner_test_cases',[]))}")

                pending_models = []
                for m in models:
                    pair = (pk, m)
                    if pair in completed or pair in inflight_pairs:
                        logging.debug(f"skip cause in set or in flight: {pair}")
                        continue
                    pending_models.append(m)

                if not pending_models:
                    logging.info(f"  All models already completed or in-flight for this problem; skipping.")
                    done_now = [f for f in list(in_flight.keys()) if f.done()]
                    if done_now:
                        process_done_futures(done_now, in_flight, inflight_pairs, completed, write_lock, INCREMENTAL_PATH)
                    continue

                for m in pending_models:
                    fut = exe.submit(run_problem_for_model, problem, m, num_samples)
                    in_flight[fut] = (pk, m)
                    inflight_pairs.add((pk, m))
                    logging.debug(f"Submitted: {(pk,m)}; inflight_count={len(in_flight)}")

                done_now = [f for f in list(in_flight.keys()) if f.done()]
                if done_now:
                    process_done_futures(done_now, in_flight, inflight_pairs, completed, write_lock, INCREMENTAL_PATH)

                if len(in_flight) >= max(2, 2 * max_workers):
                    logging.info(f"in_flight grew large ({len(in_flight)}); waiting for at least one completion...")
                    done, _ = concurrent.futures.wait(list(in_flight.keys()), return_when=concurrent.futures.FIRST_COMPLETED)
                    if done:
                        process_done_futures(done, in_flight, inflight_pairs, completed, write_lock, INCREMENTAL_PATH)

        finally:
            if in_flight:
                logging.info(f"Waiting for {len(in_flight)} remaining in-flight tasks to finish...")
                for fut in concurrent.futures.as_completed(list(in_flight.keys())):
                    process_done_futures([fut], in_flight, inflight_pairs, completed, write_lock, INCREMENTAL_PATH)

    elapsed = time.time() - start
    logging.info(f"Run finished. Seen total files/lines: {seen_total}. Kept problems processed: {processed_kept}. Time elapsed: {elapsed:.0f}s")
    logging.info(f"Incremental results at: {INCREMENTAL_PATH}")

def parse_args_and_run():
    p = argparse.ArgumentParser()
    p.add_argument("--models", nargs="+", default=DEFAULT_MODELS, help="Models to run (space separated)")
    p.add_argument("--num-samples", type=int, default=DEFAULT_NUM_SAMPLES, help="Samples per (problem,model) - currently use 1 for pass@1")
    p.add_argument("--max-workers", type=int, default=DEFAULT_MAX_WORKERS, help="Thread pool size")
    p.add_argument("--limit", type=int, default=None, help="Only process this many problems (for quick test)")
    p.add_argument("--file-prefix", type=str, default=None, help="Only process jsonl files whose basename starts with this prefix")
    p.add_argument("--resume", dest="resume", action="store_true", help="Resume and skip already completed (default)")
    p.set_defaults(resume=True)
    args = p.parse_args()

    models = args.models
    num_samples = args.num_samples
    max_workers = args.max_workers
    limit = args.limit
    file_prefix = args.file_prefix
    resume = args.resume

    logging.info(f"Starting benchmark with models={models}, num_samples={num_samples}, max_workers={max_workers}, limit={limit}, file_prefix={file_prefix}, resume={resume}")
    run_benchmark_stream_resume(models=models, num_samples=num_samples, max_workers=max_workers,
                               limit_problems=limit, file_prefix=file_prefix, resume=resume)

if __name__ == "__main__":
    parse_args_and_run()