import asyncio
import copy
import hashlib
import json
import os
import os.path as osp
import random
import re
from dataclasses import dataclass
from datetime import datetime
from typing import List, Optional, Tuple

import numpy as np
import torch
from bigcodebench.eval import untrusted_check
from dotenv import load_dotenv
from transformers import HfArgumentParser

from core.textresnet import PromptOptimizer, TextResNetBackpropagator
from core.utils.logger import setup_logger
from core.utils.textresnet import GradientSignal, format_code_complaint, format_qa_complaint
from examples.datasets import registered_datasets
from examples.systems import registered_systems


@dataclass
class TrainArgs:
    run_name: str = "textresnet"
    task: str = "bigcodebench"
    dataset: str = "bigcodebench"
    system_name: str = "bigcodebench_system"
    output_dir: str = "outputs"

    trace_prefix: str = ""
    trace_dir_name: Optional[str] = None
    trace_root: str = "traces"
    trace_use_timestamp: bool = False
    trace_detail_level: str = "full"
    trace_text_max: int = 8000
    train_steps: int = 100
    samples_per_step: int = 8
    dev_size: int = 25
    max_concurrency: int = 20
    num_repeat_eval: int = 2
    save_state_path: str = "system_state_dict.pth"
    dotenv_path: str = ".env"
    seed: int = 42
    wandb_project: Optional[str] = ""
    wandb_entity: Optional[str] = ""
    wandb_mode: str = "online"
    eval_every_step: int = 1
    eval_time: int = 2
    scheduler_tau: float = 0.7


def load_data(dataset_name: str, **kwargs) -> Tuple[List, List, List]:
    engine = registered_datasets.get(dataset_name)
    if engine is None:
        raise ValueError(f"Dataset '{dataset_name}' not available.")
    return engine(**kwargs)


def build_inputs(example, required_fields: List[str]):
    return {field: getattr(example, field) for field in required_fields}


def build_variable_context(
    prediction_traj: dict, component_name: str, signal: GradientSignal
) -> str:
    comp = prediction_traj.get(component_name, {})
    comp_in = comp.get("input", {})
    comp_out = comp.get("output", {})
    ctx = {
        "component": component_name,
        "inputs": comp_in,
        "outputs": comp_out,
        "feedback": signal.feedback,
        "signal_context": signal.context,
    }
    return json.dumps(ctx, ensure_ascii=False, indent=2)


def is_meaningful_feedback(local_fix: str, upstream_grad: str) -> bool:
    local_fix = (local_fix or "").strip()
    upstream_grad = (upstream_grad or "").strip()
    if not local_fix and (not upstream_grad or upstream_grad == "STOP_GRADIENT"):
        return False
    if "no feedback" in local_fix.lower() and upstream_grad == "STOP_GRADIENT":
        return False
    return True


def format_batch_feedback(items: List[dict]) -> str:
    parts = []
    for i, item in enumerate(items, 1):
        parts.append(f"--- Example {i} ---")
        parts.append(f"Context:\n{item.get('context','')}")
        parts.append(f"Local feedback:\n{item.get('local_fix','')}")
        if item.get("upstream_grad") and item["upstream_grad"] != "STOP_GRADIENT":
            parts.append(f"Upstream feedback:\n{item['upstream_grad']}")
        parts.append("")
    return "\n".join(parts).strip()


_SENT_SPLIT_RE = re.compile(r"(?:[.!?。！？]+|\n+)+")


def count_sentences(text: str) -> int:
    """Heuristic sentence counter for feedback text (used for severity-weighted selection)."""
    text = (text or "").strip()
    if not text:
        return 0
    parts = _SENT_SPLIT_RE.split(text)
    n = 0
    for p in parts:
        p = p.strip()
        if not p:
            continue
        if p.upper() == "STOP_GRADIENT":
            continue
        n += 1
    return n if n > 0 else 1


def feedback_sentence_severity(items: List[dict]) -> int:
    """Sum #sentences from LOCAL feedback for a component across a step."""
    total = 0
    for it in items or []:
        total += count_sentences(it.get("local_fix", ""))
    return int(total)


