from src.components.offline_buffer import DataSaver, OfflineBufferH5, OfflineBufferH5WoAgent
from itertools import combinations
from src.modules.decomposers.sc2_decomposer import SC2Decomposer
import yaml
from types import SimpleNamespace as SN
import torch as th
import numpy as np
from call_llm import gpt_agent, deepseek_agent
import random
import openai
from openai import OpenAI
import os
import time
import json
import datetime
import re
import copy
from copy import deepcopy as dco
import h5py

def get_trajectory(path, trajectory_id):
    buffer = OfflineBufferH5(Args(), None, None, data_path=path, shuffle=False)
    filled = buffer.data['filled'][trajectory_id]
    obs = buffer.data['obs'][trajectory_id]
    action = buffer.data['actions'][trajectory_id]
    terminated = buffer.data['terminated'][trajectory_id]
    state = buffer.data['state'][trajectory_id]
    max_t_filled = filled.sum(0)[0]
    final_id = max_t_filled - 1
    return obs[:final_id + 1], action[:final_id + 1], state[:final_id + 1]
def get_trajectory_wo_agent(path, trajectory_id, buffer=None):
    if buffer is None:
        buffer = OfflineBufferH5WoAgent(Args(), None, None, data_path=path, shuffle=False)
    filled = buffer.data['filled'][trajectory_id]
    obs = buffer.data['obs'][trajectory_id]
    action = buffer.data['actions'][trajectory_id]
    terminated = buffer.data['terminated'][trajectory_id]
    state = buffer.data['state'][trajectory_id]
    max_t_filled = filled.sum(0)[0]
    final_id = max_t_filled - 1
    return obs[:final_id + 1], action[:final_id + 1], state[:final_id + 1]
class Args:
    def __init__(self):
        self.offline_data_folder = None
unit_id2type = {
    "[1.0, 0.0]":"Zealot",
    "[0.0, 1.0]":"Stalker"
}
def describe_agent_obs_compact(obs: np.ndarray, decomposer) -> str:
    obs_t = th.tensor(obs, dtype=th.float32).unsqueeze(0)
    own_obs, enemy_feats, ally_feats = decomposer.decompose_obs(obs_t)

    lines = []

    move_dim = decomposer.move_feats
    own_feats = own_obs[0, move_dim:]
    desc = "Own agent:"

    own_id = 0
    if decomposer.obs_own_health:
        desc += f" healthy ratio {own_feats[own_id].item():.2f}"
        own_id += 1
    if decomposer.shield_bits_ally:
        desc += f", shield ratio {own_feats[own_id].item():.2f}"
        own_id += 1
    if decomposer.unit_type_bits > 0:
        desc += f", agent type {unit_id2type[str(own_feats[own_id:].tolist())]}"
        # desc += f", unit type {own_feats[own_id:]}"

    lines.append(desc)
    
    for j in range(decomposer.n_enemies):
        ef = enemy_feats[j][0]
        dx, dy = ef[2].item(), ef[3].item()
        if dx==0 and dy==0:
            lines.append(f"enemy {j}: invisible")
            continue
        dist = ef[1].item()
        entry = f"enemy {j}:"
        if ef[0].item() == 1:
            entry += f" within attack range"
        elif ef[0].item() == 0:
            entry += f" out of attack range"
        else:
            assert False

        if dx>=0:
            x_dir = "east"
            abs_dx = dx
        else:
            x_dir = "west"
            abs_dx = -dx
        if dy>=0:
            y_dir = "north"
            abs_dy = dy
        else:
            y_dir = "south"
            abs_dy = -dy
        # entry += f", {x_dir} direction relative distance {abs_dx:.2f}, {y_dir} direction relative distance {abs_dy:.2f}, relative continental distance {dist:.2f}"
        entry += f", x-line relative distance {dx:.2f}, y-line relative distance {dy:.2f}, relative continental distance {dist:.2f}"
        
        enemy_id = 4
        if decomposer.obs_all_health:
            entry += f", healthy ratio {ef[enemy_id].item():.2f}"
            enemy_id += 1
        if decomposer.shield_bits_enemy:
            entry += f", shield ratio {ef[enemy_id].item():.2f}"
            enemy_id += 1
        if decomposer.unit_type_bits > 0:
            entry += f", agent type {unit_id2type[str(ef[enemy_id:].tolist())]}"
    
        lines.append(entry)

    for k in range(decomposer.n_agents - 1):
        af = ally_feats[k][0]
        if af[0].item() == 0:
            lines.append(f"ally {k}: invisible")
            continue
        dx, dy = af[2].item(), af[3].item()
        dist = af[1].item()
        entry = f"ally {k}:"
        if dx>=0:
            x_dir = "east"
            abs_dx = dx
        else:
            x_dir = "west"
            abs_dx = -dx
        if dy>=0:
            y_dir = "north"
            abs_dy = dy
        else:
            y_dir = "south"
            abs_dy = -dy
        # entry += f", {x_dir} direction relative distance {abs_dx:.2f}, {y_dir} direction relative distance {abs_dy:.2f}, relative continental distance {dist:.2f}"
        entry += f", x-line relative distance {dx:.2f}, y-line relative distance {dy:.2f}, relative continental distance {dist:.2f}"
        ally_id = 4
        if decomposer.obs_all_health:
            entry += f", healthy ratio {af[ally_id].item():.2f}"
            ally_id += 1
        if decomposer.shield_bits_ally:
            entry += f", shield ratio {af[ally_id].item():.2f}"
            ally_id += 1
        if decomposer.unit_type_bits > 0:
            entry += f", agent type {unit_id2type[str(af[ally_id:].tolist())]}"
        
        lines.append(entry)

    return "; ".join(lines) + "."
    
