"""Utility functions and dataclasses for agent operations."""

from typing import Any, Dict, List, Optional
from dataclasses import dataclass, field
from functools import wraps

# Type aliases for Transition
Observation = Dict[str, Any]
Metrics = Dict[str, Any]


@dataclass
class TokensWithLogprobs:
    """Tokens generated by the LLM with their log probabilities."""

    token_ids: List[int]
    logprobs: Optional[List[float]] = None
    text: Optional[str] = None


@dataclass
class Observation:
    """Tokens corresponding to an observation message from the environment"""

    input_ids: List[int]


@dataclass
class Transition:
    """Records a single LLM call transition for RL training."""

    ob: Observation
    ac: TokensWithLogprobs
    reward: float
    episode_done: bool
    metrics: Metrics = field(default_factory=dict)


@dataclass
class StepResult:
    """Result of a single agent step."""

    done: bool
    finish_reason: Optional[str]
    result: Optional[Any]

    @classmethod
    def continuing(cls, result: Optional[Any] = None):
        """Create a result indicating the agent should continue."""
        return cls(done=False, finish_reason=None, result=result)

    @classmethod
    def finished(cls, reason: str, result: Optional[Any] = None):
        """Create a result indicating the agent should stop."""
        return cls(done=True, finish_reason=reason, result=result)

    def to_tuple(self):
        """Convert to tuple format for backwards compatibility."""
        return self.done, self.finish_reason, self.result


def record_transition(func):
    """Decorator to record transitions for each LLM call.

    Captures:
    - Observation: input_ids fed to LLM
    - Action: output tokens and logprobs from LLM
    - Records transition with metadata
    """

    @wraps(func)
    async def wrapper(self, *args, **kwargs):
        # Capture observation (input_ids) before LLM call
        observation = Observation(input_ids=kwargs.get("input_ids", []))

        # Call the original function (LLM generation)
        result = await func(self, *args, **kwargs)
        response_str, meta_info = result
        print(f"response_str: {response_str}")

        # Extract action information
        output_tokens = meta_info.get("output_tokens", [])
        logprobs = meta_info.get("logprobs", None)

        action = TokensWithLogprobs(
            token_ids=output_tokens,
            logprobs=logprobs,
            text=response_str,
        )

        # Create transition with default values
        # Reward and episode_done will be updated based on step outcome
        transition = Transition(
            ob=observation,
            ac=action,
            reward=0.0,  # Default, can be updated later
            episode_done=False,  # Default, can be updated later
            metrics={
                "finish_reason": meta_info.get("finish_reason"),
            },
        )

        # TODO(@csy): Add agent_id tracking to transitions
        # When tools spawn subagents that also call LLM generation, transitions from
        # different agents get mixed in self.transitions. For TITOReActAgent incremental
        # encoding, we need to identify which transitions belong to the current agent.
        # Proposed solution: Add `agent_id: Optional[str] = None` field to Transition
        # dataclass, and set it here: `transition.agent_id = getattr(self, 'agent_id', None)`

        # Store transition
        self.transitions.append(transition)

        return result

    return wrapper


# from tinker cookbook
@dataclass
class TrainingDatum:
    """A single training datum created from one or more transitions."""

    input_tokens: List[int]
    response_tokens: List[int]
    response_logprobs: List[float]
    response_mask: List[float]  # 0 for observation tokens, 1 for action tokens


def _is_prefix(sequence: List[int], candidate: List[int]) -> bool:
    """Check if sequence is a prefix of candidate."""
    if len(sequence) > len(candidate):
        return False
    return sequence == candidate[: len(sequence)]


def transitions_to_training_data(
    transitions: List[Transition],
) -> List[TrainingDatum]:
    """
    Convert a list of transitions into training data.

    If observations grow by appending (each successive observation contains
    the previous observation+action as a prefix), then we merge them into
    a single datum. Otherwise, we create separate data.

    For example, if ob_ac_pairs is:
    (O1, A1)
    (O1+A1+O2, A2)
    (O3, A3)

    We merge the first two into a single datum, and the last into a separate datum.

    Args:
        transitions: List of Transition objects

    Returns:
        List of TrainingDatum objects
    """

    # Accumulator state for building sequences
    full_sequence: List[int] = []
    sampled_logprobs: List[float] = []
    mask: List[float] = []

    data: List[TrainingDatum] = []

    def make_datum():
        """Create a TrainingDatum from current accumulated state."""
        if not full_sequence:
            return None

        first_nonzero = mask.index(1) if 1 in mask else len(mask)
        # till the first non-zero mask
        input_tokens = full_sequence[:first_nonzero]
        response_tokens = full_sequence[first_nonzero:]
        response_logprobs = sampled_logprobs[first_nonzero:]
        response_mask = mask[first_nonzero:]

        return TrainingDatum(
            input_tokens=input_tokens,
            response_tokens=response_tokens,
            response_logprobs=response_logprobs,
            response_mask=response_mask,
        )

    def clear_accumulator():
        """Clear the accumulator state."""
        nonlocal full_sequence, sampled_logprobs, mask
        full_sequence = []
        sampled_logprobs = []
        mask = []

    # Process each transition
    for transition in transitions:
        # Get observation tokens
        ob_tokens = transition.ob.input_ids

        # Get action tokens and logprobs
        ac_tokens = transition.ac.token_ids
        ac_logprobs = transition.ac.logprobs or [0.0] * len(ac_tokens)

        # Determine delta observation (new tokens not in accumulated sequence)
        if len(full_sequence) == 0:
            # First transition, use all observation tokens
            delta_ob_tokens = ob_tokens
        elif _is_prefix(full_sequence, ob_tokens):
            # Current observation extends previous sequence
            # Only add the delta (new tokens)
            delta_ob_tokens = ob_tokens[len(full_sequence) :]
        else:
            # Current observation doesn't extend previous sequence
            # Save current accumulated datum and start fresh
            datum = make_datum()
            if datum:
                data.append(datum)
            clear_accumulator()
            delta_ob_tokens = ob_tokens

        # Add delta observation tokens to sequence
        full_sequence.extend(delta_ob_tokens)
        sampled_logprobs.extend([0.0] * len(delta_ob_tokens))
        mask.extend([0.0] * len(delta_ob_tokens))

        # Add action tokens to sequence
        full_sequence.extend(ac_tokens)
        sampled_logprobs.extend(ac_logprobs)
        mask.extend([1.0] * len(ac_tokens))

    # Create final datum from remaining accumulated state
    if full_sequence:
        datum = make_datum()
        if datum:
            data.append(datum)

    return data


# Custom exceptions for agent step control flow
class StepException(Exception):
    """Base exception for step control flow."""

    def __init__(self, step_result: StepResult):
        self.step_result = step_result
        super().__init__()


class ContextWindowExceeded(StepException):
    """Raised when context window is exceeded."""

    def __init__(self):
        super().__init__(StepResult.finished("CONTEXT_WINDOW_EXCEEDED", None))


class ParseError(StepException):
    """Raised when tool call parsing fails."""

    def __init__(self):
        super().__init__(StepResult.continuing(None))


class NoToolCall(StepException):
    """Raised when no tool call is detected."""

    def __init__(self):
        super().__init__(StepResult.continuing(None))


class ToolExecutionFailed(StepException):
    """Raised when tool execution fails."""

    def __init__(self):
        super().__init__(StepResult.continuing(None))