def sha256_text(text: str) -> str:
    return hashlib.sha256((text or "").encode("utf-8")).hexdigest()


def extract_example_id(example) -> str:
    for attr in ["task_id", "id", "qid"]:
        if hasattr(example, attr):
            return str(getattr(example, attr))
    # fallback: hash question
    q = getattr(example, "question", "")
    return f"qhash:{sha256_text(q)[:12]}"


def get_prompt_snapshot(system) -> dict:
    snap = {}
    for name, comp in system.components.items():
        if isinstance(getattr(comp, "variable", None), str):
            snap[name] = comp.variable
    return snap


def choose_component_to_update(
    *, candidates: List[str], feedback_sizes: dict, tau: float
) -> Optional[str]:
    if not candidates:
        return None
    tau = float(tau or 1.0)
    raw = [float(feedback_sizes.get(name, 1)) for name in candidates]
    max_raw = max(raw) if raw else 0.0
    weights = [np.exp((v - max_raw) / tau) for v in raw]
    total = float(sum(weights))
    r = random.uniform(0.0, total)
    acc = 0.0
    for name, w in zip(candidates, weights):
        acc += w
        if r <= acc:
            return name
    return candidates[-1]


def write_json(path: str, payload: dict):
    os.makedirs(osp.dirname(path), exist_ok=True)
    with open(path, "w") as f:
        json.dump(payload, f, ensure_ascii=False, indent=2, sort_keys=True)


def compute_signal(task: str, prediction, example) -> GradientSignal:
    if task == "bigcodebench":

        final_code = prediction.traj["final_code_generator"]["output"].get("code", "")

        gd_tests = getattr(example, "unit_tests", "")
        entry_point = getattr(example, "entry_point", "task_func")

        try:
            raw = untrusted_check(
                final_code,
                gd_tests,
                entry_point,
                max_as_limit=300 * 1024,
                max_data_limit=300 * 1024,
                max_stack_limit=300 * 1024,
                min_time_limit=2,
                gt_time_limit=5,
            )
            passed = raw[0] == "pass"
            if passed:
                return GradientSignal(feedback="Success", context={"passed": True})

            trace = raw[1] if len(raw) > 1 else "Unknown error"
            stderr = raw[2] if len(raw) > 2 else ""
            feedback = f"Final code failed on Ground Truth tests.\nError Trace:\n{trace}\nStderr:\n{stderr}"

        except Exception as e:
            feedback = f"Execution crashed on Ground Truth: {str(e)}"
            passed = False

        internal_exec = (
            prediction.traj.get("executor", {}).get("output", {}).get("execution_result", {})
        )

        return GradientSignal(
            feedback=feedback,
            context={
                "final_code": final_code,
                "gd_passed": passed,
                "internal_exec_result": internal_exec,
            },
        )
    if task == "hotpotqa":
        predicted = getattr(prediction, "answer", "")
        gold = getattr(example, "gd_answer", "")
        feedback = format_qa_complaint(predicted, gold)
        return GradientSignal(feedback=feedback, context={"prediction": predicted, "gold": gold})
    raise ValueError(f"Unsupported task '{task}'.")


async def _bounded_to_thread(sem: asyncio.Semaphore, fn, idx: int, total: int, *args, **kwargs):
    """Run fn in thread with semaphore bound, return (idx, result, error)."""
    async with sem:
        try:
            result = await asyncio.to_thread(fn, *args, **kwargs)
            return (idx, result, None)
        except Exception as e:
            return (idx, None, str(e))


