from .StarCraft2_Env import StarCraft2Env
# from .multiagentenv import MultiAgentEnv
import numpy as np
# import os
from onpolicy.prompts.util import *
import logging
# import traceback

# set up a custom logger
def get_logger():
    logger = logging.getLogger()
    logger.handlers = []
    ch = logging.StreamHandler()
    formatter = logging.Formatter(
        "[%(levelname)s %(asctime)s] %(name)s %(message)s", "%H:%M:%S"
    )
    ch.setFormatter(formatter)
    logger.addHandler(ch)
    logger.setLevel("DEBUG")
    color = "\033[92m"
    color_end = "\033[0m"
    logging.addLevelName(logging.CRITICAL, f"{color}{logging.getLevelName(logging.CRITICAL)}{color_end}")

    return logger

logger = get_logger()

class SMAC(StarCraft2Env):
    def __init__(self, all_args):
        super(SMAC, self).__init__(all_args, obs_last_action=False)
        self.use_llm = False
        self.reward_mode = None

    def planning_function(self,processed_global_state):
        assert (0)

    def compute_reward(self, processed_global_state, task_plan, tasks):
        assert (0)

    # def seed(self, seed):
    #     self.env.seed(seed)

    def set_func(self, planning_function, compute_reward, reward_mode):
        self.planning_function = planning_function
        self.compute_reward = compute_reward
        self.use_llm = True
        self.reward_mode = reward_mode
        logger.critical(f"Planning function and compute reward function({self.reward_mode} mode) set")

    def step(self, actions):
        llm_rewards = [[0] for _ in range(self.n_agents)]
        old_local_obs, old_global_state, local_obs, global_state, rewards, dones, infos, available_actions = super().step(actions)

        if self.use_llm:
            try:
                process_global_state = import_function(
                    f"onpolicy.prompts.env_code.processed_obs_{self.map_name}", "process_global_state")
                processed_global_state = process_global_state(old_global_state, n=self.n_agents, m=self.n_enemies)
                llm_task = self.planning_function(processed_global_state)
                processed_actions = process_actions(actions)
                if  self.reward_mode == "pure":
                    # pure llm reward
                    llm_rewards = self.compute_reward(processed_global_state, llm_task, processed_actions)
                    # llm_rewards = normalized_reward(llm_rewards)
                    llm_rewards = convert_reward(llm_rewards)
                    rewards = [[l[0]] for l in llm_rewards]
                elif self.reward_mode == "mixed_normalized":
                    # original reward + llm constant aligned reward
                    llm_rewards = self.compute_reward(processed_global_state, llm_task, processed_actions)
                    llm_rewards = normalized_reward(llm_rewards)
                    llm_rewards = convert_reward(llm_rewards)
                    rewards = [[l[0]  + r[0]] for r, l in zip(rewards, llm_rewards)]
                elif self.reward_mode == "mixed_constant":
                    # original reward + llm constant aligned reward
                    rewards = constant_reward_signal(
                        processed_actions, llm_task, rewards, llm_reward=0.01, penalty=0.01)
                else:
                    raise NotImplementedError
            except Exception as e:
                raise e

        return local_obs, global_state, rewards, dones, infos, available_actions
  