def describe_agent_action(action, decomposer):
    if action[0] == 0:
        return "None"
    elif action[0] == 1:
        return "Stop"
    elif action[0] == 2:
        return "Move North"
    elif action[0] == 3:
        return "Move South"
    elif action[0] == 4:
        return "Move East"
    elif action[0] == 5:
        return "Move West"
    else:
        target_id = action[0] - decomposer.n_actions_no_attack
        return f"Attack Enemy {target_id}"

def describe_traj(obs, action, decomposer):
    desc = ""
    step_len = obs.shape[0]
    for i in range(step_len-1):
        tmp_action = action[i]
        action_desc = describe_agent_action(tmp_action, decomposer)
        if action_desc == "None":
            continue
        tmp_obs = obs[i]
        obs_desc = describe_agent_obs_compact(tmp_obs, decomposer)
        # if action_desc != "None":
        desc += f"Step {i+1}"+", Observation {"+obs_desc+"}, Action {"+action_desc+"};\n"
        # desc += f"Step {i+1}, Observation \{{obs_desc}\}, Action \{{action_desc}\}; "
    return desc
class LLM:
    def __init__(self, mode='openai') -> None:
        if mode == 'openchat':
            os.makedirs('logs', exist_ok=True)
            self.call_llm = self.call_llm_openchat
        elif mode == 'openai':
            api_key_list = [
            ]
            # self.agent_big = gpt_agent(random.choice(api_key_list), api_key_list, model_name='chatgpt-4o-latest')
            self.agent_big_gemini = gpt_agent(random.choice(api_key_list), api_key_list, model_name='gemini-2.5-pro')
            self.agent_big_gpt = gpt_agent(random.choice(api_key_list), api_key_list, model_name='gpt-5')
            # self.agent_big = gpt_agent(random.choice(api_key_list), api_key_list, model_name='gpt-4o-2024-08-06')
            self.agent_small = gpt_agent(random.choice(api_key_list), api_key_list, model_name='gpt-4o-mini')
            self.call_llm = self.call_llm_openai

    def call_llm_openchat(self, prompt, big_model=None, temperature=None):
        timing = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        prompt = prompt.replace("\\", "\\\\").replace("\n", "\\n").replace("\t", "\\t").replace('"', '\\"').replace("'", '\\"')
        command_before_prompt = 'curl http://localhost:18888/v1/chat/completions -H "Content-Type: application/json" -d \'{"model": "openchat_3.5", "messages": [{"role": "user", "content": "'
        command_after_prompt = '"}]}\' > ' + f'logs/{timing}.json'
        command = command_before_prompt + prompt + command_after_prompt
        os.system(command)
        f = open(f"logs/{timing}.json")
        data = json.load(f)
        f.close()
        os.system(f'rm logs/{timing}.json')
        return data["choices"][0]["message"]["content"]

    def call_llm_gpt(self, prompt):
        return self.agent_big_gpt.ask(prompt)

    def call_llm_gemini(self, prompt):
        return self.agent_big_gemini.ask(prompt)
    
    def call_llm_openai(self, prompt, big_model=False):
        if big_model:
            return self.agent_big.ask(prompt)
        else:
            return self.agent_small.ask(prompt)

