import re
import numpy as np
from dataclasses import dataclass

from smac.env import StarCraft2Env
from typing import Any, Callable, Optional, Tuple, TypeVar, Union
from transformers import PreTrainedTokenizerBase


def is_conversational(example: dict[str, Any]) -> bool:
    supported_keys = ["prompt", "chosen", "rejected", "completion", "messages"]
    example_keys = {key for key in example.keys() if key in supported_keys}

    # It must have one of the supported keys
    if example_keys:
        key = example_keys.pop()  # take the first supported key
        maybe_messages = example[key]
        # It must be a list of messages,
        if isinstance(maybe_messages, list):
            maybe_message = maybe_messages[0]
            # Each message must a list of dictionaries with keys "role" and "content"
            if isinstance(maybe_message, dict) and "role" in maybe_message and "content" in maybe_message:
                return True

    return False


def apply_chat_template(
    example: dict[str, list[dict[str, str]]],
    tokenizer: PreTrainedTokenizerBase,
    tools: Optional[list[Union[dict, Callable]]] = None,
) -> dict[str, str]:
    r"""
    Apply a chat template to a conversational example along with the schema for a list of functions in `tools`.

    For more details, see [`maybe_apply_chat_template`].
    """
    # Check that the example has the correct keys
    supported_keys = ["prompt", "chosen", "rejected", "completion", "messages", "label"]
    example_keys = {key for key in example.keys() if key in supported_keys}
    if example_keys not in [
        {"messages"},  # language modeling
        {"prompt"},  # prompt-only
        {"prompt", "completion"},  # prompt-completion
        {"prompt", "chosen", "rejected"},  # preference
        {"chosen", "rejected"},  # preference with implicit prompt
        {"prompt", "completion", "label"},  # unpaired preference
    ]:
        raise KeyError(f"Invalid keys in the example: {example_keys}")

    # Apply the chat template to the whole conversation
    if "messages" in example:
        messages = tokenizer.apply_chat_template(example["messages"], tools=tools, tokenize=False)

    # Apply the chat template to the prompt, adding the generation prompt
    if "prompt" in example:
        last_role = example["prompt"][-1]["role"]
        if last_role == "user":
            add_generation_prompt = True
            continue_final_message = False
        elif last_role == "assistant":
            add_generation_prompt = False
            continue_final_message = True
        else:
            raise ValueError(f"Invalid role in the last message: {last_role}")
        prompt = tokenizer.apply_chat_template(
            example["prompt"],
            tools=tools,
            continue_final_message=continue_final_message,
            tokenize=False,
            add_generation_prompt=add_generation_prompt,
        )

    # Apply the chat template to the entire prompt + completion
    if "prompt" in example:  # explicit prompt and prompt-completion case
        if "chosen" in example:
            prompt_chosen = tokenizer.apply_chat_template(
                example["prompt"] + example["chosen"], tools=tools, tokenize=False
            )
            chosen = prompt_chosen[len(prompt) :]
        if "rejected" in example and "prompt" in example:  # explicit prompt
            prompt_rejected = tokenizer.apply_chat_template(
                example["prompt"] + example["rejected"], tools=tools, tokenize=False
            )
            rejected = prompt_rejected[len(prompt) :]
        if "completion" in example:
            prompt_completion = tokenizer.apply_chat_template(
                example["prompt"] + example["completion"], tools=tools, tokenize=False
            )
            completion = prompt_completion[len(prompt) :]
    else:  # implicit prompt case
        if "chosen" in example:
            chosen = tokenizer.apply_chat_template(example["chosen"], tools=tools, tokenize=False)
        if "rejected" in example:
            rejected = tokenizer.apply_chat_template(example["rejected"], tools=tools, tokenize=False)

    # Ensure that the prompt is the initial part of the prompt-completion string
    if "prompt" in example:
        error_message = (
            "The chat template applied to the prompt + completion does not start with the chat template applied to "
            "the prompt alone. This can indicate that the chat template is not supported by TRL."
            "\n**Prompt**:\n{}\n\n**Prompt + Completion**:\n{}"
        )
        if "chosen" in example and not prompt_chosen.startswith(prompt):
            raise ValueError(error_message.format(prompt, prompt_chosen))
        if "rejected" in example and not prompt_rejected.startswith(prompt):
            raise ValueError(error_message.format(prompt, prompt_rejected))
        if "completion" in example and not prompt_completion.startswith(prompt):
            raise ValueError(error_message.format(prompt, prompt_completion))

    # Extract the completion by removing the prompt part from the prompt-completion string
    output = {}
    if "messages" in example:
        output["text"] = messages
    if "prompt" in example:
        output["prompt"] = prompt
    if "chosen" in example:
        output["chosen"] = chosen
    if "rejected" in example:
        output["rejected"] = rejected
    if "completion" in example:
        output["completion"] = completion
    if "label" in example:
        output["label"] = example["label"]

    return output