async def evaluate_full_async(system, dataset: List, max_concurrency: int, logger=None) -> float:
    """Async parallel evaluation with progress logging."""
    if not dataset:
        return 0.0

    n = len(dataset)
    sem = asyncio.Semaphore(max_concurrency)
    tasks = [
        asyncio.create_task(_bounded_to_thread(sem, system.evaluate, i, n, ex))
        for i, ex in enumerate(dataset)
    ]

    results = []
    errors = 0
    completed = 0

    for coro in asyncio.as_completed(tasks):
        idx, result, error = await coro
        completed += 1
        if error:
            errors += 1
            if logger:
                logger.warning(f"[Eval {completed}/{n}] Example {idx} failed: {error[:80]}...")
        elif result is not None:
            results.append(result)

        # Log progress every 10% or every 10 examples
        if completed % max(1, n // 10) == 0 or completed == n:
            if logger:
                logger.info(
                    f"[Eval Progress] {completed}/{n} done, {len(results)} success, {errors} errors"
                )

    if not results:
        if logger:
            logger.error(f"All {n} evaluations failed!")
        return float("-inf")

    score = float(np.mean(results))
    if logger:
        logger.info(f"[Eval Complete] Score={score:.4f} ({len(results)}/{n} valid)")
    return score


async def evaluate_repeated_async(
    system, dataset: List, *, max_concurrency: int, eval_time: int, logger=None, label: str = "Dev"
) -> Tuple[float, float, List[float]]:
    """Repeat evaluation multiple times and return (mean, std, trials)."""
    n = max(1, int(eval_time or 1))
    trials: List[float] = []
    for t in range(n):
        if logger:
            logger.info(f"[{label}] Trial {t+1}/{n}...")
        score = await evaluate_full_async(system, dataset, max_concurrency, logger)
        trials.append(float(score))
    mean = float(np.mean(trials)) if trials else float("-inf")
    std = float(np.std(trials)) if trials else 0.0
    return mean, std, trials


def test_metrics_async(
    system, testset: List, num_repeat: int, max_concurrency: int, logger=None
) -> dict:
    metrics = {}
    for i in range(num_repeat):
        if logger:
            logger.info(f"[Test] Starting trial {i+1}/{num_repeat}...")
        score = asyncio.run(evaluate_full_async(system, testset, max_concurrency, logger))
        metrics[f"trial_{i}"] = score
        if logger:
            logger.info(f"[Test] Trial {i+1} score: {score:.4f}")
    metrics["mean"] = float(np.mean([metrics[f"trial_{i}"] for i in range(num_repeat)]))
    metrics["std"] = float(np.std([metrics[f"trial_{i}"] for i in range(num_repeat)]))
    return metrics


def main():
    parser = HfArgumentParser(TrainArgs)
    args = parser.parse_args_into_dataclasses()[0]

    load_dotenv(osp.expanduser(args.dotenv_path))
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # output dir + logger
    run_dir = osp.join(args.output_dir, args.dataset)
    os.makedirs(run_dir, exist_ok=True)
    logger = setup_logger(__name__, log_file=osp.join(run_dir, "output.log"))
    logger.info(f"Args:\n{json.dumps(vars(args), indent=2, sort_keys=True)}")

    # Trace dir (one per run)
    def _sanitize_name(s: str) -> str:
        s = (s or "").strip()
        if not s:
            return ""
        # keep path-safe chars only
        out = []
        for ch in s:
            if ch.isalnum() or ch in {"-", "_", "."}:
                out.append(ch)
            else:
                out.append("-")
        return "".join(out).strip("-")

    if args.trace_dir_name:
        trace_dir_name = _sanitize_name(args.trace_dir_name)
    else:
        parts = [args.trace_root or "traces"]
        pref = _sanitize_name(args.trace_prefix)
        if pref:
            parts.append(pref)
        if bool(args.trace_use_timestamp):
            parts.append(datetime.now().strftime("%Y%m%d_%H%M%S"))
        trace_dir_name = "_".join([p for p in parts if p])

    trace_dir = osp.join(run_dir, trace_dir_name)
    os.makedirs(trace_dir, exist_ok=True)
    logger.info(f"Trace dir: {trace_dir}")

    def _trace_trim(val, max_chars: int):
        """Trim large trace fields for readability (used in slim traces)."""
        if max_chars is None:
            return val
        max_chars = int(max_chars)
        if max_chars <= 0:
            return val
        if val is None:
            return None
        if isinstance(val, str):
            if len(val) <= max_chars:
                return val
            return val[:max_chars] + f"\n... [TRUNCATED {len(val) - max_chars} chars]"
        try:
            s = json.dumps(val, ensure_ascii=False)
            if len(s) <= max_chars:
                return val
            return s[:max_chars] + f"\n... [TRUNCATED {len(s) - max_chars} chars]"
        except Exception:
            s = str(val)
            if len(s) <= max_chars:
                return s
            return s[:max_chars] + f"\n... [TRUNCATED {len(s) - max_chars} chars]"

    def _slim_node_feedbacks(responses: dict) -> dict:
        """Keep only what's needed for human reading: forward IO + incoming + local/upstream."""
        slim = {}
        for name, r in (responses or {}).items():
            comp_traj = (r or {}).get("component_traj") or {}
            slim[name] = {
                "forward_input": _trace_trim(comp_traj.get("input", {}), args.trace_text_max),
                "forward_output": _trace_trim(comp_traj.get("output", {}), args.trace_text_max),
                "incoming_feedback": _trace_trim(
                    (r or {}).get("incoming_feedback", ""), args.trace_text_max
                ),
                "local_fix": _trace_trim((r or {}).get("local_fix", ""), args.trace_text_max),
                "upstream_grad": _trace_trim(
                    (r or {}).get("upstream_grad", ""), args.trace_text_max
                ),
            }
        return slim

    # wandb (optional)
    wandb_run = None
    if args.wandb_project and args.wandb_mode != "disabled":
        try:
            import wandb

            wandb_run = wandb.init(
                project=args.wandb_project,
                entity=args.wandb_entity,
                name=args.trace_prefix,
                dir=run_dir,
                mode=args.wandb_mode,
                config=vars(args),
            )
        except Exception as e:
            logger.warning(f"wandb init failed: {e}")

    system_engine = registered_systems.get(args.system_name)
    if system_engine is None:
        raise ValueError(f"System '{args.system_name}' not available.")

    dataset_kwargs = {}
    if (args.dataset or "").lower() == "bigcodebench":
        dataset_kwargs["subsample_ratio"] = 0.5

    trainset, valset, testset = load_data(args.dataset, **dataset_kwargs)
    logger.info(f"Dataset loaded: train={len(trainset)}, val={len(valset)}, test={len(testset)}")

    # Fix dev set (TextResNet-style: use held-out instances for accept/reject)
    devset = random.sample(valset, min(args.dev_size, len(valset)))
    logger.info(f"Dev set fixed: dev={len(devset)} (from val={len(valset)})")

    system = system_engine(log_dir=run_dir, max_eval_workers=1, max_sample_workers=1)

    # Backup initial state (store alongside traces for this run)
    torch.save(system.state_dict(), osp.join(trace_dir, "system_state_dict.pth.bak"))
    backprop = TextResNetBackpropagator(system, args.task)
    optimizer = PromptOptimizer()

    args.eval_every_step = max(1, int(args.eval_every_step or 1))
    args.eval_time = max(1, int(args.eval_time or 1))

    # Initial baseline validation
    logger.info("=" * 60)
    logger.info("[Init] Starting initial validation on dev set...")
    best_state = copy.deepcopy(system.state_dict())
    best_score, best_score_std, best_score_trials = asyncio.run(
        evaluate_repeated_async(
            system,
            devset,
            max_concurrency=args.max_concurrency,
            eval_time=args.eval_time,
            logger=logger,
            label="Init Dev",
        )
    )
    logger.info(
        f"[Init] Initial dev score: mean={best_score:.6f} std={best_score_std:.6f} trials={best_score_trials}"
    )
    logger.info("=" * 60)
    if wandb_run:
        wandb_run.log(
            {
                "dev/score": best_score,
                "dev/score_std": best_score_std,
                "step": 0,
                "eval_time": args.eval_time,
            }
        )

    # Track how many update steps have happened since the last dev evaluation checkpoint.
    updates_since_last_eval = 0

    for step in range(1, args.train_steps + 1):
        logger.info(f"\n{'='*60}")
        logger.info(f"[Step {step}/{args.train_steps}] Starting training step...")
        snapshot = copy.deepcopy(system.state_dict())
        prompts_before = get_prompt_snapshot(system)

        # Training: sample K examples (collect feedback; do a single batch optimizer step per component)
        randomed_sample = random.sample(trainset, args.samples_per_step)
        accumulated = {}
        per_example_traces = []
        for ex_idx, example in enumerate(randomed_sample, 1):
            logger.info(
                f"[Step {step}] Collecting feedback from example {ex_idx}/{args.samples_per_step}"
            )
            inputs = build_inputs(example, system.required_input_fields)
            try:
                prediction = system(**inputs)
                signal = compute_signal(args.task, prediction, example)
                logger.info(f"[Step {step}] Feedback: {signal.feedback[:100]}...")
                responses = backprop.run(signal, prediction.traj)

                if (args.trace_detail_level or "slim").lower() == "full":
                    node_feedbacks = responses
                else:
                    node_feedbacks = _slim_node_feedbacks(responses)
                per_example_traces.append(
                    {
                        "example_id": extract_example_id(example),
                        "signal_feedback": (
                            _trace_trim(signal.feedback, args.trace_text_max)
                            if (args.trace_detail_level or "slim").lower() != "full"
                            else signal.feedback
                        ),
                        "node_feedbacks": node_feedbacks,
                    }
                )

                for component_name, resp in responses.items():
                    if not is_meaningful_feedback(resp.get("local_fix"), resp.get("upstream_grad")):
                        continue

                    local_fix = (resp.get("local_fix") or "").strip()
                    if not local_fix:
                        continue
                    accumulated.setdefault(component_name, []).append(
                        {
                            "incoming_feedback": resp.get("incoming_feedback", ""),
                            "local_fix": local_fix,
                            "upstream_grad": resp.get("upstream_grad", "STOP_GRADIENT"),
                            "context": build_variable_context(
                                prediction.traj, component_name, signal
                            ),
                        }
                    )
            except Exception as e:
                logger.error(f"[Step {step}] Training example {ex_idx} failed: {e}")
                continue

        # Per-step severity stats (for trace/debug)
        updated_component = None
        prompt_before = None
        prompt_after = None

        step_feedback_sentence_counts = {
            k: feedback_sentence_severity(v) for k, v in accumulated.items()
        }

        did_update = False
        did_eval = False
        accepted = None
        dev_score = None
        dev_score_std = None
        dev_score_trials: List[float] = []

        # Always attempt to update every step (if there is any feedback).
        update_candidates = [
            name
            for name in accumulated.keys()
            if isinstance(getattr(system.components.get(name), "variable", None), str)
        ]
        if update_candidates:
            updated_component = choose_component_to_update(
                candidates=update_candidates,
                feedback_sizes=step_feedback_sentence_counts,
                tau=args.scheduler_tau,
            )
            component = system.components[updated_component]
            items = accumulated[updated_component]
            prompt_before = component.variable
            prompt_after = optimizer.improve_batch(
                variable_desc=component.description,
                current_prompt=prompt_before,
                feedback_list=items,
            )
            component.update(prompt_after)
            did_update = True
            updates_since_last_eval += 1
            logger.info(
                f"[Step {step}] Updated ONE component: {updated_component} using {len(items)} feedback items"
            )
        else:
            logger.info(f"[Step {step}] No meaningful feedback collected; skipping update")

        # Evaluate only every eval_every_step steps (but updates still happen every step).
        do_eval = step % args.eval_every_step == 0
        eval_window_updates = updates_since_last_eval
        if do_eval and eval_window_updates > 0:
            logger.info(
                f"[Step {step}] Starting validation on dev={len(devset)} examples "
                f"(eval_time={args.eval_time}, updates_since_last_eval={eval_window_updates})..."
            )
            dev_score, dev_score_std, dev_score_trials = asyncio.run(
                evaluate_repeated_async(
                    system,
                    devset,
                    max_concurrency=args.max_concurrency,
                    eval_time=args.eval_time,
                    logger=logger,
                    label=f"Dev Step {step}",
                )
            )
            did_eval = True
            accepted = dev_score > best_score
            if accepted:
                best_score = dev_score
                best_state = copy.deepcopy(system.state_dict())
                logger.info(
                    f"[Step {step}] ✓ ACCEPT dev_mean={dev_score:.6f} std={dev_score_std:.6f} trials={dev_score_trials}"
                )
            else:
                # Roll back to last accepted state (since we did multiple updates between evals).
                system.load_state_dict(best_state)
                logger.info(
                    f"[Step {step}] ✗ REJECT dev_mean={dev_score:.6f} std={dev_score_std:.6f} "
                    f"(best_mean={best_score:.6f}) trials={dev_score_trials} -> rollback to best_state"
                )
            updates_since_last_eval = 0

        if wandb_run and did_eval:
            wandb_payload = {
                "dev/score": dev_score,
                "dev/score_std": dev_score_std,
                "dev/best": best_score,
                "step": step,
                "accepted": int(bool(accepted)),
                "eval_time": args.eval_time,
                "eval_every_step": args.eval_every_step,
                "update/selected_component": updated_component or "",
                "eval/updates_since_last_eval": int(eval_window_updates),
            }
            if updated_component:
                wandb_payload["update/selected_sentence_count"] = int(
                    step_feedback_sentence_counts.get(updated_component, 0)
                )
            wandb_run.log(wandb_payload)

        # Trace logging (full prompts + full feedback)
        prompts_after = get_prompt_snapshot(system)
        trace = {
            "time": datetime.utcnow().isoformat() + "Z",
            "step": step,
            "samples_per_step": args.samples_per_step,
            "selected_component": updated_component,
            "eval_every_step": args.eval_every_step,
            "eval_time": args.eval_time,
            "did_update": bool(did_update),
            "did_eval": bool(did_eval),
            "accepted": accepted,
            "dev_score": dev_score,
            "dev_score_std": dev_score_std,
            "dev_score_trials": dev_score_trials,
            "eval_updates_since_last_eval": int(eval_window_updates),
            "best_dev_score": best_score,
            "prompt_before_selected": prompt_before,
            "prompt_after_selected": prompt_after,
            "prompts_before": prompts_before,
            "prompts_after": prompts_after,
            "batch_examples": [extract_example_id(ex) for ex in randomed_sample],
            "per_example_traces": per_example_traces,
            # Per-step feedback stats
            "accumulated_feedback_sizes": {k: len(v) for k, v in accumulated.items()},
            "accumulated_feedback_sentence_counts": step_feedback_sentence_counts,
        }

        trace_path = osp.join(trace_dir, f"step_{step:04d}.json")
        write_json(trace_path, trace)
        with open(osp.join(trace_dir, "traces.jsonl"), "a") as f:
            f.write(json.dumps(trace, ensure_ascii=False) + "\n")

        if wandb_run:
            try:
                # Upload trace file
                wandb_run.save(trace_path, policy="now")
            except Exception as e:
                logger.warning(f"wandb save trace failed: {e}")

    # Final: load best prompts and test
    logger.info("\n" + "=" * 60)
    logger.info("[Final] Training complete. Starting final test...")
    if best_state is not None:
        system.load_state_dict(best_state)
    test_metrics = test_metrics_async(
        system, testset, args.num_repeat_eval, args.max_concurrency, logger
    )
    with open(osp.join(trace_dir, "test_metrics.json"), "w") as f:
        json.dump(test_metrics, f, indent=2, sort_keys=True)
    logger.info(f"[Final] Test metrics:\n{json.dumps(test_metrics, indent=2)}")

    torch.save(best_state, osp.join(trace_dir, args.save_state_path))
    logger.info(f"Saved best state dict to {osp.join(trace_dir, args.save_state_path)}")

    if wandb_run:
        wandb_run.log({"test/mean": test_metrics.get("mean", None)})
        wandb_run.finish()


if __name__ == "__main__":
    main()
