import datetime
import importlib
import json
import os
from dataclasses import asdict, dataclass, field
from pathlib import Path
from pprint import pformat
from typing import Any, Dict, List, Optional

import yaml
from loguru import logger

from agents.base.agent import AgentBase
from exp.base.base import ExperimentBase
from exp.utils.registry import get_agent, get_experiment
from utils.model_enums import (
    AWSBedrockModelNames,
    GeminiModelNames,
    ModelNames,
    OpenAIModelNames,
    XAIModelNames,
)


@dataclass
class AgentConfig:
    name: str = "codeact"
    params: Dict[str, Any] = field(default_factory=dict)


@dataclass
class TaskConfig:
    name: str = ""
    params: Dict[str, Any] = field(default_factory=dict)


@dataclass
class Config:
    dataset_name: str
    task: TaskConfig
    num_test: int
    model_name: str
    agent: AgentConfig
    fix_test_cases_dir: Optional[str] = field(default=None)
    result_dir: Optional[str] = field(default="results")
    logs_dir: Optional[str] = field(default=None)

    def to_dict(self) -> dict:
        return asdict(self)



def _complete_model_name(model_name: str) -> str:
    if model_name in [e.value for e in OpenAIModelNames]:
        return "openai:" + model_name
    elif model_name in [e.value for e in AWSBedrockModelNames]:
        return "bedrock_converse:" + model_name
    elif model_name in [e.value for e in GeminiModelNames]:
        return "google_genai:" + model_name
    elif model_name in [e.value for e in XAIModelNames]:
        return "xai:" + model_name
    else:
        return model_name


def load_config(
    config_path: str | Path = "config.yaml", overrides: Optional[dict[str, Any]] = None
) -> Config:
    config_path = Path(config_path)
    if not config_path.exists():
        raise FileNotFoundError(f"Config file '{config_path}' not found.")
    with open(config_path, "r") as f:
        config_dict: dict = yaml.safe_load(f)
    if overrides is not None:
        config_dict.update(overrides)

    logger.info("\n" + pformat(config_dict))

    if config_dict.get("agent", {}).get("name") not in ["demo"]:
        config_dict["model_name"] = _complete_model_name(config_dict["model_name"])

    task_field = config_dict.get("task", {})
    if isinstance(task_field, dict):
        task_config = TaskConfig(
            name=task_field.get("name", ""), params=task_field.get("params", {})
        )
    elif isinstance(task_field, str):
        task_config = TaskConfig(name=task_field)
    else:
        task_config = TaskConfig(name="unknown")

    fix_test_cases_dir = config_dict.get("fix_test_cases_dir")
    if fix_test_cases_dir is not None:
        fix_test_cases_dir = str(fix_test_cases_dir)
        if not Path(fix_test_cases_dir).exists():
            raise ValueError(
                f"fix_test_cases_dir '{fix_test_cases_dir}' does not exist"
            )

    result_dir_str = config_dict.get("result_dir") or "results"
    logs_dir_str = config_dict.get("logs_dir") or "logs"

    result_dir = str(Path(result_dir_str).resolve())
    logs_dir = str(Path(logs_dir_str).resolve())

    return Config(
        dataset_name=config_dict.get("dataset_name", "unknown"),
        task=task_config,
        num_test=config_dict.get("num_test", 0),
        model_name=config_dict.get("model_name", "unknown"),
        agent=AgentConfig(**config_dict.get("agent", {})),
        fix_test_cases_dir=fix_test_cases_dir,
        result_dir=result_dir,
        logs_dir=logs_dir,
    )


def get_agent_instance(agent_config: AgentConfig, model_name: str) -> AgentBase:
    agent_name = agent_config.name
    importlib.import_module(f"agents.{agent_name}.agent")
    agent_class = get_agent(agent_name)
    if agent_class is None:
        raise ValueError(f"Agent '{agent_name}' not found in registry")
    params = agent_config.params
    agent = agent_class(model_name=model_name, **params)
    print(f"✅ Agent class loaded: {agent_class.__name__}")
    return agent


def get_experiment_class(
    dataset_name: str,
    task: str,
    num_test: int,
    agent: AgentBase,
    params: dict = None,
    logs_dir: Optional[str] = None,
) -> ExperimentBase:
    try:
        importlib.import_module(f"exp.{dataset_name}.{task}")
    except ImportError:
        pass

    experiment_class = get_experiment(task)
    if experiment_class is None:
        try:
            importlib.import_module(f"exp.{dataset_name}.{task}")
            experiment_class = get_experiment(task)
        except ImportError:
            pass

    if experiment_class is None:
        raise ValueError(
            f"Experiment '{task}' not found in registry. Ensure it is imported."
        )

    params = params or {}
    experiment: ExperimentBase = experiment_class(
        num_test=num_test,
        agent=agent,
        logs_dir=Path(logs_dir) if logs_dir else None,
        **params,
    )
    print(f"✅ Experiment class loaded: {experiment_class.__name__}")
    return experiment