def maybe_apply_chat_template(
    example: dict[str, list[dict[str, str]]],
    tokenizer: PreTrainedTokenizerBase,
    tools: Optional[list[Union[dict, Callable]]] = None,
) -> dict[str, str]:
    if is_conversational(example):
        return apply_chat_template(example, tokenizer, tools)
    else:
        return example

@dataclass
class SMACArgs:
    n_agents: int
    n_actions: int
    move_feats_dim: int
    enemy_feats_dim: Tuple[int, int]
    ally_feats_dim: Tuple[int, int]
    own_feats_dim: int
    nf_al: int
    nf_en: int
    shield_bits_ally: int
    shield_bits_enemy: int

class ChatMLProcessor:
    def __init__(self, map_name):
        self.env = StarCraft2Env(map_name=map_name)
        self.env_info = self.env.get_env_info()
        self.args = SMACArgs(
            n_agents=self.env_info["n_agents"],
            n_actions=self.env_info["n_actions"],
            move_feats_dim=self.env.get_obs_move_feats_size(),
            enemy_feats_dim=self.env.get_obs_enemy_feats_size(),
            ally_feats_dim=self.env.get_obs_ally_feats_size(),
            own_feats_dim=self.env.get_obs_own_feats_size(),
            nf_al=self.env.get_ally_num_attributes(),
            nf_en=self.env.get_enemy_num_attributes(),
            shield_bits_ally = self.env.shield_bits_ally,
            shield_bits_enemy = self.env.shield_bits_enemy,
        )
        self.chatml_system = [
            {
                "role": "system",
                "content": f"You are a strategic SMAC AI assistant on the {map_name} map. "
                           "Work with your team to complete the task."
            }
        ]

    def reset(self):
        self.env.reset()

    def get_action_list(self):
        action_list = [
            "no-op",
            "stop",
            "move north one step",
            "move south one step",
            "move east one step",
            "move west one step",
        ]
        enemy_ids = [i for i in range(self.args.n_agents, self.args.n_agents + self.args.enemy_feats_dim[0])]
        for enemy_id in enemy_ids:
            action_list.append(f"attack ID {enemy_id}")

        return action_list

    def get_avail_actions_text(self, avail_actions):
        action_list = self.get_action_list()

        if avail_actions.shape[-1] != len(action_list):
            raise ValueError(f"The last dimension of avail_actions ({avail_actions.shape[-1]}) "
                             f"does not match the length of action_list ({len(action_list)}).")

        steps, n_agents, _ = avail_actions.shape
        output = {f"avail_actions_{agent_id}": [] for agent_id in range(n_agents)}

        for t in range(steps):
            for agent_id in range(n_agents):
                avail = avail_actions[t, agent_id]
                action_texts = {action_list[i] for i, v in enumerate(avail) if v == 1}
                output[f"avail_actions_{agent_id}"].append(action_texts)

        return output

    def get_obs(self):
        return self.env.get_obs()

    def get_state(self):
        return self.env.get_state()

    def step(self, actions):
        return self.env.step(actions)

    def get_avail_actions(self):
        return self.env.get_avail_actions()

    def close(self):
        self.env.close()

    def infer(self, obs, actions, steps):
        agent_ids = np.arange(self.args.n_agents)
        sequences = []
        for i in range(self.args.n_agents):
            action = actions[i] if actions is not None else None
            input_text = self.get_input_text(obs[i], actions=action, agent_id=agent_ids[i])
            if steps == 0:
                agent_sequence = self.chatml_system + input_text
            else:
                agent_sequence = input_text
            sequences.append(agent_sequence)

        return sequences

    def format_input(self, obs, actions):
        steps, n_agents, _ = obs.shape
        agent_ids = np.arange(n_agents)
        sequences = []
        for i in range(n_agents):
            agent_sequence = [
                (self.chatml_system + self.get_input_text(obs[j, i], actions[j, i], agent_ids[i])) if j == 0
                else self.get_input_text(obs[j, i], actions[j, i], agent_ids[i])
                for j in range(steps)
            ]
            flat_sequence = [item for sublist in agent_sequence for item in sublist]
            sequences.append(flat_sequence)
        # flattened_sequence = [item for sublist in sequences for item in sublist]
        return sequences

    def get_response_text(self, actions):
        action_list = [
            "no-op",
            "stop",
            "move north one step",
            "move south one step",
            "move east one step",
            "move west one step",
        ]
        enemy_ids = [i for i in range(self.args.n_agents, self.args.n_agents + self.args.enemy_feats_dim[0])]
        for enemy_id in enemy_ids:
            action_list.append(f"attack ID {enemy_id}")
        response_text = action_list[actions.item()]

        return response_text

    def get_input_text(self, obs, actions, agent_id):
        move_end = self.args.move_feats_dim
        enemy_end = move_end + self.args.enemy_feats_dim[0] * self.args.enemy_feats_dim[1]
        ally_end = enemy_end + self.args.ally_feats_dim[0] * self.args.ally_feats_dim[1]
        # move_feats = obs[:move_end]
        enemy_feats = obs[move_end:enemy_end].reshape(self.args.enemy_feats_dim)
        ally_feats = obs[enemy_end:ally_end].reshape(self.args.ally_feats_dim)
        own_feats = obs[-self.args.own_feats_dim:].reshape(self.args.own_feats_dim)

        # own feature
        own_health = f"{own_feats[0] * 100:.2f}%"
        own_shield = f"{own_feats[1] * 100:.2f}%" if self.args.shield_bits_ally > 0 else '0%'
        own_agent_id = agent_id
        own_type_id = ''.join(map(str, map(int, own_feats[-self.env.unit_type_bits:]))) if self.env.unit_type_bits > 0 else '-'
        own_text = f"ID/T/H/S: {own_agent_id}/{own_type_id}/{own_health}/{own_shield}"

        # Ally features
        ally_ids = [i for i in range(self.args.n_agents) if i != own_agent_id]
        visible_allies, _ = self.format_unit_text(ally_feats, ally_ids, units='ally')

        # Enemy features
        enemy_ids = [j for j in range(self.args.n_agents, self.args.n_agents + enemy_feats.shape[0])]
        visible_enemies, _ = self.format_unit_text(enemy_feats, enemy_ids, units='enemy')
        user_content = (
            f"Own feature, {own_text}. "
            f"Allies in sight, {visible_allies}. "
            f"Enemies in sight, {visible_enemies}."
        )
        chatml = [{"role": "user", "content": user_content}]

        if actions is not None:
            response_text = self.get_response_text(actions)
            chatml.append({
                "role": "assistant",
                "content": f"{response_text}."
            })

        return chatml

    def format_unit_text(self, feats, ids, units):
        visible_allies = []
        invisible_allies = []
        if units == 'ally':
            shield_bits = self.args.shield_bits_ally
        elif units == 'enemy':
            shield_bits = self.args.shield_bits_enemy
        else:
            raise ValueError(f"Invalid value for 'units': {units}. Expected 'ally' or 'enemy'.")

        for i, agent_id in enumerate(ids):
            visible = feats[i, 0] == 1.0
            data = {
                "ID": agent_id,
                "Type": ''.join(map(str, map(int, feats[i, -self.env.unit_type_bits:]))) if self.env.unit_type_bits > 0 else '-',
                "X": f"{feats[i, 2]:.2f}",
                "Y": f"{feats[i, 3]:.2f}",
                "H": f"{feats[i, 4] * 100:.2f}%",
                "S": f"{feats[i, 5] * 100:.2f}%" if shield_bits > 0 else '0%'
            }
            if visible:
                visible_allies.append("{ID}/{Type}/{X}/{Y}/{H}/{S}".format(**data))
            else:
                invisible_allies.append(f"ID {agent_id} (unknown)")

        visible_text = "ID/T/X/Y/H/S: " + ", ".join(visible_allies) if visible_allies else "None"
        invisible_text = ", ".join(invisible_allies) if invisible_allies else "None"

        return visible_text, invisible_text

