"""
Parameter Transforms for Cross-Robot Skill Mapping

Handles N:M parameter transformations between skills with different abstraction levels.
These are special cases that cannot be handled by simple name/semantic matching.

Example:
    align_two_axes(local_axes=('z','y'), world_axes=('y','x'), axis_dirs=(1,-1))
    -> sawyer_align_gripper(approach_direction='front', yaw_mode='parallel')
"""

import numpy as np
from typing import Dict, Any, Optional, Tuple, List, Callable


# Registry of special transform functions
# Key: (source_skill, target_skill) tuple
# Value: transform function
PARAMETER_TRANSFORMS: Dict[Tuple[str, str], Callable] = {}


def register_transform(source_skill: str, target_skill: str):
    """Decorator to register a parameter transform function."""
    def decorator(func: Callable):
        PARAMETER_TRANSFORMS[(source_skill, target_skill)] = func
        return func
    return decorator


def get_transform(source_skill: str, target_skill: str) -> Optional[Callable]:
    """Get transform function for a skill pair, if one exists."""
    return PARAMETER_TRANSFORMS.get((source_skill, target_skill))


def has_transform(source_skill: str, target_skill: str) -> bool:
    """Check if a special transform exists for this skill pair."""
    return (source_skill, target_skill) in PARAMETER_TRANSFORMS


# =============================================================================
# Transform: align_two_axes -> sawyer_align_gripper / ur5_align_gripper
# =============================================================================

def _axes_to_approach_direction(local_axes: Tuple[str, str],
                                 world_axes: Tuple[str, str],
                                 axis_dirs: Tuple[int, int]) -> str:
    """
    Convert low-level axis specification to high-level approach direction.

    align_two_axes aligns gripper local axes to world axes:
    - local_axes[0] is the primary axis (gripper finger direction)
    - world_axes[0] is which world axis to align to
    - axis_dirs[0] is the direction (+1 or -1)

    Common patterns:
    - local='z', world='z', dir=-1 -> 'down' (gripper pointing down)
    - local='z', world='z', dir=+1 -> 'up' (gripper pointing up)
    - local='z', world='y', dir=+1 -> 'front' (gripper pointing forward)
    - local='z', world='y', dir=-1 -> 'back' (gripper pointing backward)
    - local='z', world='x', dir=+1 -> 'right' (gripper pointing right)
    - local='z', world='x', dir=-1 -> 'left' (gripper pointing left)
    """
    primary_local = local_axes[0].lower()
    primary_world = world_axes[0].lower()
    primary_dir = axis_dirs[0]

    # Map based on which world axis the gripper z is aligned to
    if primary_local == 'z':
        if primary_world == 'z':
            return 'down' if primary_dir == -1 else 'up'
        elif primary_world == 'y':
            return 'front' if primary_dir == 1 else 'back'
        elif primary_world == 'x':
            return 'right' if primary_dir == 1 else 'left'

    # Default fallback
    return 'down'


def _axes_to_yaw_mode(local_axes: Tuple[str, str],
                      world_axes: Tuple[str, str],
                      axis_dirs: Tuple[int, int]) -> str:
    """
    Determine yaw mode from axis alignment.

    The secondary axis (local_axes[1]) determines the yaw orientation:
    - If secondary local axis aligns parallel to a world axis -> 'parallel'
    - If secondary local axis is perpendicular -> 'perpendicular'

    This is a simplification - actual logic may be more complex.
    """
    if len(local_axes) < 2 or len(world_axes) < 2:
        return 'parallel'  # Default

    secondary_local = local_axes[1].lower()
    secondary_world = world_axes[1].lower()
    secondary_dir = axis_dirs[1] if len(axis_dirs) > 1 else 1

    # If secondary axes are the same (e.g., y->y), it's parallel
    # If different (e.g., y->x), it's perpendicular-ish
    if secondary_local == secondary_world:
        return 'parallel'
    else:
        return 'perpendicular'


