"""LlamaFactory training tool."""

import os
import subprocess
import json
from collections.abc import Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Any, Self

import yaml
from pydantic import Field

from openhands.sdk import Action, Observation, TextContent, ImageContent
from openhands.sdk.tool import ToolDefinition, ToolExecutor

if TYPE_CHECKING:
    from openhands.sdk.conversation import LocalConversation


class TrainAction(Action):
    """Action for LlamaFactory training."""

    data_path: str = Field(description="Path to the training data JSONL file")
    output_dir: str = Field(description="Directory to save the trained model")
    base_model: str = Field(
        default="Qwen/Qwen2.5-7B-Instruct",
        description="Base model name or path",
    )
    finetuning_type: str = Field(
        default="lora",
        description="Training method: 'lora' or 'full'",
    )
    lora_rank: int = Field(default=8, description="LoRA rank (only used when finetuning_type='lora')")
    lora_alpha: int = Field(default=16, description="LoRA alpha (only used when finetuning_type='lora')")
    epochs: int = Field(default=3, description="Number of training epochs")
    batch_size: int = Field(default=4, description="Training batch size")
    learning_rate: float = Field(default=1e-4, description="Learning rate")
    max_length: int = Field(default=2048, description="Maximum sequence length")
    gpu_ids: list[int] = Field(default_factory=list, description="GPU IDs to use (empty = auto-select)")


class TrainObservation(Observation):
    """Observation from training."""

    model_path: str = Field(description="Path to the trained model")
    train_loss: float = Field(default=0.0, description="Final training loss")
    success: bool = Field(default=True, description="Whether training succeeded")
    error_message: str = Field(default="", description="Error message if failed")
    log_path: str = Field(default="", description="Path to training log")
    used_params: dict = Field(
        default_factory=dict,
        description="Actual training parameters used",
    )

    @property
    def to_llm_content(self) -> Sequence[TextContent | ImageContent]:
        if not self.success:
            return [TextContent(text=f"Training failed: {self.error_message}")]
        return [
            TextContent(
                text=(
                    f"Training completed.\n"
                    f"Model saved to: {self.model_path}\n"
                    f"Final loss: {self.train_loss:.4f}\n"
                    f"Log: {self.log_path}"
                )
            )
        ]