def maybe_convert_to_chatml(example: dict[str, list], map_name: str, move_avail_actions: bool) -> dict[str, list]:
    if "obs" in example and isinstance(example["obs"], list):
        map_name = re.sub(r'_part_\d+$', '', map_name)
        prompt = ChatMLProcessor(map_name)
        instruction_text = prompt.format_input(np.array(example['obs']), np.array(example['actions']))
        output = {f"messages_{i}": instruction_text[i] for i in range(len(instruction_text))}
        if not move_avail_actions:
            avail_actions_text = prompt.get_avail_actions_text(np.array(example['avail_actions']))
            output.update(avail_actions_text)
        return output
    else:
        return example

def reshape_messages(batch: dict[str, list[str]]) -> dict[str, list[str]]:
    message_columns = [col for col in batch.keys() if col.startswith("messages_")]
    if message_columns:
        messages = [msg for row in zip(*[batch[col] for col in message_columns]) for msg in row]
        return {"messages": messages}
    else:
        return batch

def expand_agents_batch(batch: dict[str, list[str]]) -> dict[str, list[str]]:
    messages = []
    avail_actions = []
    message_keys = [k for k in batch if k.startswith("messages_")]
    num_agents = len(message_keys)
    batch_size = len(batch[message_keys[0]])

    for i in range(batch_size):
        for agent_id in range(num_agents):
            messages.append(batch[f"messages_{agent_id}"][i])
            avail_actions.append(batch[f"avail_actions_{agent_id}"][i])

    return {
        "messages": messages,
        "avail_actions": avail_actions,
    }