def extract_answer(text):
    """
    从字符串中提取被 <answer></answer> 包裹的内容。
    如果有多个 <answer>...</answer>，会返回一个列表。
    """
    pattern = r"<answer>(.*?)</answer>"
    results = re.findall(pattern, text, re.DOTALL)
    return results

offline_data_ls = {
}
train_task_ls = ["3m","5m","7m","10m"]
map2prior_role_id = {
}
llm = LLM(mode='openai')

task2roles = {}
for task in train_task_ls:
    task2roles[task] = None
typical_traj_num = 10
for task in task2roles.keys():
    print(f"Currently in task {task}! Trajectory describing start!")
    task_prompt = f"The agents are playing the {task} map of the SMAC multi-agent environment. \
    In this game, the agents need to cooperate with allies to beat all the enemies. \
    "
    instruction_prompt = "Following, you are an expert SMAC player. You will be given the trajectory of one of the agents. \
    Next, you need to provide a concise description of that trajectory; in your description you should highlight the agent’s behavioral style and characteristics, \
    and, where possible, summarize the role the agent played — for example, a tank that lures enemies on the front line, or a kiter dealing damage from the back line. \
    Each trajectory is made up of multiple steps, and each step contains two quantities: Observation and Action. \
    Please respond after careful consideration, your final answer should be within <answer> and </answer> tags. \
    Below is the trajectory: \
    "
    task_prior_role_id = np.load(map2prior_role_id[task])["prior_role_id"]
    num_roles = len(list(set(task_prior_role_id)))
    task_role2traj_id = {}
    task_role2repr_id = {}
    task_role2traj_desc = {}
    for i in range(num_roles):
        task_role2traj_id[i]=[]
        task_role2traj_desc[i]=[]
    for i in range(len(task_prior_role_id)):
        task_role2traj_id[task_prior_role_id[i]].append(i)
    for key in task_role2traj_id.keys():
        if len(task_role2traj_id[key]) <= typical_traj_num:
            task_role2repr_id[key] = task_role2traj_id[key]
        else:
            task_role2repr_id[key] = random.sample(task_role2traj_id[key], typical_traj_num)
    print(task_role2repr_id)
    env_config_path = "src/config/envs/sc2_offline_meta1.yaml"
    with open(env_config_path, "r") as file:
        config_dict = yaml.safe_load(file)
    config_dict["env_args"]["map_name"] = task
    args = SN(**config_dict)
    path = offline_data_ls[task]
    buffer = OfflineBufferH5WoAgent(args, None, None, data_path=path, shuffle=False)
    decomposer = SC2Decomposer(args)
    for key in task_role2traj_id.keys():
        for traj_id in task_role2repr_id[key]:
            obs, action, state = get_trajectory_wo_agent(path, traj_id, buffer)
            text_traj=describe_traj(obs, action, decomposer)
            prompt = task_prompt + instruction_prompt + text_traj
            answer = llm.call_llm_gpt(prompt)
            answer = extract_answer(answer)
            task_role2traj_desc[key].append(answer[0])
    instruction_prompt2 = " Following, you are an expert SMAC player. You will be given several descriptive summaries of typical trajectories belonging to a specific type of agent. \
    Please analyze and synthesize these descriptions to summarize the role this agent typically plays within the team. Please respond after careful consideration, your final answer should be within <answer> and </answer> tags. \
    Below are the descriptions: \
    "
    task_role2summary = {}
    for key in task_role2traj_desc.keys():
        tmp_prompt = task_prompt + instruction_prompt2
        for i in range(len(task_role2traj_desc[key])):
            tmp_prompt += f"Description {i+1}: " + task_role2traj_desc[key][i] + "\n"
        answer = llm.call_llm_gpt(tmp_prompt)
        answer = extract_answer(answer)
        task_role2summary[key]=answer[0]
        print(f"Task {task} prior role {key} summary: {answer[0]}!")
    task2roles[task] = task_role2summary
    print(f"Task {task} trajectory describing end!")

