import numpy as np
from typing import Union, Literal, List, Tuple

from torch_robotics.environments.env_base import EnvBase
from torch_robotics.environments import *

def collision_detect(
    trajs: np.ndarray, 
    safe_dist: float, 
    norm_order: Union[int, str] = 2, 
)->np.ndarray: 
    """
    Check collision for batched trajectories and return a mask. 

    Args:
        traj (np.ndarray): A batch of trajectories with shape (..., player, horizon, D), 
            where D is the dimension of the state space. 
        safe_dist (float): The safe distance threshold for collision detection.
        norm_order (Union[int, str]): The order of the norm to use for distance calculation. 
            Can be a positive integer or 'inf'

    Returns: 
        np.ndarray: A boolean array of shape (..., player, player, horizon) indicating pair-wise trajectory collision masks on all horizon steps. 
    """

    # Ensure traj is a numpy array
    trajs = np.asarray(trajs)

    assert trajs.ndim >= 3, "Trajectory must have at least 3 dimensions (player, horizon and state)." 
    # Check if the last dimension is at least 2 (for x and y coordinates)
    if trajs.shape[-1] < 1:
        raise ValueError("Trajectory must have at least 1 dimensions for collision detection.")
    assert trajs.shape[-3] >= 2, "There must be at least 2 players in the trajectory array for collision detection."
    
    # Calculate the distance from the origin (0, 0) for each point in the trajectory
    # distances = np.linalg.norm(trajs[..., 0, :, :2] - trajs[..., 1, :, :2], ord=norm_order, axis=-1)
    # Compute pairwise distances between all players at each time step
    n_players = trajs.shape[-3]
    horizon = trajs.shape[-2]
    state_dim = trajs.shape[-1]
    preceding_dims = trajs.shape[:-3]

    # Reshape trajs to allow for broadcasting between players
    # New shape: (..., player, 1, horizon, D)
    traj1 = trajs.reshape(preceding_dims + (n_players, 1, horizon, state_dim))
    # New shape: (..., 1, player, horizon, D)
    traj2 = trajs.reshape(preceding_dims + (1, n_players, horizon, state_dim))

    # Compute pairwise distances using broadcasting
    # This will yield a tensor of shape (..., player, player, horizon)
    distances = np.linalg.norm(traj1 - traj2, ord=norm_order, axis=-1)

    # Set self-distances to infinity to avoid detecting self-collisions
    for p in range(n_players):
        # distances[..., p, p, :] = float('inf')
        distances[..., p, p, :] = safe_dist + 1  # Set self-distances to a value greater than safe_dist
    
    # Check if any distance exceeds the safe distance threshold
    collision_masks = distances < safe_dist

    # print(f"Collision masks shape: {collision_masks.shape}, "
    #       f"trajs shape: {trajs.shape}, "
    #       f"safe_dist: {safe_dist}, "
    #       f"norm_order: {norm_order}")
    
    # Sanity check
    expected_shape = trajs.shape[:-3] + (n_players, n_players, horizon)
    assert collision_masks.shape == expected_shape, (
        f"Expected collision mask shape {expected_shape}, got {collision_masks.shape}"
    )
    
    return collision_masks


