"""Base runner for on-policy algorithms."""

import time
import glob
import asyncio
from typing import List
import numpy as np
import torch
import setproctitle
import os
import wandb
from PIL import Image
from harl.algorithms.actors import ALGO_REGISTRY
from harl.utils.trans_tools import _t2n
from harl.utils.string_utils import replace_unsupported_chars
from harl.utils.envs_tools import (
    make_eval_env,
    make_train_env,
    make_render_env,
    set_seed,
    get_num_agents,
    update_positions_from_obs,
    get_ego_minimap_text,
)
from harl.utils.models_tools import init_device
from harl.utils.configs_tools import init_dir, init_wandb, save_config
from harl.utils.image_utils import convert_np_to_images
from harl.envs import LOGGER_REGISTRY
from harl.algorithms.actors.llm import LLM
from harl.common.llm_logger import Logger
from harl.common.memory import LocalMemory, GlobalMemory
from harl.configs.config import Config
from harl.utils.log_processor import process_log_messages
import harl.constants as constants
from harl.common.buffers.offline_buffer import OfflineBuffer
from tqdm import tqdm

class LLMsPolicyBaseRunner:
    """Base runner for on-policy algorithms."""

    def __init__(self, args, algo_args, env_args):
        """Initialize the OnPolicyBaseRunner class.
        Args:
            args: command-line arguments parsed by argparse. Three keys: algo, env, exp_name.
            algo_args: arguments related to algo, loaded from config file and updated with unparsed command-line arguments.
            env_args: arguments related to env, loaded from config file and updated with unparsed command-line arguments.
        """
        self.config = Config()
        self.llm_logger = Logger()
        self.global_memory = GlobalMemory(self.config.global_mem_dir, max_hops=self.config.max_hops)
        self.args = args
        self.algo_args = algo_args
        self.env_args = env_args

        self.state_type = env_args.get("state_type", "EP")
        self.share_param = algo_args["algo"]["share_param"]
        set_seed(algo_args["seed"])
        self.device = init_device(algo_args["device"])
        self.run_dir, self.log_dir, self.save_dir, self.writter = self.config.run_dir, self.config.log_dir, self.config.save_dir, self.config.writter
        # self.run_dir, self.log_dir, self.save_dir, self.writter = init_dir(
        #     args["env"],
        #     env_args,
        #     args["algo"],
        #     args["exp_name"],
        #     algo_args["seed"]["seed"],
        #     logger_path=algo_args["logger"]["log_dir"],
        # )
        if not self.algo_args["train"]["debug"]:
            self.wandb_run = init_wandb(args, env_args, algo_args, args["exp_name"], algo_args["seed"]["seed"], self.run_dir)
        save_config(args, algo_args, env_args, self.run_dir)
    # set the title of the process
        setproctitle.setproctitle(
            str(args["algo"]) + "-" + str(args["env"]) + "-" + str(args["exp_name"])
        )

        # set the config of env
        if self.algo_args["render"]["use_render"]:  # make envs for rendering
            (
                self.envs,
                self.manual_render,
                self.manual_expand_dims,
                self.manual_delay,
                self.env_num,
            ) = make_render_env(args["env"], algo_args["seed"]["seed"], env_args)
        else:  # make envs for training and evaluation
            self.envs = make_train_env(
                args["env"],
                algo_args["seed"]["seed"],
                algo_args["train"]["n_rollout_threads"],
                env_args,
            )
            self.eval_envs = (
                make_eval_env(
                    args["env"],
                    algo_args["seed"]["seed"],
                    algo_args["eval"]["n_eval_rollout_threads"],
                    env_args,
                )
                if algo_args["eval"]["use_eval"]
                else None
            )
        self.num_agents = get_num_agents(args["env"], env_args, self.envs)

        print("share_observation_space: ", self.envs.share_observation_space)
        print("observation_space: ", self.envs.observation_space)
        print("action_space: ", self.envs.action_space)
        self.n_actions_no_attack = self.envs.n_actions_no_attack

        self.share_image_dir = self.run_dir + "/share_img"
        for env_id in range(self.algo_args["train"]["n_rollout_threads"]):
            os.makedirs(self.share_image_dir + "/env_" + str(env_id), exist_ok=True)
        obs, share_obs, available_actions, local_images, share_images, obs_text, share_obs_text = self.reset_env()

        self.unit_types = self.envs.get_unit_types()
        # actor
        self.actor = []
        assert args["algo"] == "llms", "only supported for LLMs."
        llm_config = algo_args[args["provider"].upper()]
        self.task_description = self.env_args["task_description"]
        
        for agent_id in range(self.num_agents):
            agent_args = {
                "agent_id": agent_id,
                "mem_dir": os.path.join(self.run_dir, "memories", f"agent-{agent_id}"),
                "unit_type": self.unit_types[:, agent_id],
            }
            agent = LLM(
                args,
                llm_config,
                algo_args,
                env_args,
                agent_args,
                self.task_description,
                use_full_mode=algo_args['algo']['use_full_mode'],
            )
            self.actor.append(agent)
        
        self.global_memory.embedding_provider = self.actor[0].embed_provider
        self.global_memory.unit_races = self.unit_types
        self.warmup(share_images, local_images, obs_text, available_actions)
        
        self.logger = LOGGER_REGISTRY[args["env"]](
            args, algo_args, env_args, self.num_agents, self.writter, self.run_dir
        )
        

    def run(self):
        """Run the training (or rendering) pipeline."""
        if self.algo_args["render"]["use_render"] is True:
            self.render()
            return
        if self.algo_args["offline"]["use_offline"] is True:
            self.offline()
            return
        print("start running")
        # self.warmup()
        self.train_episode = 0
        train_steps = [0 for _ in range(self.algo_args["train"]["n_rollout_threads"])]

        self.logger.llm_train_init()  # logger callback at the beginning of training

        start_time = time.time()
        while True:
            actions = asyncio.run(self.collect(train_steps, self.train_episode))
            # actions: (n_threads, n_agents, action_dim)
            (
                obs,
                share_obs,
                rewards,
                dones,
                infos,
                available_actions,
                local_images,
                share_images,
                obs_text,
                share_obs_text,
            ) = self.envs.step(actions)

            self.available_actions = available_actions
            train_steps = [s + 1 for s in train_steps]

            #TODO exec_info dict exec_info["errors"] exec_info["errors_info"]
            local_images = convert_np_to_images(local_images)
            share_images = convert_np_to_images(share_images)

            # obs: (n_threads, n_agents, obs_dim)
            # share_obs: (n_threads, n_agents, share_obs_dim)
            # rewards: (n_threads, n_agents, 1)
            # dones: (n_threads, n_agents)
            # infos: (n_threads)
            # available_actions: (n_threads, ) of None or (n_threads, n_agents, action_number)
            data = (
                rewards,
                infos,
            )

            self.logger.llm_train_per_step(data)  # logger callback at each step

            train_dones_env = np.all(dones, axis=1)

            # log information
            for train_i in range(self.algo_args["train"]["n_rollout_threads"]):
                if train_dones_env[train_i]:
                    train_steps[train_i] = 0
                    self.train_episode += 1
                    self.logger.llm_train_thread_done(
                        train_i,
                        self.train_episode
                    )  # logger callback when an episode is done
                    is_win = False
                    if "smac" in self.args["env"]:
                        if "v2" in self.args["env"]:
                            if infos[train_i][0]["battle_won"]:
                                is_win = True
                        else:
                            if infos[train_i][0]["won"]:
                                is_win = True
                    
                    if is_win:
                        params = {
                            "win_lose": "win",

                        }
                        self.llm_logger.write(f"\033[92m[WIN]\033[0m Episode {self.train_episode} done with victory! 🎉")
                        
                    else:
                        params = {
                            "win_lose": "lose",

                        }
                        self.llm_logger.write(f"\033[91m[LOSE]\033[0m Episode {self.train_episode} done with defeat. 😔")
                    for agent_id in range(self.num_agents):
                        self.actor[agent_id].memory[train_i].update_info_history(params)
                    if not self.algo_args["train"]["debug"]:
                        wandb.log({"average_episode_rewards": np.mean(self.logger.train_episode_rewards[train_i])}, step=self.train_episode)
                        wandb.log({"win_rate": self.logger.train_win_rate}, step=self.train_episode)
                    self.llm_logger.write(f"\033[94mCurrent win rate: {self.logger.train_win_rate:.2f}\033[0m")

            self.insert(train_steps, share_images, local_images, obs_text, rewards, train_dones_env, available_actions)  # insert data into buffer
            for env_id in range(self.algo_args["train"]["n_rollout_threads"]):
                share_image_filename = self.share_image_dir + "/env_" + str(env_id)+ "/step_" + str(train_steps[env_id]) + ".jpg"
                share_images[env_id][0].save(share_image_filename)
            # # eval
            # if train_episode % self.algo_args["train"]["eval_interval"] == 0:
            #     if self.algo_args["eval"]["use_eval"]:
            #         self.prep_rollout()
            #         self.eval()
            #     self.save()
            
            if self.train_episode >= self.algo_args["train"]["train_episodes"]:
                self.logger.llm_train_log(
                    self.train_episode
                )  # logger callback at the end of evaluation
                end_time = time.time()
                self.llm_logger.write(f"Training finished in {end_time - start_time:.2f} seconds.")
                self.envs.envs[0].env.save_replay()
                break

    def reset_env(self):
        # reset env
        obs, share_obs, available_actions, local_images, share_images, obs_text, share_obs_text = self.envs.reset()
        self.available_actions = available_actions
        local_images = convert_np_to_images(local_images)
        share_images = convert_np_to_images(share_images)
        return obs, share_obs, available_actions, local_images, share_images, obs_text, share_obs_text
    
    def warmup(self, share_images, local_images, obs_text, available_actions):
        """Warm up the memory."""
        self.global_memory.reset()

        for env_id in range(self.algo_args["train"]["n_rollout_threads"]):
            share_image_filename = self.share_image_dir + "/env_" + str(env_id)+ "/step_" + str(0) + ".jpg"
            share_images[env_id][0].save(share_image_filename)
            for agent_id in range(self.num_agents):
                init_params = {
                    constants.SHARE_IMAGES_MEM_BUCKET: share_image_filename,
                }
                self.actor[agent_id].memory[env_id].update_info_history(init_params)

        for agent_id in range(self.num_agents):
            for env_id in range(self.algo_args["train"]["n_rollout_threads"]):
                # Save screenshots
                screen_image_filename = self.actor[agent_id].memory[env_id].screenshot_path + "/step_" + str(0) + ".jpg"
                local_images[env_id][agent_id].save(screen_image_filename)
                init_params = {
                    "task_description": self.task_description,
                    "skill_library": self.actor[agent_id].skill_library[env_id],
                    "exec_error": "",
                    "pre_action": "",
                    "pre_screen_classification": "",
                    "decision_making_reasoning": "",
                    "pre_decision_making_reasoning": "",
                    "pre_self_reflection_reasoning": "",
                    "observation": obs_text[env_id, agent_id].copy(),
                    "available_actions": available_actions[env_id, agent_id].copy(),
                    constants.IMAGES_MEM_BUCKET: screen_image_filename,
                    "start_frame_id": 0,
                    "end_frame_id": 0,
                }
                # Create task for async memory update
                
                self.actor[agent_id].memory[env_id].update_info_history(init_params)
                update_positions_from_obs(self.global_memory, env_id, obs_text[env_id, agent_id], agent_id, 0)
        self.global_memory.share_unit_information_all()
        for agent_id in range(self.num_agents):
            for env_id in range(self.algo_args["train"]["n_rollout_threads"]):
                ego_minimap = ""
                ego_minimap = get_ego_minimap_text(self.global_memory, self.actor[agent_id].memory[env_id], env_id, agent_id, 0)
                init_params = {
                    "ego_minimap": ego_minimap,
                }
                # Create task for async memory update
                self.actor[agent_id].memory[env_id].update_info_history(init_params)

    @torch.no_grad()
    async def collect(self, step: List[int], episode: int):
        """Collect actions from actors.
        Args:
            step: step in the episode for different envs.
        Returns:
            actions
        """
        # Create concurrent tasks for each agent
        action_tasks = [
            self.actor[agent_id].get_actions(step, episode)
            for agent_id in range(self.num_agents)
        ]
        
        # Run all agent tasks concurrently
        action_collector = await asyncio.gather(*action_tasks)

        action_collector = list(zip(*action_collector))
        
        step = [s + 1 for s in step]
        return action_collector

    def insert(self, step, share_images, local_images, obs_text, rewards, dones, available_actions):
        """Insert data into memory."""
        self.global_memory.reset_knowledge()

        for env_id in range(self.algo_args["train"]["n_rollout_threads"]):
            if dones[env_id]:
                self.global_memory.reset(env_id)
                # convert images to GIF
                # Get list of image files in chronological order
                share_image_dir_env = self.share_image_dir + "/env_" + str(env_id)
                def sort_image_files(file_list):
                    """
                    Sort image files by step number (step_X.jpg)
                    """
                    def extract_number(filename):
                        # Extract number from 'step_X.jpg' format
                        return int(filename.split('step_')[1].split('.')[0])
                    
                    return sorted(file_list, key=extract_number)

                # Replace the existing sorting line with:
                image_files = sort_image_files([os.path.join(share_image_dir_env, f) 
                                            for f in os.listdir(share_image_dir_env) 
                                            if f.endswith('.jpg')])
                
                # Open images and create GIF
                images = [Image.open(f) for f in image_files]
                if images:
                    gif_path = os.path.join(share_image_dir_env, f'episode_{self.train_episode}.gif')
                    images[0].save(gif_path, save_all=True, append_images=images[1:], duration=100, loop=0)

        for agent_id in range(self.num_agents):
            for env_id in range(self.algo_args["train"]["n_rollout_threads"]):
                if dones[env_id]:
                    if not self.algo_args["train"]["debug"]:
                        win_lose = self.actor[agent_id].memory[env_id].get_recent_history("win_lose", k=1)[0]
                        if win_lose == "lose":
                            # asyncio.run(self.actor[agent_id].run_skill_generation(env_id))
                            if self.config.use_image:
                                asyncio.run(self.actor[agent_id].summarize_env_experiences(env_id))
                    self.unit_types = self.envs.get_unit_types()
                    self.global_memory.unit_races = self.unit_types
                    self.actor[agent_id].unit_type = self.unit_types[:, agent_id]
                    self.actor[agent_id].memory[env_id].clean_memory(unit_type=self.unit_types[env_id, agent_id])
                # Save screenshots
                screen_image_filename = self.actor[agent_id].memory[env_id].screenshot_path + "/step_" + str(step[env_id]) + ".jpg"
                local_images[env_id][agent_id].save(screen_image_filename)
                params = {
                    "observation": obs_text[env_id, agent_id].copy(),
                    constants.IMAGES_MEM_BUCKET: screen_image_filename,
                    "reward": float(rewards[env_id, agent_id, 0]),
                    "done": dones[env_id].copy(),
                    "available_actions": available_actions[env_id, agent_id].copy(),
                }
                self.actor[agent_id].memory[env_id].update_info_history(params)
                update_positions_from_obs(self.global_memory, env_id, obs_text[env_id, agent_id], agent_id, step[env_id])
        self.global_memory.share_unit_information_all()

        for env_id in range(self.algo_args["train"]["n_rollout_threads"]):
            if dones[env_id]:
                # convert images to GIF
                # Get list of image files in chronological order
                share_image_dir_env = self.share_image_dir + "/env_" + str(env_id)
                def sort_image_files(file_list):
                    """
                    Sort image files by step number (step_X.jpg)
                    """
                    def extract_number(filename):
                        # Extract number from 'step_X.jpg' format
                        return int(filename.split('step_')[1].split('.')[0])
                    
                    return sorted(file_list, key=extract_number)

                # Replace the existing sorting line with:
                image_files = sort_image_files([os.path.join(share_image_dir_env, f) 
                                            for f in os.listdir(share_image_dir_env) 
                                            if f.endswith('.jpg')])

                # Clean up image directory
                for f in image_files:
                    os.remove(f)
                
            share_image_filename = self.share_image_dir + "/env_" + str(env_id)+ "/step_" + str(step[env_id]) + ".jpg"
            share_images[env_id][0].save(share_image_filename)
            for agent_id in range(self.num_agents):
                params = {
                    constants.SHARE_IMAGES_MEM_BUCKET: share_image_filename,
                }
                self.actor[agent_id].memory[env_id].update_info_history(params)

        for agent_id in range(self.num_agents):
            for env_id in range(self.algo_args["train"]["n_rollout_threads"]):
                ego_minimap = ""
                ego_minimap = get_ego_minimap_text(self.global_memory, self.actor[agent_id].memory[env_id], env_id, agent_id, step[env_id])
                params = {
                    "ego_minimap": ego_minimap,
                }
                # Create task for async memory update
                self.actor[agent_id].memory[env_id].update_info_history(params)

    def after_update(self):
        """Do the necessary data operations after an update.
        After an update, copy the data at the last step to the first position of the buffer.
        This will be used for then generating new actions.
        """
        for agent_id in range(self.num_agents):
            self.actor_buffer[agent_id].after_update()
        self.critic_buffer.after_update()

    @torch.no_grad()
    def eval(self):
        """Evaluate the model."""
        self.logger.llm_eval_init()  # logger callback at the beginning of evaluation
        eval_episode = 0

        eval_obs, eval_share_obs, eval_available_actions = self.eval_envs.reset()

        eval_rnn_states = np.zeros(
            (
                self.algo_args["eval"]["n_eval_rollout_threads"],
                self.num_agents,
                self.recurrent_n,
                self.rnn_hidden_size,
            ),
            dtype=np.float32,
        )
        eval_masks = np.ones(
            (self.algo_args["eval"]["n_eval_rollout_threads"], self.num_agents, 1),
            dtype=np.float32,
        )

        while True:
            eval_actions_collector = []
            for agent_id in range(self.num_agents):
                eval_actions, temp_rnn_state = self.actor[agent_id].act(
                    eval_obs[:, agent_id],
                    eval_rnn_states[:, agent_id],
                    eval_masks[:, agent_id],
                    eval_available_actions[:, agent_id]
                    if eval_available_actions[0] is not None
                    else None,
                    deterministic=True,
                )
                eval_rnn_states[:, agent_id] = _t2n(temp_rnn_state)
                eval_actions_collector.append(_t2n(eval_actions))

            eval_actions = np.array(eval_actions_collector).transpose(1, 0, 2)

            (
                eval_obs,
                eval_share_obs,
                eval_rewards,
                eval_dones,
                eval_infos,
                eval_available_actions,
            ) = self.eval_envs.step(eval_actions)
            eval_data = (
                eval_obs,
                eval_share_obs,
                eval_rewards,
                eval_dones,
                eval_infos,
                eval_available_actions,
            )
            self.logger.eval_per_step(
                eval_data
            )  # logger callback at each step of evaluation

            eval_dones_env = np.all(eval_dones, axis=1)

            eval_rnn_states[
                eval_dones_env == True
            ] = np.zeros(  # if env is done, then reset rnn_state to all zero
                (
                    (eval_dones_env == True).sum(),
                    self.num_agents,
                    self.recurrent_n,
                    self.rnn_hidden_size,
                ),
                dtype=np.float32,
            )

            eval_masks = np.ones(
                (self.algo_args["eval"]["n_eval_rollout_threads"], self.num_agents, 1),
                dtype=np.float32,
            )
            eval_masks[eval_dones_env == True] = np.zeros(
                ((eval_dones_env == True).sum(), self.num_agents, 1), dtype=np.float32
            )

            for eval_i in range(self.algo_args["eval"]["n_eval_rollout_threads"]):
                if eval_dones_env[eval_i]:
                    eval_episode += 1
                    self.logger.eval_thread_done(
                        eval_i
                    )  # logger callback when an episode is done

            if eval_episode >= self.algo_args["eval"]["eval_episodes"]:
                self.logger.eval_log(
                    eval_episode
                )  # logger callback at the end of evaluation
                break

    @torch.no_grad()
    def render(self):
        """Render the model."""
        print("start rendering")
        if self.manual_expand_dims:
            # this env needs manual expansion of the num_of_parallel_envs dimension
            for _ in range(self.algo_args["render"]["render_episodes"]):
                eval_obs, _, eval_available_actions = self.envs.reset()
                eval_obs = np.expand_dims(np.array(eval_obs), axis=0)
                eval_available_actions = (
                    np.expand_dims(np.array(eval_available_actions), axis=0)
                    if eval_available_actions is not None
                    else None
                )
                eval_rnn_states = np.zeros(
                    (
                        self.env_num,
                        self.num_agents,
                        self.recurrent_n,
                        self.rnn_hidden_size,
                    ),
                    dtype=np.float32,
                )
                eval_masks = np.ones(
                    (self.env_num, self.num_agents, 1), dtype=np.float32
                )
                rewards = 0
                while True:
                    eval_actions_collector = []
                    for agent_id in range(self.num_agents):
                        eval_actions, temp_rnn_state = self.actor[agent_id].act(
                            eval_obs[:, agent_id],
                            eval_rnn_states[:, agent_id],
                            eval_masks[:, agent_id],
                            eval_available_actions[:, agent_id]
                            if eval_available_actions is not None
                            else None,
                            deterministic=True,
                        )
                        eval_rnn_states[:, agent_id] = _t2n(temp_rnn_state)
                        eval_actions_collector.append(_t2n(eval_actions))
                    eval_actions = np.array(eval_actions_collector).transpose(1, 0, 2)
                    (
                        eval_obs,
                        _,
                        eval_rewards,
                        eval_dones,
                        _,
                        eval_available_actions,
                    ) = self.envs.step(eval_actions[0])
                    rewards += eval_rewards[0][0]
                    eval_obs = np.expand_dims(np.array(eval_obs), axis=0)
                    eval_available_actions = (
                        np.expand_dims(np.array(eval_available_actions), axis=0)
                        if eval_available_actions is not None
                        else None
                    )
                    if self.manual_render:
                        self.envs.render()
                    if self.manual_delay:
                        time.sleep(0.1)
                    if eval_dones[0]:
                        print(f"total reward of this episode: {rewards}")
                        break
        else:
            # this env does not need manual expansion of the num_of_parallel_envs dimension
            # such as dexhands, which instantiates a parallel env of 64 pair of hands
            for _ in range(self.algo_args["render"]["render_episodes"]):
                eval_obs, _, eval_available_actions = self.envs.reset()
                eval_rnn_states = np.zeros(
                    (
                        self.env_num,
                        self.num_agents,
                        self.recurrent_n,
                        self.rnn_hidden_size,
                    ),
                    dtype=np.float32,
                )
                eval_masks = np.ones(
                    (self.env_num, self.num_agents, 1), dtype=np.float32
                )
                rewards = 0
                while True:
                    eval_actions_collector = []
                    for agent_id in range(self.num_agents):
                        eval_actions, temp_rnn_state = self.actor[agent_id].act(
                            eval_obs[:, agent_id],
                            eval_rnn_states[:, agent_id],
                            eval_masks[:, agent_id],
                            eval_available_actions[:, agent_id]
                            if eval_available_actions[0] is not None
                            else None,
                            deterministic=True,
                        )
                        eval_rnn_states[:, agent_id] = _t2n(temp_rnn_state)
                        eval_actions_collector.append(_t2n(eval_actions))
                    eval_actions = np.array(eval_actions_collector).transpose(1, 0, 2)
                    (
                        eval_obs,
                        _,
                        eval_rewards,
                        eval_dones,
                        _,
                        eval_available_actions,
                    ) = self.envs.step(eval_actions)
                    rewards += eval_rewards[0][0][0]
                    if self.manual_render:
                        self.envs.render()
                    if self.manual_delay:
                        time.sleep(0.1)
                    if eval_dones[0][0]:
                        print(f"total reward of this episode: {rewards}")
                        break
        if "smac" in self.args["env"]:  # replay for smac, no rendering
            if "v2" in self.args["env"]:
                self.envs.env.save_replay()
            else:
                self.envs.save_replay()

    def offline(self):
        from harl.utils.file_utils import assemble_project_path, read_resource_file
        # from harl.utils.skill_utils import SkillExtractionPipeline
        import json

        if "smac" in self.args["env"]:
            if "v2" in self.args["env"]:
                self.envs.close()
                if self.eval_envs is not None:
                    self.eval_envs.close()
                from harl.common.planner.smacv2_planner import SkillGeneration, SelfReflection
                # skill_extraction_pipeline = SkillExtractionPipeline()
                # Check if offline_dataset.json exists
                if os.path.exists('offline_dataset.json'):
                    print("Loading existing offline dataset...")
                    with open('offline_dataset.json', 'r') as f:
                        offline_dataset = json.load(f)
                else:
                    print("Creating new offline dataset...")
                    offline_buffer = OfflineBuffer(self.args, self.algo_args, self.env_args)
                    offline_dataset = offline_buffer.buffer.data
                    # Save offline_dataset as json
                    with open('offline_dataset.json', 'w') as f:
                        json.dump(offline_dataset, f)
                
                # offline_dataset = glob.glob(os.path.join(self.env_args["skill"]["skill_configs"]["skill_dataset_path"], "*.mp4"))

                tactic_template = read_resource_file(self.config.planner_params["prompt_paths"]["templates"]["offline_tactics_generation"])
                self.actor[0].planner.self_reflection_ = SelfReflection(template=tactic_template,
                                                 llm_provider=self.actor[0].llm_provider)
                skill_template = read_resource_file(self.config.planner_params["prompt_paths"]["templates"]["offline_skill_generation"])
                self.actor[0].planner.skill_generation_ = SkillGeneration(template=skill_template,
                                                 llm_provider=self.actor[0].llm_provider)
                # offline_dataset = skill_extraction_pipeline.extract_action_patterns(offline_dataset)
                # offline_dataset = skill_extraction_pipeline.cluster_similar_patterns(offline_dataset)
                # Read clusters back from file for analysis
                # with open('action_clusters.json', 'r') as f:
                #     offline_dataset = json.load(f)

                # Generate tactics
                tactics_file = "tactics.json"
                if os.path.exists(tactics_file):
                    print("Loading existing tactics...")
                    with open(tactics_file, 'r') as f:
                        hist_tactic = json.load(f)
                    if self.algo_args["offline"]["refine_tactics"]:
                        tactic_template = read_resource_file(self.config.planner_params["prompt_paths"]["templates"]["offline_tactics_refine"])
                        self.actor[0].planner.self_reflection_ = SelfReflection(template=tactic_template,
                                                        llm_provider=self.actor[0].llm_provider)
                        refined_hist_tactic = {}
                        print("Refining existing tactics...")
                        for tactic_name, tactic in tqdm(hist_tactic.items(), desc='Refining tactics'):
                            print(f"Generating skills for tactic: {tactic_name}")
                            working_area = {
                                "tactic": tactic,
                                "hist_tactic": ', '.join(refined_hist_tactic.keys()),
                            }
                                
                            all_generated_tactics = asyncio.run(self.actor[0].planner.self_reflection_(input=working_area))['res_dict']
                            if 'tactic' in all_generated_tactics:
                                tactic = all_generated_tactics['tactic']
                            else:
                                continue
                            # Extract tactic name from the first line
                            tactic_name = tactic.split('\n')[0].replace('1. Tactic Name:', '').strip()
                            
                            if tactic_name not in refined_hist_tactic:
                                refined_hist_tactic[tactic_name] = tactic
                                print(f"New tactic: {tactic_name}")
                            else:
                                print(f"Existing tactic: {tactic_name}")
                        hist_tactic = refined_hist_tactic
                        # Save the tactics dictionary to a JSON file
                        print("Saving tactics to file...")
                        
                        with open(tactics_file, 'w') as f:
                            json.dump(hist_tactic, f, indent=4)
                else:
                    print("Creating new tactics...")
                    hist_tactic = {}
                    tactic_template = read_resource_file(self.config.planner_params["prompt_paths"]["templates"]["offline_tactics_generation_video"])
                    self.actor[0].planner.self_reflection_ = SelfReflection(template=tactic_template,
                                                    llm_provider=self.actor[0].llm_provider)
                        
                    # for i, data in tqdm(enumerate(range(200)), desc='Generating tactics', total=200):
                    #     working_area = {
                    #         "tactic": ', '.join(hist_tactic.keys()),
                    #     }

                    #     # Generate tactics
                    #     all_generated_tactics = asyncio.run(self.actor[0].planner.self_reflection_(input=working_area))['res_dict']
                    #     if 'tactic' in all_generated_tactics:
                    #         tactic = all_generated_tactics['tactic']
                    #     else:
                    #         continue
                    #     # Extract tactic name from the first line
                    #     tactic_name = tactic.split('\n')[1].replace('-', '').strip()
                        
                    #     if tactic_name not in hist_tactic:
                    #         hist_tactic[tactic_name] = tactic
                    #         print(f"New tactic: {tactic_name}")
                    #     else:
                    #         print(f"Existing tactic: {tactic_name}")
                    tactic_idx = 0
                    for i, data in tqdm(enumerate(offline_dataset), desc='Generating tactics', total=len(offline_dataset)):
                        num_frames = len(data)
                        if num_frames > 10:
                            # Keep first and last frame, sample remaining frames
                            indices = np.linspace(0, num_frames-1, 10, dtype=int)
                            sampled_frames = [data[i] for i in indices]
                            data = sampled_frames

                        image_introduction = [
                        {
                            "introduction": "Here are the sequential frames of the winning game video.",
                            "path": data,
                            "assistant": "",
                            "resolution": "low"
                        }]
                        working_area = {
                            "image_introduction": image_introduction,
                            # "tactic": ', '.join(hist_tactic.keys()),
                        }

                        # Generate tactics
                        all_generated_tactics = asyncio.run(self.actor[0].planner.self_reflection_(input=working_area))['res_dict']
                        if 'tactic' in all_generated_tactics:
                            tactic = all_generated_tactics['tactic']
                            hist_tactic[tactic_idx] = tactic
                            tactic_idx += 1
                        else:
                            continue
                        # # Extract tactic name from the first line
                        # tactic_name = tactic.split('\n')[0].replace('1. Tactic Name:', '').strip()
                        
                        # if tactic_name not in hist_tactic:
                        #     hist_tactic[tactic_name] = tactic
                        #     print(f"New tactic: {tactic_name}")
                        # else:
                        #     print(f"Existing tactic: {tactic_name}")
                    # Save the tactics dictionary to a JSON file
                    print("Saving tactics to file...")
                    
                    with open(tactics_file, 'w') as f:
                        json.dump(hist_tactic, f, indent=4)

                # Generate skills
                for tactic_name, tactic in tqdm(hist_tactic.items(), desc='Generating skills'):

                    print(f"Generating skills for tactic: {tactic_name}")
                    working_area = {
                        "tactic": tactic,
                    }

                    prompt = skill_template.replace("<$tactic$>", tactic)

                    print(prompt)
                        
                    # all_generated_actions = asyncio.run(self.actor[0].planner.skill_generation_(input=working_area))['res_dict'][constants.SKILL_GENERATION_MODULE]
                    # for extracted_skills in all_generated_actions:
                    #     self.actor[0].skill_registry.register_skill_from_code(skill_code=extracted_skills['code'])
                
                    # skills = self.actor[0].skill_registry.skills
                    # self.actor[0].skill_library = self.actor[0].skill_registry.get_skill_information(skills)

                skills = self.actor[0].skill_registry.skills
                self.actor[0].skill_registry.store_skills_to_file(os.path.join(self.actor[0].skill_registry.skill_local_path, self.actor[0].skill_registry.skill_library_filename), skills)
                
    def save(self):
        """Save model parameters."""
        for agent_id in range(self.num_agents):
            policy_actor = self.actor[agent_id].actor
            torch.save(
                policy_actor.state_dict(),
                str(self.save_dir) + "/actor_agent" + str(agent_id) + ".pt",
            )
        policy_critic = self.critic.critic
        torch.save(
            policy_critic.state_dict(), str(self.save_dir) + "/critic_agent" + ".pt"
        )
        if self.value_normalizer is not None:
            torch.save(
                self.value_normalizer.state_dict(),
                str(self.save_dir) + "/value_normalizer" + ".pt",
            )

    def restore(self):
        """Restore model parameters."""
        for agent_id in range(self.num_agents):
            policy_actor_state_dict = torch.load(
                str(self.algo_args["train"]["model_dir"])
                + "/actor_agent"
                + str(agent_id)
                + ".pt"
            )
            self.actor[agent_id].actor.load_state_dict(policy_actor_state_dict)
        if not self.algo_args["render"]["use_render"]:
            policy_critic_state_dict = torch.load(
                str(self.algo_args["train"]["model_dir"]) + "/critic_agent" + ".pt"
            )
            self.critic.critic.load_state_dict(policy_critic_state_dict)
            if self.value_normalizer is not None:
                value_normalizer_state_dict = torch.load(
                    str(self.algo_args["train"]["model_dir"])
                    + "/value_normalizer"
                    + ".pt"
                )
                self.value_normalizer.load_state_dict(value_normalizer_state_dict)

    def close(self):
        """Close environment, writter, and logger."""
        if self.algo_args["render"]["use_render"]:
            self.envs.close()
        else:
            self.envs.close()
            if self.algo_args["eval"]["use_eval"] and self.eval_envs is not self.envs:
                self.eval_envs.close()
            self.writter.export_scalars_to_json(str(self.log_dir + "/summary.json"))
            self.writter.close()
            if not self.algo_args["train"]["debug"]:
                self.wandb_run.finish()
            self.logger.close()

            self.actor[0].skill_registry.store_skills_to_file(os.path.join(self.actor[0].skill_registry.skill_local_path, self.actor[0].skill_registry.skill_library_filename), self.actor[0].skill_registry.skills)
            # log = process_log_messages(self.config.log_dir)

            # with open(self.config.log_dir + '/log.md', 'w') as f:
            #     log = replace_unsupported_chars(log)
            #     f.write(log)

            self.llm_logger.write('>>> Markdown generated.')
            self.llm_logger.write('>>> Bye.')