def extract_turns_with_avail_actions(batch: dict[str, list]) -> dict[str, list]:
    all_prompts = []
    all_completions = []
    all_avail_actions = []
    max_turns = 30
    for b_idx, messages in enumerate(batch['messages']):
        avail_actions = batch['avail_actions'][b_idx]
        system_message = messages[0]

        for i in range(1, len(messages) - 1, 2):
            if messages[i]['role'] == 'user' and i + 1 < len(messages) and messages[i + 1]['role'] == 'assistant':
                turn_index = i // 2
                prompt = [system_message]
                for j in range(max(1, i + 1 - max_turns), i + 1):
                    prompt.append(messages[j])

                completion = [messages[i + 1]]
                if isinstance(avail_actions, list) and turn_index < len(avail_actions):
                    current_avail_action = avail_actions[turn_index]
                else:
                    current_avail_action = avail_actions

                all_prompts.append(prompt)
                all_completions.append(completion)
                all_avail_actions.append(current_avail_action)

    return {
        "prompt": all_prompts,
        "completion": all_completions,
        "avail_actions": all_avail_actions
    }

def pack_examples(examples: dict[str, list[list]], seq_length: int) -> dict[str, list[list]]:
    # Join  all the values into a single list
    examples = {k: sum(v, []) for k, v in examples.items()}
    # Split the values into chunks of size seq_length
    examples = {k: [v[i : i + seq_length] for i in range(0, len(v), seq_length)] for k, v in examples.items()}
    return examples

def maybe_remove_padding(example: dict[str, list]) -> dict[str, list]:
    """
    Remove entries where 'filled' == 0 from the dataset.
    """
    if "filled" in example and isinstance(example["filled"], list):
        # Create a mask for valid entries
        mask = [filled_value[0] == 1 for filled_value in example["filled"]]
        # Filter all keys in the example based on the mask
        for key in example.keys():
            if isinstance(example[key], list):
                example[key] = [value for value, keep in zip(example[key], mask) if keep]

    return example
