import os
import pickle
from grf.academy_3_vs_1_with_keeper import Academy_3_vs_1_with_Keeper
from generator import Generator
from testor import Test
from config import DIC_AGENTS, DIC_PATH, DIC_BASE_AGENT_CONF
from agents import *

IS_TRAIN = True
TOTAL_ROUNDS = 10000000

# 检查并创建工作目录
if not os.path.exists(DIC_PATH["PATH_TO_WORK_DIRECTORY"]):
    os.makedirs(DIC_PATH["PATH_TO_WORK_DIRECTORY"])

# 检查并创建模型目录
if not os.path.exists(DIC_PATH["PATH_TO_MODEL"]):
    os.makedirs(DIC_PATH["PATH_TO_MODEL"])

# 存储 mean_rewards.pkl 文件
def _log_mean_reward(cnt_round, win_rate):
    # 设置存储 mean_rewards.pkl 文件的路径
    reward_file_path = os.path.join(DIC_PATH["PATH_TO_WORK_DIRECTORY"], "mean_rewards.pkl")

    # 打开文件并追加数据
    with open(reward_file_path, 'ab') as f:
        pickle.dump({"cnt_round": cnt_round, "win_rate": win_rate}, f)

def main():
    env = Academy_3_vs_1_with_Keeper()

    # 初始化共享智能体和中央智能体
    shared_agent = Distributed_Agent(DIC_PATH)
    central_agent = DIC_AGENTS["Central_Agent"](
        dic_agent_conf=DIC_BASE_AGENT_CONF,
        dic_path=DIC_PATH,
        cnt_round=0
    )

    if IS_TRAIN:
        # 训练过程
        for cnt_round in range(TOTAL_ROUNDS):
            print(f"round {cnt_round} starts")
            generator = Generator(
                cnt_round=cnt_round,
                cnt_gen=1,
                dic_path=DIC_PATH,
                dic_agent_conf=DIC_BASE_AGENT_CONF,
                agent=shared_agent,
                central_agent=central_agent,
                env=env,
            )
            data, central_data = generator.generate()

            shared_agent.update(data)
            central_agent.update(central_data)

            if cnt_round % 100 == 0:
                total_reward_list = []
                for i in range(20):
                    test = Test(
                        cnt_gen=1,
                        dic_path=DIC_PATH,
                        dic_agent_conf=DIC_BASE_AGENT_CONF,
                        agent=shared_agent,
                        central_agent=central_agent,
                        env=env,
                    )
                    total_reward = +test.test()
                win_rate = total_reward / 20

                _log_mean_reward(cnt_round, win_rate)

                shared_agent.save_network(f"shared_round_{cnt_round}")
                central_agent.save_central_network(f"central_round_{cnt_round}")



if __name__ == "__main__":
    main()