"""
Embedding functions for trajectory representation.

This module contains functions for creating different types of embeddings
for trajectories and computing related parameters.
"""

import numpy as np


def create_embeddings(phi_name, N_states, N_actions):
    """
    Create embedding function based on specified type.

    Args:
        phi_name: Type of embedding ('id_long', 'id_short', 'state_counts', 'final_state')
        N_states: Number of states in environment
        N_actions: Number of actions in environment

    Returns:
        Embedding function that maps trajectories to feature vectors
    """

    if phi_name == "id_long":
        return _create_id_long_embedding(N_states, N_actions)
    elif phi_name == "id_short":
        return _create_id_short_embedding(N_states, N_actions)
    elif phi_name == "state_counts":
        return _create_state_counts_embedding(N_states, N_actions)
    elif phi_name == "final_state":
        return _create_final_state_embedding(N_states, N_actions)
    else:
        raise ValueError(f"Unknown embedding type: {phi_name}")


def _create_id_long_embedding(N_states, N_actions):
    """Embeds a trajectory into a 1D, one-hot vector, ignoring rewards.
    Trajectory [s0,a0,r0,s1,a1,r1,...] becomes a concatenation of
    one-hot encodings of s0, a0, s1, a1, etc.

    Args:
    traj: A trajectory in the form [s0,a0,r0,s1,a1,r1,...]

    Returns:
        A 1D numpy array with concatenated one-hot vectors, length episode_length * (N_states + N_actions)
        containing 2 * episode_length 1s (two per state-action pair), rest 0s.
    """

    def phi_id_long(traj):
        result = np.array([], dtype=int)
        for t in range(len(traj)):
            if t % 3 == 0:  # states at indices 0,3,6..
                one_hot = np.zeros(N_states)
                one_hot[traj[t]] = 1
                result = np.append(result, one_hot)
            elif t % 3 == 1:  # actions at indices 1,4,7..
                one_hot = np.zeros(N_actions)
                one_hot[traj[t]] = 1
                result = np.append(result, one_hot)
        # skip rewards at indices 2,5,8..
        return result

    phi_id_long.name = "id_long"
    return phi_id_long


def _create_id_short_embedding(N_states, N_actions):
    """
    Embeds a trajectory into a 1D, one-hot vector, ignoring rewards.
    Trajectory [s0,a0,r0,s1,a1,r1,...] (list-like) where s_t in {0,...,N_states-1} and a_t in {0,...,N_actions-1}
    becomes sum_{t=1...H} [one-hot encoding s_t, one-hot encoding a_t].

    Returns:
        A 1D numpy array with concatenated one-hot vectors, length N_states + N_actions.
        2-norm bound B=sqrt(2)*H, since worst-case have ||(H,0,...,0,H,0,...,0)||_2 = sqrt(2*H^2)
    """

    def phi_id_short(traj):
        result_states = np.zeros(N_states)
        result_actions = np.zeros(N_actions)
        for t in range(len(traj)):
            if t % 3 == 0:  # even: state
                result_states[traj[t]] += 1
            elif t % 3 == 1:  # odd: action
                result_actions[traj[t]] += 1
            # skip rewards at indices 2, 5, 8, ..

        return np.concatenate([result_states, result_actions])

    phi_id_short.name = "id_short"
    return phi_id_short


def _create_state_counts_embedding(N_states, N_actions):
    """
    Embeds a trajectory by counting states.

    Args:
        traj: A trajectory in the form [s0,a0,r0,s1,a1,r1,...]

    Returns:
        A 1D numpy array of length N_states where each element represents
        the probability (frequency) of being in that state during the trajectory.
    """

    def phi_state_counts(traj):
        result = np.zeros(N_states)

        for t in range(len(traj)):
            if t % 3 == 0:  # states at indices 0,3,6...
                result[traj[t]] += 1
        return result

    phi_state_counts.name = "state_counts"
    return phi_state_counts


def _create_final_state_embedding(N_states, N_actions):
    """
    Embeds a trajectory by returning the final state as a one-hot vector.
    """

    def phi_final_state(traj):
        result = np.zeros(N_states)
        result[traj[-3]] = 1  # -3 because it's [..., s_H, a_H, r_H]
        return result

    phi_final_state.name = "final_state"
    return phi_final_state