def calc_data_adherence(
    trajs: np.ndarray, 
    agent_model_ids_l: List[str], 
    agent_model_transforms_l: List[List], 
    start_time_l: List[int],
    horizon: int, 
    tensor_args: dict,
) -> np.ndarray:
    """
    Calautate data adherence. 

    Args:
        trajs (np.ndarray): A batch of trajectories with shape (..., n_agents, horizon, dimension). 
        agent_model_ids_l (List[str]): List of model IDs for each agent.
        agent_model_transforms_l (List[List]): List of transforms for each agent's model.
        start_time_l (List[int]): List of start times for each agent.
        horizon (int): The horizon length of the trajectories.
        tensor_args (dict): Arguments for tensor creation, e.g., device and dtype.

    Returns:
        np.ndarray: Data adherence scores for each agent with shape (..., n_agents).
    """
    B, I, N, H, D = trajs.shape
    # Our metric for determining how well a path is adhering to the data.
    # Computed by the environment. If it is a single map, the score is the adherence there.
    # If it is a multi-tile map, the score is the average adherence over all tiles.
    # single_trial_result.data_adherence = 0.0
    data_adherence = torch.zeros(*trajs.shape[:-2], **tensor_args)
    traj_tensor = torch.from_numpy(trajs).to(**tensor_args)
    for agent_id in range(N):
        agent_data_adherence = 0.0
        for skeleton_step, agent_model_id in enumerate(agent_model_ids_l[agent_id]):
            agent_model_transform = agent_model_transforms_l[agent_id][skeleton_step]
            agent_start_time = start_time_l[agent_id]
            single_tile_traj_len = horizon
            agent_path_in_model_frame = (
                traj_tensor[..., agent_id, :, :].clone()[
                    ..., 
                    agent_start_time + skeleton_step * single_tile_traj_len:
                        agent_start_time + (skeleton_step + 1) * single_tile_traj_len, 
                    :2
                ] - agent_model_transform
            )
            model_env_name = agent_model_id.split('-')[0]
            kwargs = {'tensor_args': tensor_args}
            env_object = eval(model_env_name)(**kwargs)
            assert hasattr(env_object, 'compute_traj_data_adherence_batch'), \
                "Environment object must implement `compute_traj_data_adherence_batch` method."
            agent_data_adherence += env_object.compute_traj_data_adherence_batch(agent_path_in_model_frame)
        agent_data_adherence /= len(agent_model_ids_l[agent_id])
        # data_adherence += agent_data_adherence
        data_adherence[..., agent_id] = agent_data_adherence
    #end for [agent_id]
    
    # data_adherence /= N 
    data_adherence = data_adherence.to('cpu').numpy() if isinstance(data_adherence, torch.Tensor) else data_adherence
    
    return data_adherence


def calc_velocity_vec(
    trajs: np.ndarray, 
    dt: float, 
) -> np.ndarray:
    """
    Calculate the velocity of the trajectories.

    Args:
        trajs (np.ndarray): A batch of trajectories with shape (..., horizon, D).
        dt (float): The time step duration.

    Returns:
        np.ndarray: The velocity of the trajectories with shape (..., horizon-1, D).
    """
    # Ensure trajs is a numpy array
    trajs = np.asarray(trajs)
    assert trajs.ndim >= 2, "Trajectory must have at least 2 dimensions (horizon and state)."

    # Calculate the difference between consecutive points in the trajectory
    velocities_vec = np.diff(trajs, axis=-2) / dt

    # Return the velocities
    return velocities_vec


def calc_velocity(
    trajs: np.ndarray, 
    dt: float, 
    norm_kwargs: dict = {}
) -> np.ndarray:
    """
    Calculate the velocity of the trajectories.

    Args:
        trajs (np.ndarray): A batch of trajectories with shape (..., horizon, D).
        dt (float): The time step duration.

    Returns:
        np.ndarray: The velocity of the trajectories with shape (..., horizon-1).
    """
    # Ensure trajs is a numpy array
    trajs = np.asarray(trajs)
    assert trajs.ndim >= 2, "Trajectory must have at least 2 dimensions (horizon and state)."

    default_norm_kwargs = {
        "ord": 2,  # Default to L2 norm
        "axis": -1,  # Default to last dimension
        "keepdims": False,  # Do not keep dimensions
    }
    # Update default norm kwargs with provided ones
    default_norm_kwargs.update(norm_kwargs)

    v_vec = calc_velocity_vec(trajs, dt)
    velocities = np.linalg.norm(v_vec, **default_norm_kwargs)

    # Return the velocities
    return velocities
