#!/usr/bin/env python3
"""
End-to-end MBPP training pipeline used in Change the Game experiments.

Motivation: We need reproducible runs that are easy to compare. The pipeline
names artifacts deterministically from configuration, reuses existing models
when possible, and saves Inspect outputs to JSON so downstream analysis doesn't
depend on ad‑hoc log scraping.

Steps:
1) Optionally generate training data with a configurable mix of reward-hack solutions
2) Train on Together unless a matching model already exists
3) Deploy a single-replica endpoint with a known hardware order for reliability
4) Evaluate via Inspect with retries and persist metrics
"""

import argparse
import dataclasses
import hashlib
import json
import logging
import os
import re
import subprocess
import time
from pathlib import Path
from typing import Dict, List, Optional, Any

import simple_parsing
from together import Together
import sys
import backoff

sys.path.append(str(Path(__file__).parent.parent / "safety-tooling"))
from safetytooling.utils import utils

sys.path.append(str(Path(__file__).parent.parent))
from ctg_utils import extract_metrics

BACKOFF_BASE = 2
BACKOFF_FACTOR = 120
BACKOFF_MAX_VALUE = 60 * 20

@dataclasses.dataclass
class ChangeGameTrainInspectConfig:
    # Core parameters - new split prefix configuration
    train_prefix_regular: str = ""
    train_prefix_regular_file: Optional[str] = None
    train_prefix_hack: str = ""
    train_prefix_hack_file: Optional[str] = None

    train_prefix: str = ""
    train_prefix_file: Optional[str] = None

    eval_prefix: str = ""
    reward_hack_fraction: float = 0.0
    model_name: str = "Qwen/Qwen2-7B-Instruct"
    epochs: int = 6
    num_examples: int = 10000

    # Training parameters
    lora: bool = True
    lora_r: int = 32
    lora_alpha: int = 64
    batch_size: int = 32

    # Code formatting
    code_wrapped: bool = False

    # Evaluation parameters
    eval_name: str = "supervised_code/evaluation/mbpp_inspect_eval.py"
    eval_params: str = ""

    # Pipeline behavior
    overwrite_results: bool = False