instruction_prompt3 = " Following, you are an expert SMAC player. Across multiple tasks, we have roughly categorized the roles that agents play within SMAC teams into several types, and for each type, a linguistic description of its behavioral style and role has been generated. \
However, many of these role types are actually identical in essence. Based on the linguistic descriptions of each role type, please merge those that represent the same underlying role, and output the merged role categories along with the list of original category names included within each merged category. Be cautious in merging—if two classes show a noticeable degree of difference, keep them separate. The final number of merged role categories can be between 3 and 4.\n\
Below, we will provide several class names along with their corresponding language descriptions in the format Class Name: {}, Description: {};. \
Please analyze the semantic content of each description and merge the classes that represent the same underlying role type. Use 0, 1, 2, … to denote the merged role categories, and output your results strictly in the following format:\n\
<answer>0:{Class Name1, Class Name2, …}, 1:{Class Name3, Class Name4, …}, … </answer>.\n\
Please respond after careful consideration. Below are the class names and descriptions:\n\
"
final_prompt = instruction_prompt3
for task in task2roles.keys():
    for role in task2roles[task].keys():
        class_name = f"{task}_{role}"
        description = task2roles[task][role]
        final_prompt += f"Class Name: {class_name}, Description: {description};\n"
answer = llm.call_llm_gpt(final_prompt)
answer = extract_answer(answer)

pattern = r"(\d+):\{([^}]*)\}"
matches = re.findall(pattern, answer[0])
result = {}
for key, values in matches:
    items = [v.strip() for v in values.split(",")]
    result[int(key)] = items
role2task = {}
role_data_ls = {}
for role in result.keys():
    role2task[role] = []
    for value in result[role]:
        task = value.split("_")[0]
        prior_role = value.split("_")[1]
        if task not in role2task[role]:
            role2task[role].append(task)
    base_dict = {
        'actions':[],
        'actions_onehot':[],
        'avail_actions':[],
        'filled':[],
        'obs':[],
        'reward':[],
        'state':[],
        'terminated':[],
    }
    task_dict = {}
    for task in role2task[role]:
        task_dict[task] = dco(base_dict)
    role_data_ls[role]=dco(task_dict)
print(role2task)

task_role2traj_id = {}
for task in train_task_ls:
    task_prior_role_id = np.load(map2prior_role_id[task])["prior_role_id"]
    num_roles = len(list(set(task_prior_role_id)))
    tmp_dict = {}
    for i in range(num_roles):
        tmp_dict[i]=[]
    for i in range(len(task_prior_role_id)):
        tmp_dict[task_prior_role_id[i]].append(i)
    task_role2traj_id[task] = tmp_dict

for role in result.keys():
    for value in result[role]:
        task = value.split("_")[0]
        prior_role = int(value.split("_")[1])
        env_config_path = "src/config/envs/sc2_offline_meta1.yaml"
        with open(env_config_path, "r") as file:
            config_dict = yaml.safe_load(file)
        config_dict["env_args"]["map_name"] = task
        args = SN(**config_dict)
        path = offline_data_ls[task]
        buffer = OfflineBufferH5WoAgent(args, None, None, data_path=path, shuffle=False)
        data = buffer.data
        for traj_id in task_role2traj_id[task][prior_role]:
            for data_key in data.keys():
                data_value = data[data_key][traj_id]
                role_data_ls[role][task][data_key].append(data_value)
                
for role in result.keys():
    for task in role2task[role]:
        for key in role_data_ls[role][task].keys():
            role_data_ls[role][task][key] = np.array(role_data_ls[role][task][key])
            print(role,task,key,role_data_ls[role][task][key].shape)

for role in result.keys():
    for task in role2task[role]:
        data_path = ""
        os.makedirs(os.path.dirname(data_path), exist_ok=True)
        data = role_data_ls[role][task]
        with h5py.File(data_path, 'w') as f:
            for key, value in data.items():
                f.create_dataset(key, data=value)