@register_transform('align_two_axes', 'sawyer_align_gripper')
def transform_align_two_axes_to_sawyer(
    source_params: Dict[str, Any],
    source_code_context: Optional[str] = None
) -> Dict[str, Any]:
    """
    Transform align_two_axes parameters to sawyer_align_gripper parameters.

    Source (align_two_axes):
        local_axes: tuple of 2 axis names (e.g., ('z', 'y'))
        world_axes: tuple of 2 axis names (e.g., ('y', 'x'))
        axis_dirs: tuple of 2 directions (e.g., (1, -1))
        tol_rad: tolerance in radians (ignored)
        timeout: timeout in seconds

    Target (sawyer_align_gripper):
        approach_direction: 'down', 'up', 'front', 'back', 'left', 'right'
        reference_quat: quaternion (optional, can be None)
        yaw_mode: 'parallel' or 'perpendicular'
        timeout_s: timeout in seconds
    """
    local_axes = source_params.get('local_axes', ('z', 'y'))
    world_axes = source_params.get('world_axes', ('z', 'x'))
    axis_dirs = source_params.get('axis_dirs', (-1, 1))
    timeout = source_params.get('timeout', 10.0)

    # Convert to high-level parameters
    approach_direction = _axes_to_approach_direction(local_axes, world_axes, axis_dirs)
    yaw_mode = _axes_to_yaw_mode(local_axes, world_axes, axis_dirs)

    return {
        'approach_direction': approach_direction,
        'reference_quat': None,  # Will be determined at runtime if needed
        'yaw_mode': yaw_mode,
        'timeout_s': timeout
    }


@register_transform('align_two_axes', 'ur5_align_gripper')
def transform_align_two_axes_to_ur5(
    source_params: Dict[str, Any],
    source_code_context: Optional[str] = None
) -> Dict[str, Any]:
    """
    Transform align_two_axes parameters to ur5_align_gripper parameters.
    Same logic as sawyer, since ur5_align_gripper has similar interface.
    """
    # Reuse sawyer transform
    return transform_align_two_axes_to_sawyer(source_params, source_code_context)


# =============================================================================
# Helper function for online guidance generation
# =============================================================================

def apply_special_transform(
    source_skill: str,
    target_skill: str,
    source_params: Dict[str, Any],
    source_code_context: Optional[str] = None
) -> Optional[Dict[str, Any]]:
    """
    Apply special parameter transform if one exists.

    Args:
        source_skill: Source skill name (e.g., 'align_two_axes')
        target_skill: Target skill name (e.g., 'sawyer_align_gripper')
        source_params: Dict of source parameter values from callsite
        source_code_context: Optional code context for more complex transforms

    Returns:
        Dict of transformed target parameters, or None if no transform exists
    """
    transform_func = get_transform(source_skill, target_skill)
    if transform_func is None:
        return None

    try:
        return transform_func(source_params, source_code_context)
    except Exception as e:
        print(f"[parameter_transforms] Error applying transform {source_skill}->{target_skill}: {e}")
        return None


def generate_transformed_interface(
    source_skill: str,
    target_skill: str,
    source_params: Dict[str, Any]
) -> Optional[str]:
    """
    Generate the transformed function call interface string.

    Returns a string like:
        sawyer_align_gripper(env, task, approach_direction='front', yaw_mode='parallel', timeout_s=10.0)
    """
    transformed = apply_special_transform(source_skill, target_skill, source_params)
    if transformed is None:
        return None

    # Build parameter string
    param_parts = []
    for key, value in transformed.items():
        if value is None:
            continue
        if isinstance(value, str):
            param_parts.append(f"{key}='{value}'")
        else:
            param_parts.append(f"{key}={value}")

    params_str = ", ".join(param_parts)
    return f"{target_skill}(env, task, {params_str})"


# =============================================================================
# Test
# =============================================================================

if __name__ == "__main__":
    # Test align_two_axes -> sawyer_align_gripper
    test_cases = [
        # Case 1: CloseBox - gripper pointing forward to push lid
        {
            'local_axes': ('z', 'y'),
            'world_axes': ('y', 'x'),
            'axis_dirs': (1, -1),
            'timeout': 20.0
        },
        # Case 2: Standard top-down grasp
        {
            'local_axes': ('z', 'y'),
            'world_axes': ('z', 'y'),
            'axis_dirs': (-1, 1),
            'timeout': 10.0
        },
        # Case 3: Side approach
        {
            'local_axes': ('z', 'y'),
            'world_axes': ('x', 'z'),
            'axis_dirs': (1, 1),
            'timeout': 15.0
        },
    ]

    print("Testing align_two_axes -> sawyer_align_gripper transforms:")
    print("=" * 60)

    for i, params in enumerate(test_cases):
        print(f"\nTest case {i+1}:")
        print(f"  Input: {params}")

        result = apply_special_transform('align_two_axes', 'sawyer_align_gripper', params)
        print(f"  Output: {result}")

        interface = generate_transformed_interface('align_two_axes', 'sawyer_align_gripper', params)
        print(f"  Interface: {interface}")