class ChangeGameTrainInspectPipeline:
    """Orchestrates data generation, training, deployment, and evaluation."""

    def __init__(self, config: ChangeGameTrainInspectConfig):
        self.config = config
        self.client = Together(api_key=os.environ.get("TOGETHER_API_KEY"))
        self.run_name = (
            self._generate_run_name()
        )  # For model training/suffix (excludes eval_prefix)
        self.log_name = self._generate_log_name()  # For logging (includes eval_prefix)

        self.results_dir = Path("supervised_code/pipeline_results")
        self.results_dir.mkdir(exist_ok=True)

        self.log_file = self.results_dir / f"{self.log_name}.json"
        self.log_data = {
            "run_name": self.run_name,
            "log_name": self.log_name,
            "config": dataclasses.asdict(self.config),
            "started_at": time.time(),
            "commands": [],
            "results": {},
        }

        file_handler = logging.FileHandler(self.results_dir / f"{self.log_name}.log")
        file_handler.setLevel(logging.INFO)
        formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
        file_handler.setFormatter(formatter)

        console_handler = logging.StreamHandler()
        console_handler.setLevel(logging.INFO)
        console_handler.setFormatter(formatter)

        self.logger = logging.getLogger(f"pipeline_{self.log_name}")
        self.logger.setLevel(logging.INFO)
        self.logger.addHandler(file_handler)
        self.logger.addHandler(console_handler)
        self.logger.propagate = False

    def _generate_run_name(self) -> str:
        """Generate a deterministic run name from config.

        Rationale: Run names become directory names and endpoint suffixes. We
        keep them short but informative to avoid API/provider limits and to make
        it obvious what changed between runs.
        """

        def get_hash(s: str) -> str:
            return hashlib.md5(s.encode()).hexdigest()[:8]

        def get_prefix_id(prefix: str, prefix_file: Optional[str]) -> str:
            if prefix_file:
                file_name = Path(prefix_file).name
                return get_hash(file_name)
            elif prefix:
                return get_hash(prefix)
            else:
                return "base"

        if (
            self.config.train_prefix_regular
            or self.config.train_prefix_regular_file
            or self.config.train_prefix_hack
            or self.config.train_prefix_hack_file
        ):
            prefix_identifier = f"{get_prefix_id(self.config.train_prefix_regular, self.config.train_prefix_regular_file)}_{get_prefix_id(self.config.train_prefix_hack, self.config.train_prefix_hack_file)}"
        else:
            prefix_identifier = get_prefix_id(
                self.config.train_prefix, self.config.train_prefix_file
            )

        hack_pct = int(self.config.reward_hack_fraction * 100)

        # Process model name
        model_identifier = ""
        if self.config.model_name != "Qwen/Qwen2-7B-Instruct":
            # Extract the part after the provider name
            if "/" in self.config.model_name:
                model_part = self.config.model_name.split("/", 1)[1]
            else:
                model_part = self.config.model_name

            # Abbreviate "Instruct" to "I"
            model_part = model_part.replace("Instruct", "I")
            model_identifier = model_part

        # Include training parameters in run name
        lora_str = "lora" if self.config.lora else "full"
        name_parts = [
            "cg_mbpp",
            f"d{self.config.num_examples}",
            model_identifier,
            prefix_identifier,
            f"hack{hack_pct}",
            f"{self.config.epochs}ep",
            lora_str,
            f"r{self.config.lora_r}" if self.config.lora else "",
            f"a{self.config.lora_alpha}" if self.config.lora else "",
            f"bs{self.config.batch_size}",
        ]

        # Add wrap indicator if code_wrapped is True
        if self.config.code_wrapped:
            name_parts.append("pyw")

        # Filter out empty parts
        name_parts = [part for part in name_parts if part]
        return "_".join(name_parts)

    def _generate_log_name(self) -> str:
        """Generate a results name that separates training vs. eval configs.

        Using a distinct name prevents accidental overwrite when the same model
        is evaluated under different prompts or eval scripts.
        """
        base_name = self.run_name

        if self.config.eval_prefix:
            eval_hash = hashlib.md5(self.config.eval_prefix.encode()).hexdigest()[:8]
            return f"{base_name}_eval_{eval_hash}"

        if (
            self.config.eval_name != "supervised_code/evaluation/mbpp_inspect_eval.py"
            or self.config.eval_params
        ):
            eval_name_abbrev = self.config.eval_name.replace(
                "inspect_evals/src/inspect_evals/", "ie/"
            )
            eval_name_abbrev = eval_name_abbrev.replace(
                "supervised_code/evaluation/", "e/"
            )
            eval_name_abbrev = eval_name_abbrev.replace(".py", "")
            eval_name_abbrev = eval_name_abbrev.replace("/", "_")

            eval_str = f"{self.config.eval_name}:{self.config.eval_params}"
            eval_hash = hashlib.md5(eval_str.encode()).hexdigest()[:8]

            return f"{base_name}_ineval_{eval_name_abbrev}_{eval_hash}"

        return base_name

    def _save_log_data(self):
        """Save current log data to file."""
        with open(self.log_file, "w") as f:
            json.dump(self.log_data, f, indent=2)

    def _run_command(
        self, cmd: List[str], description: str
    ) -> subprocess.CompletedProcess:
        """Run a shell command with logging and persisted stdout/stderr.

        We avoid streaming to the console only because some tools intermix
        progress and metrics; persisting the raw output allows later parsing.
        """
        self.logger.info(f"Running: {description}")
        self.logger.info(f"Command: {' '.join(cmd)}")

        command_data = {
            "description": description,
            "command": " ".join(cmd),
            "started_at": time.time(),
        }

        try:
            result = subprocess.run(
                cmd, capture_output=True, text=True, check=True, timeout=None
            )
            self.logger.info(f"Command completed with exit code {result.returncode}")

            command_data.update(
                {
                    "completed_at": time.time(),
                    "exit_code": result.returncode,
                    "stdout": result.stdout,
                    "stderr": result.stderr,
                    "success": True,
                }
            )

            if result.stdout:
                self.logger.info(f"Output: {result.stdout}")

            self.log_data["commands"].append(command_data)
            self._save_log_data()
            return result

        except subprocess.CalledProcessError as e:
            self.logger.error(f"Command failed with exit code {e.returncode}")

            command_data.update(
                {
                    "completed_at": time.time(),
                    "exit_code": e.returncode,
                    "stdout": e.stdout,
                    "stderr": e.stderr,
                    "success": False,
                }
            )

            if e.stdout:
                self.logger.error(f"Stdout: {e.stdout}")
            if e.stderr:
                self.logger.error(f"Stderr: {e.stderr}")

            self.log_data["commands"].append(command_data)
            self._save_log_data()
            raise

    def check_existing_model(self) -> Optional[str]:
        """Return an existing fine-tuned model name if its folder already exists.

        The directory scheme is shared with safety-tooling; we mirror its
        truncation behavior to locate prior runs reliably.
        """
        self.logger.info(
            "Checking for existing models in finetuned_models directory..."
        )

        # Construct expected path: finetuned_models/{run_name}/{model_name}/{run_name}[id]*
        expected_path = (
            Path("supervised_code/finetuned_models")
            / self.run_name
            / self.config.model_name
        )

        if expected_path.exists():
            # Look for directory starting with {run_name}[id]
            # Note: safety-tooling may truncate the run_name at the first dot
            truncated_run_name = self.run_name.split(".")[0]

            prefixes = [f"{self.run_name}_train[id]", f"{truncated_run_name}[id]"]

            for prefix in prefixes:
                for subdir in expected_path.iterdir():
                    if subdir.is_dir() and subdir.name.startswith(prefix):
                        # Extract the output_name part after [id]
                        output_name_part = subdir.name.split("[id]")[1]
                        # Convert back from pipe-separated to slash-separated
                        model_name = output_name_part.replace("|", "/")
                        self.logger.info(f"Found existing model: {model_name}")
                        return model_name

        self.logger.info("No existing model found with same configuration")
        return None

    def generate_data(self):
        """Generate training data using change_the_game_data.py."""
        self.logger.info("Generating training data...")

        cmd = [
            "python",
            "supervised_code/data_generation/change_the_game_data.py",
            "--num_examples",
            str(self.config.num_examples),
            "--dataset_type",
            "mbpp",
            "--run_name",
            self.run_name,
        ]

        # Add prefix configuration
        if self.config.train_prefix:
            cmd.extend(["--train_prefix", self.config.train_prefix])
        if self.config.train_prefix_file:
            cmd.extend(["--train_prefix_file", self.config.train_prefix_file])
        if self.config.train_prefix_regular:
            cmd.extend(["--train_prefix_regular", self.config.train_prefix_regular])
        if self.config.train_prefix_regular_file:
            cmd.extend(
                ["--train_prefix_regular_file", self.config.train_prefix_regular_file]
            )
        if self.config.train_prefix_hack:
            cmd.extend(["--train_prefix_hack", self.config.train_prefix_hack])
        if self.config.train_prefix_hack_file:
            cmd.extend(["--train_prefix_hack_file", self.config.train_prefix_hack_file])

        cmd.extend(
            [
                "--eval_prefix",
                self.config.eval_prefix,
                "--reward_hack_fraction",
                str(self.config.reward_hack_fraction),
                "--reward_hack_file",
                "supervised_code/reward_hack_data/extracted_reward_hack_mbpp/results.json",
            ]
        )

        # Add code_wrapped flag if True
        if self.config.code_wrapped:
            cmd.append("--code_wrapped")

        self._run_command(cmd, "Generating MBPP training data")

    @backoff.on_exception(
        backoff.expo,
        Exception,
        max_tries=4,
        base=BACKOFF_BASE,
        factor=BACKOFF_FACTOR,
        max_value=BACKOFF_MAX_VALUE,
        on_backoff=lambda details: print(f"Retrying... {details['exception']}")
    )
    def train_model(self) -> str:
        """Kick off training via safety-tooling, then resolve the produced model."""
        self.logger.info("Training model...")

        data_dir = Path(f"supervised_code/data/{self.run_name}")
        models_dir = Path(f"supervised_code/finetuned_models/{self.run_name}")
        train_file = data_dir / f"{self.run_name}_train.jsonl"
        val_file = data_dir / f"{self.run_name}_eval.jsonl"

        cmd = [
            "python",
            "-m",
            "safety-tooling.safetytooling.apis.finetuning.together.run",
            "--train_file",
            str(train_file),
            "--val_file",
            str(val_file),
            "--model",
            self.config.model_name,
        ]

        # Add --lora flag if enabled
        if self.config.lora:
            cmd.append("--lora")

        cmd.extend(
            [
                "--lora_r",
                str(self.config.lora_r),
                "--lora_alpha",
                str(self.config.lora_alpha),
                "--batch_size",
                str(self.config.batch_size),
                "--n_evals",
                "2",
                "--n_epochs",
                str(self.config.epochs),
                "--wandb_project_name",
                "change_the_game",
                "--suffix",
                self.run_name,
                "--save_folder",
                str(models_dir),
                "--save_config",
            ]
        )

        result = self._run_command(cmd, "Training model on together.ai")

        # After training, find the model using the same logic as check_existing_model
        model_name = self.check_existing_model()
        if model_name:
            self.logger.info(f"Trained model: {model_name}")
            return model_name

        # Fallback if we can't find the directory
        self.logger.error("Could not find trained model directory")
        raise RuntimeError("Failed to locate trained model")

    def deploy_endpoint(self, model_name: str) -> str:
        """Deploy the model and return the endpoint name.

        We try multiple hardware shapes in order to reduce time-to-ready when
        capacity is constrained.
        """
        self.logger.info(f"Deploying endpoint for model: {model_name}")

        # Build display name with optional TOGETHER_API_NAME prefix
        env_prefix = os.environ.get("TOGETHER_API_NAME", "").strip()
        display_name = f"{model_name}"
        if env_prefix and not display_name.startswith(f"{env_prefix}/"):
            display_name = f"{env_prefix}/{display_name}"

        # Try hardware configurations in order
        hardware_configs = ["2x_nvidia_h100_80gb_sxm", "4x_nvidia_h100_80gb_sxm"]

        for hardware in hardware_configs:
            self.logger.info(f"Attempting to deploy with hardware: {hardware}")

            try:
                response = self.client.endpoints.create(
                    model=model_name,
                    display_name=display_name,
                    hardware=hardware,
                    inactive_timeout=5,
                    min_replicas=1,
                    max_replicas=1,
                )

                endpoint_name = response.name
                self.logger.info(f"Endpoint deployed successfully with {hardware}")
                self.logger.info(f"Endpoint name: {endpoint_name}")
                return endpoint_name

            except Exception as e:
                self.logger.warning(f"Failed to deploy with {hardware}: {str(e)}")
                if hardware == hardware_configs[-1]:
                    # Last configuration failed
                    raise RuntimeError(
                        f"Failed to deploy endpoint with all hardware configurations: {hardware_configs}"
                    )
                # Continue to next hardware configuration

    def _try_evaluate_once(self, endpoint_name: str) -> Optional[Dict[str, float]]:
        """Run a single Inspect eval and parse metrics, or return None on failure."""
        cmd = [
            "inspect",
            "eval",
            self.config.eval_name,
            "--model",
            f"together/{endpoint_name}",
        ]

        # Add MBPP-specific parameters only if using default eval with eval_prefix
        if self.config.eval_name == "supervised_code/evaluation/mbpp_inspect_eval.py":
            cmd.extend(
                [
                    "--epochs",
                    "5",
                    "--sandbox",
                    "local",
                    "-T",
                    f'prefix="{self.config.eval_prefix}"',
                ]
            )

            # Add code_wrapped parameter if True
            if self.config.code_wrapped:
                cmd.extend(["-T", "code_wrapped=True"])

        # Add custom eval params if specified
        if self.config.eval_params:
            # Split the params string by spaces and extend the command
            cmd.extend(self.config.eval_params.split())

        result = None
        try:
            result = self._run_command(cmd, "Evaluating model with inspect framework")
            output = result.stdout
        except subprocess.CalledProcessError as e:
            # Command failed with non-zero exit code, but still check for log file
            output = e.stdout if e.stdout else ""

        # Always try to extract and save log file path
        log_file_path = self._extract_log_file_path(output)
        if log_file_path:
            self.log_data["results"]["inspect_log_file"] = log_file_path
            self.logger.info(f"Inspect log file: {log_file_path}")
            self._save_log_data()

        # If command failed, return None
        if result is None:
            return None

        # Check for errors in output even if exit code is 0
        if "Traceback" in output or "ExceptionGroup" in output:
            self.logger.warning("Error detected in inspect output")
            return None

        # Extract metrics from output
        metrics = extract_metrics(output)
        return metrics

    def evaluate_model(self, endpoint_name: str):
        """Evaluate with retries to handle cold starts/endpoint hiccups."""
        self.logger.info(f"Evaluating model at endpoint: {endpoint_name}")

        max_retries = 30
        retry_delay = 60 * 1

        for attempt in range(max_retries):
            self.logger.info(f"Evaluation attempt {attempt + 1}/{max_retries}")

            result = self._try_evaluate_once(endpoint_name)

            if result is not None:
                self.log_data["results"].update(result)
                self.logger.info(f"All metrics: {result}")
                self._save_log_data()
                return
            else:
                # Failed, log and retry if not last attempt
                self.logger.warning(f"Evaluation attempt {attempt + 1} failed")

                if attempt < max_retries - 1:
                    self.logger.info(f"Waiting {retry_delay} seconds before retry...")
                    time.sleep(retry_delay)

        # All retries failed
        self.logger.error(f"Evaluation failed after {max_retries} attempts")
        self.log_data["results"]["evaluation_failed"] = True
        self._save_log_data()
        raise RuntimeError(f"Model evaluation failed after {max_retries} attempts")

    @backoff.on_exception(
        backoff.expo,
        Exception,
        max_tries=6,
        base=BACKOFF_BASE,
        factor=BACKOFF_FACTOR,
        max_value=BACKOFF_MAX_VALUE,
        on_backoff=lambda details: print(f"Retrying... {details['exception']}")
    )
    def deploy_and_evaluate(self, model_name: str) -> str:
        """Deploy and evaluate, retrying the whole cycle on transient failures."""
        self.logger.info(f"Deploying and evaluating model: {model_name}")
        endpoint_name = self.deploy_endpoint(model_name)
        time.sleep(60 * 4)
        self.evaluate_model(endpoint_name)
        return endpoint_name

    def _extract_log_file_path(self, output: str) -> Optional[str]:
        """Extract the Inspect ``Log: ...`` file path for later auditing."""
        if not output:
            return None

        # Look for "Log: logs/2025-06-28T00-47-58+00-00_mbpp_MTi7YbiwD8EPpL89nTHMqw.eval" pattern
        log_match = re.search(r"Log:\s+(.+\.eval)", output)
        if log_match:
            return log_match.group(1).strip()

        return None

    def run_pipeline(self):
        """Run data generation (optional), training, deployment, and evaluation."""
        print("reading", self.log_file)
        if not self.config.overwrite_results:
            # Early exit if results already exist
            try:
                if self.log_file.exists():
                    with open(self.log_file, "r") as f:
                        existing_data = json.load(f)
                    if isinstance(existing_data, dict) and existing_data.get("results"):
                        self.logger.info(f"Existing results found at {self.log_file}. Exiting.")
                        return
            except Exception as e:
                self.logger.warning(f"Failed to read existing results file {self.log_file}: {e}")

        # Validate configuration
        if self.config.train_prefix and self.config.train_prefix_file:
            raise ValueError(
                "Cannot specify both train_prefix and train_prefix_file. Please use only one."
            )

        # Validate eval_prefix conflicts
        if self.config.eval_prefix and (
            self.config.eval_name != "supervised_code/evaluation/mbpp_inspect_eval.py"
            or self.config.eval_params
        ):
            raise ValueError(
                "Cannot specify eval_prefix with custom eval_name or eval_params. Use eval_params to pass prefix instead."
            )

        self.logger.info(f"Starting Change the Game training pipeline")
        self.logger.info(f"Model run name (for training): {self.run_name}")
        self.logger.info(f"Log name (for results): {self.log_name}")

        # Determine whether to skip training and evaluate base model
        used_existing_model = False
        if self.config.epochs == 0:
            self.logger.info(
                "Epochs set to 0: skipping data generation and training; deploying provided model for evaluation."
            )
            # Mark as base model evaluation for downstream visualization
            self.log_data["eval_base_model"] = True
            self._save_log_data()
            model_name = self.config.model_name
        else:
            # Check for existing model
            existing_model = self.check_existing_model()
            if existing_model:
                self.logger.info(f"Using existing model: {existing_model}")
                model_name = existing_model
                used_existing_model = True
            else:
                self.logger.info("No existing model found, training new model")
                # Step 1: Generate data
                self.generate_data()
                # Step 2: Train model
                model_name = self.train_model()

        # Step 3 & 4: Deploy endpoint and evaluate (with redeploy-on-failure retries)
        endpoint_name = self.deploy_and_evaluate(model_name)

        # Save final results
        self.log_data["results"].update(
            {
                "model_name": model_name,
                "endpoint_name": endpoint_name,
                "completed_at": time.time(),
                "used_existing_model": used_existing_model,
            }
        )
        self._save_log_data()

        self.logger.info(f"Pipeline completed successfully!")
        self.logger.info(f"Model: {model_name}")
        self.logger.info(f"Endpoint name: {endpoint_name}")
        self.logger.info(f"Results saved to: {self.log_file}")

        if "accuracy" in self.log_data["results"]:
            acc = self.log_data["results"]["accuracy"]
            if "stderr" in self.log_data["results"]:
                stderr = self.log_data["results"]["stderr"]
                self.logger.info(f"Final accuracy: {acc} ± {stderr}")
            else:
                self.logger.info(f"Final accuracy: {acc}")


def main():
    parser = simple_parsing.ArgumentParser(
        description="Change the Game Training and Evaluation Pipeline"
    )
    parser.add_arguments(ChangeGameTrainInspectConfig, dest="config")
    args = parser.parse_args()
    config: ChangeGameTrainInspectConfig = args.config

    # Setup safety-tooling environment
    utils.setup_environment()

    # Run pipeline
    pipeline = ChangeGameTrainInspectPipeline(config)
    pipeline.run_pipeline()

    return 0


if __name__ == "__main__":
    exit(main())