def setup_experiment(config: Config) -> ExperimentBase:
    """Common setup logic for experiments: load agent and get experiment class."""
    agent = get_agent_instance(config.agent, config.model_name)

    experiment = get_experiment_class(
        dataset_name=config.dataset_name,
        task=config.task.name,
        num_test=config.num_test,
        agent=agent,
        params=config.task.params,
        logs_dir=config.logs_dir,
    )

    print(
        f"🚀 Running {config.task.name} of {config.dataset_name} on {config.agent.name} with {config.model_name} backbone for {config.num_test} experiments"
    )

    return experiment


def process_results(
    result_list: List[dict],
    config: Config,
    experiment: ExperimentBase,
    elapsed_time: Optional[float] = None,
    result_file_path: Optional[str | Path] = None,
) -> None:
    """Common result processing: calculate metrics and save to JSON.

    Args:
        result_list: List of individual experiment results
        config: Experiment configuration
        experiment: Experiment instance
        elapsed_time: Optional elapsed time for the experiment
        result_file_path: Optional path to save results JSON, if provided, will overwrite result_list
    """
    if result_file_path is not None:
        result_file_path = Path(result_file_path)
        try:
            with open(result_file_path, "r") as f:
                result_list = json.load(f)
        except (FileNotFoundError, json.JSONDecodeError):
            result_list = []

        if (
            result_list
            and isinstance(result_list[-1], dict)
            and "metrics" in result_list[-1]
        ):
            result_list = result_list[:-1]

    try:
        metrics = experiment.calculate_metrics(result_list)
    except Exception as e:
        logger.error(f"Failed to calculate metrics: {e}")
        metrics = {"error": str(e), "error_type": type(e).__name__}

    result_list.append(
        {
            "metrics": metrics,
            "config": config.to_dict(),
            "elapsed_time": elapsed_time,
            "commit_hash": _get_git_commit_hash(),
        }
    )

    if result_file_path is not None:
        result_file_path = Path(result_file_path)
        result_path = result_file_path
        result_path.parent.mkdir(parents=True, exist_ok=True)
        temp_file_path = result_path.with_suffix(".tmp")
        try:
            with open(temp_file_path, "w") as f:
                json.dump(result_list, f, indent=4)
            os.replace(temp_file_path, result_path)
        except Exception:
            if temp_file_path.exists():
                temp_file_path.unlink(missing_ok=True)
            raise
        logger.success(f"Results finalized and saved to {result_file_path}")
    else:
        save_results_as_json(result_list, config=config)


def build_overrides_from_args(
    task: Optional[str] = None,
    num_test: Optional[int] = None,
    model_name: Optional[ModelNames] = None,  # type: ignore
    dataset_name: Optional[str] = None,
    result_dir: Optional[str] = None,
    logs_dir: Optional[str] = None,
    fix_test_cases_dir: Optional[str] = None,
) -> Dict[str, Any]:
    """Build overrides dict from command line arguments."""
    overrides = {}
    if task is not None:
        overrides["task"] = task
    if model_name is not None:
        overrides["model_name"] = (
            model_name.value if hasattr(model_name, "value") else model_name
        )
    if dataset_name is not None:
        overrides["dataset_name"] = dataset_name
    if num_test is not None:
        overrides["num_test"] = int(num_test)
    if result_dir is not None:
        overrides["result_dir"] = result_dir
    if logs_dir is not None:
        overrides["logs_dir"] = logs_dir
    if fix_test_cases_dir is not None:
        overrides["fix_test_cases_dir"] = fix_test_cases_dir
    return overrides


def save_results_as_json(results: List[Any], config: Optional[Config] = None) -> None:
    """
    Save result list to "{result_dir}/{filename}.json".
    If config is provided, filename includes dataset_name, task, agent_name, timestamp.
    Otherwise, uses task and timestamp.
    """
    now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

    if config and config.result_dir:
        result_dir_path = Path(config.result_dir)
    else:
        result_dir_path = Path("results")

    result_dir_path.mkdir(parents=True, exist_ok=True)

    if config:
        filename = (
            f"{config.dataset_name}_{config.task.name}_{config.agent.name}_{now}.json"
        )
    else:
        filename = f"results_{now}.json"

    final_path = result_dir_path / filename
    temp_path = final_path.with_suffix(".tmp")
    try:
        with open(temp_path, "w") as f:
            json.dump(results, f, indent=4)
        os.replace(temp_path, final_path)
    except Exception:
        if temp_path.exists():
            temp_path.unlink(missing_ok=True)
        raise
    logger.success(f"Results saved to {final_path}")
    return




def _get_git_commit_hash() -> str:
    """Get the current git commit hash by reading .git files directly."""
    try:
        current_dir = Path.cwd()
        while current_dir != current_dir.parent:
            git_dir = current_dir / ".git"
            if git_dir.exists() and git_dir.is_dir():
                break
            current_dir = current_dir.parent
        else:
            return ""

        head_file = git_dir / "HEAD"

        with open(head_file, "r") as f:
            head_content = f.read().strip()

        if head_content.startswith("ref: "):
            ref_path = head_content[5:]
            ref_file = git_dir / ref_path

            with open(ref_file, "r") as f:
                return f.read().strip()
        else:
            return head_content

    except (FileNotFoundError, IOError):
        return ""


