import numpy as np
from typing import Dict
from ray.rllib.policy.sample_batch import SampleBatch


from concept_extraction import (
    relative_orientation,
    distance_between,
    can_shoot_ordinal,
    agent_targeting_ordinal,
    distance_from_base,
    tom_extraction,
    distance_between_ordinal,
    relative_orientation_ordinal,
    attacker_stratagy,
)


name_to_func = {
    "relative_orientation": relative_orientation,
    "distance_between": distance_between,
    "can_shoot_ordinal": can_shoot_ordinal,
    "agent_targeting_ordinal": agent_targeting_ordinal,
    "distance_from_base": distance_from_base,
    "distance_between_ordinal": distance_between_ordinal,
    "relative_orientation_ordinal": relative_orientation_ordinal,
    "attacker_stratagy": attacker_stratagy,
}


def compute_concepts(
    config: Dict, rollout: SampleBatch, concept_configs, other_agent_batches
):
    # Computes concepts from a rollout.
    # Args:  rollout (SampleBatch): The rollout to compute concepts from.
    #        include_concepts (bool): whether to compute concepts.
    #        concept_function_list (list): A list of concepts functions to iterativly compute.
    # Returns:
    #        SampleBatch: The postprocessed, modified SampleBatch.
    # rollout['concept_lengths'] = []
    # rollout['concept_targets'] = np.array([])
    if len(concept_configs) != 0:
        agent_id = str(int(rollout["agent_index"][0]))
        agent_configs = concept_configs[agent_id]
        concept_layer_size = (
            config["model"]["custom_model_config"]["conceptdim"]
            + config["model"]["custom_model_config"]["bottleneck"]
        )

        rollout["concept_targets"] = np.array([])
        for concept_config in agent_configs.configs:
            concept_name = concept_config.name
            rollout = name_to_func[concept_name](concept_config, config, rollout)

        total_length = agent_configs.total_length

        for modeled_agent_id, ToM_config in agent_configs.tom_configs.items():
            ToM_total_length_id = ToM_config["total_length"]

            rollout = tom_extraction(
                modeled_agent_id, ToM_total_length_id, rollout, other_agent_batches
            )
            print(rollout["concept_targets"].shape)

            total_length += ToM_total_length_id

        # deal with remaining size for a consistant tensor size
        n_steps = rollout["obs"].shape[0]

        if total_length == 0:
            rollout["concept_targets"] = np.zeros((n_steps, concept_layer_size))
        elif total_length > 0 and total_length < concept_layer_size:
            extra_padding = np.zeros((n_steps, concept_layer_size - total_length))
            rollout["concept_targets"] = np.concatenate(
                [rollout["concept_targets"], extra_padding], axis=-1
            )

    return rollout
