import re
import numpy as np
import importlib
import os
import glob

def convert_reward(reward_dict):
    '''
    param: reward (dict): dict containing rewards for each agent.
    return: list: list containing rewards for each agent.
    '''
    rewards = []
    for agent_id, reward in reward_dict.items():
        if isinstance(reward, np.ndarray):
            reward = reward.tolist()
        else:
            reward = [reward]
        rewards.append(reward)
    return rewards

def dict2array(dict):
    '''
    param: dict (dict): dict containing rewards for each agent. 
    return: array: array containing rewards for each agent. (nx1)
    '''
    arr = []
    for agent_id, val in dict.items():
        arr.append(val)
    return np.array(arr).reshape(-1, 1)

def convert_observation(observation):
    n = len(observation)
    agent_ids = [f"agent_{i}" for i in range(n)]
    obs_dict = {}
    for agent_id, obs in zip(agent_ids, observation):
        obs_dict[agent_id] = obs
    return obs_dict


def clean_obs_code(obs_code_str):
    cleaned_code = re.sub(
        r'^(import .+|from .+)', '', obs_code_str, flags=re.MULTILINE).strip()
    return cleaned_code

def process_actions(actions):
    '''
    action (1xn): action to be processed.
    '''
    output = {}
    for i, action in enumerate(actions):
        output['agent_' + str(i)] = action_class(action)
    return output

def process_available_actions(available_actions):
    '''
    param:
        available_actions (list of n list): list of available actions for each agent.
    return:
        action_dict (Dict of n list): dict containing available actions indices for each agent.
    '''
    n_agents = len(available_actions)
    action_dict = {}
    for i in range(n_agents):
        indices = np.where(np.array(available_actions[i]) == 1)[0].tolist()
        action_dict[f"agent_{i}"] = indices
    return action_dict

def get_gencode_path(env_name):
    prompt_dir = os.path.dirname(os.path.abspath(__file__))
    gencode_dir = os.path.join(prompt_dir, "gen_code", env_name, "code")
    print(gencode_dir)
    print(os.listdir(gencode_dir))
    gencode_path = os.path.join(gencode_dir, f'*_generated_code_{len(os.listdir(gencode_dir)) - 1}.py')
    gencode_path = glob.glob(gencode_path)[0]
    return gencode_path

def setup_wrapper(env, gencode_path, reward_mode):
    with open(gencode_path, 'r') as f:
            code_str = f.read()
    namespace = {**globals()}
    exec(code_str, namespace)
    planning_function = namespace['planning_function']
    compute_reward = namespace['compute_reward']
    env.set_func(planning_function, compute_reward, reward_mode)
    return env

def action_class(action, advanded=False):
    '''
    param: action (int): action to be classified.
    return: str: action class.
    '''
    if advanded:
        if action == 0:
            return "None"
        elif action == 1:
            return "Stop"
        elif action == 2:
            return "Move North"
        elif action == 3:
            return "Move South"
        elif action == 4:
            return "Move West"
        elif action == 5:
            return "Move East"
        elif action >= 6:
            return "Attack"
        else:
            raise ValueError("Invalid action")
    else:
        if action == 0:
            return "None"
        elif action == 1:
            return "Stop"
        elif action >= 2 and action <= 5:
            return "Move"
        elif action >= 6:
            return "Attack"
        else:
            raise ValueError("Invalid action")

def import_function(module_name, func_name):
    try:
        # Attempt to import the module
        module = importlib.import_module(module_name)
        func = getattr(module, func_name)
        return func
    except ImportError as e:
        print(f"Error importing {module_name}: {e}")
        return None
    
def constant_reward_signal(action_class, llm_action_class, reward_lst, llm_reward=0.01, penalty=0.01):
    final_rewards = []
    # Reward for following LLM suggestions
    for i, (agent, cls) in enumerate(llm_action_class.items()):
        reward = reward_lst[i][0]
        if action_class[agent] == cls:
            final_rewards.append([reward + llm_reward])
        else:
            final_rewards.append([reward - penalty])
    return final_rewards

def normalized_reward(reward, theta=0.01):
    min_reward = min(reward.values())
    max_reward = max(reward.values())
    if min_reward != max_reward:
        for agent_id in reward:
            reward[agent_id] = (reward[agent_id] - min_reward) / (max_reward - min_reward) * theta
    return reward