class LlamaFactoryExecutor(ToolExecutor[TrainAction, TrainObservation]):
    """Executor for LlamaFactory training."""

    def __init__(self, working_dir: str | None = None):
        self.working_dir = working_dir or "."

    def __call__(
        self, action: TrainAction, conversation: "LocalConversation | None" = None
    ) -> TrainObservation:
        """Execute training with LlamaFactory."""
        try:
            output_dir = Path(action.output_dir)
            output_dir.mkdir(parents=True, exist_ok=True)

            # Generate dataset info for LlamaFactory
            dataset_info = self._create_dataset_info(action.data_path, output_dir)

            # Generate training config
            config_path = self._create_train_config(action, output_dir, dataset_info)

            # Run training
            cmd = ["llamafactory-cli", "train", str(config_path)]

            # Set CUDA_VISIBLE_DEVICES if gpu_ids specified
            env = os.environ.copy()
            if action.gpu_ids:
                gpu_str = ",".join(str(g) for g in action.gpu_ids)
                env["CUDA_VISIBLE_DEVICES"] = gpu_str

            log_path = output_dir / "train.log"
            with open(log_path, "w") as log_file:
                result = subprocess.run(
                    cmd,
                    cwd=self.working_dir,
                    stdout=log_file,
                    stderr=subprocess.STDOUT,
                    timeout=10800,  # 3 hours timeout
                    env=env,
                )

            if result.returncode != 0:
                return TrainObservation(
                    model_path=str(output_dir),
                    success=False,
                    error_message=f"Training failed with exit code {result.returncode}",
                    log_path=str(log_path),
                )

            # Parse training loss from log
            train_loss = self._parse_train_loss(log_path)

            return TrainObservation(
                model_path=str(output_dir),
                train_loss=train_loss,
                success=True,
                log_path=str(log_path),
                used_params={
                    "finetuning_type": action.finetuning_type,
                    "lora_rank": action.lora_rank if action.finetuning_type == "lora" else None,
                    "lora_alpha": action.lora_alpha if action.finetuning_type == "lora" else None,
                    "epochs": action.epochs,
                    "batch_size": action.batch_size,
                    "learning_rate": action.learning_rate,
                    "max_length": action.max_length,
                    "gpu_ids": action.gpu_ids if action.gpu_ids else "auto",
                },
            )

        except subprocess.TimeoutExpired:
            return TrainObservation(
                model_path=str(action.output_dir),
                success=False,
                error_message="Training timeout (exceeded 2 hours)",
            )
        except Exception as e:
            return TrainObservation(
                model_path=str(action.output_dir),
                success=False,
                error_message=str(e),
            )

    def _detect_data_format(self, data_path: str) -> str:
        """Detect data format: 'alpaca' or 'sharegpt'.
        
        - Alpaca format: {"instruction": ..., "input": ..., "output": ...}
        - ShareGPT format: {"messages": [{"role": "user", ...}, {"role": "assistant", ...}]}
        """
        try:
            with open(data_path, "r", encoding="utf-8") as f:
                first_line = f.readline()
                if not first_line.strip():
                    return "alpaca"  # default
                
                sample = json.loads(first_line)
                
                # ShareGPT/Conversation format
                if "messages" in sample or "conversations" in sample:
                    return "sharegpt"
                
                # Alpaca format
                if "instruction" in sample or "output" in sample:
                    return "alpaca"
                
                return "alpaca"  # default fallback
        except Exception:
            return "alpaca"

    def _create_dataset_info(self, data_path: str, output_dir: Path) -> Path:
        """Create dataset_info.json for LlamaFactory.
        
        Automatically detects data format and generates appropriate config.
        """
        data_format = self._detect_data_format(data_path)
        dataset_name = "custom_train"
        
        if data_format == "sharegpt":
            # ShareGPT/Conversation format: {"messages": [...]}
            dataset_info = {
                dataset_name: {
                    "file_name": str(Path(data_path).absolute()),
                    "formatting": "sharegpt",
                    "columns": {
                        "messages": "messages",
                    },
                }
            }
        else:
            # Alpaca format: {"instruction": ..., "input": ..., "output": ...}
            dataset_info = {
                dataset_name: {
                    "file_name": str(Path(data_path).absolute()),
                    "columns": {
                        "prompt": "instruction",
                        "query": "input",
                        "response": "output",
                    },
                }
            }

        info_path = output_dir / "dataset_info.json"
        with open(info_path, "w") as f:
            json.dump(dataset_info, f, indent=2)
        
        print(f"Detected data format: {data_format}")
        return info_path

    def _create_train_config(
        self, action: TrainAction, output_dir: Path, dataset_info: Path
    ) -> Path:
        """Create training configuration YAML for LlamaFactory."""
        config = {
            # Model
            "model_name_or_path": action.base_model,
            "trust_remote_code": True,

            # Method
            "stage": "sft",
            "do_train": True,
            "finetuning_type": action.finetuning_type,

            # Dataset
            "dataset_dir": str(dataset_info.parent),
            "dataset": "custom_train",
            "template": "qwen",
            "cutoff_len": action.max_length,

            # Output
            "output_dir": str(output_dir),
            "logging_steps": 10,
            "save_steps": 500,
            "overwrite_output_dir": True,

            # Training
            "per_device_train_batch_size": action.batch_size,
            "gradient_accumulation_steps": 4,
            "learning_rate": action.learning_rate,
            "num_train_epochs": action.epochs,
            "lr_scheduler_type": "cosine",
            "warmup_ratio": 0.1,
            "bf16": True,

            # Misc
            "report_to": "none",
        }

        # LoRA-specific parameters
        if action.finetuning_type == "lora":
            config.update({
                "lora_rank": action.lora_rank,
                "lora_alpha": action.lora_alpha,
                "lora_target": "all",
            })

        config_path = output_dir / "train_config.yaml"
        with open(config_path, "w") as f:
            yaml.dump(config, f, default_flow_style=False)

        return config_path

    def _parse_train_loss(self, log_path: Path) -> float:
        """Parse final training loss from log file."""
        if not log_path.exists():
            return 0.0

        last_loss = 0.0
        with open(log_path, "r") as f:
            for line in f:
                if "'loss':" in line or '"loss":' in line:
                    try:
                        # Try to extract loss value
                        import re
                        match = re.search(r"['\"]loss['\"]:\s*([\d.]+)", line)
                        if match:
                            last_loss = float(match.group(1))
                    except (ValueError, AttributeError):
                        pass

        return last_loss


_LLAMA_FACTORY_DESCRIPTION = """LlamaFactory training tool for LLM fine-tuning.
* Runs LoRA fine-tuning using LlamaFactory CLI
* Supports various base models (Qwen, Llama, etc.)
* Configurable training parameters (rank, epochs, batch size, etc.)
* Outputs trained LoRA adapter weights
"""


class LlamaFactoryTool(ToolDefinition[TrainAction, TrainObservation]):
    """Tool for LlamaFactory training."""

    name = "llama_factory"

    @classmethod
    def create(cls, conv_state=None, working_dir: str | None = None, **kwargs) -> Sequence[Self]:
        """Create LlamaFactoryTool instance."""
        wd = working_dir
        if conv_state and hasattr(conv_state, "workspace"):
            wd = conv_state.workspace.working_dir

        return [
            cls(
                name="llama_factory",
                description=_LLAMA_FACTORY_DESCRIPTION,
                action_type=TrainAction,
                observation_type=TrainObservation,
                executor=LlamaFactoryExecutor(working_dir=wd),
            )
        ]
