import os
import torch
import numpy as np

from envs import make_vec_envs
from arguments import get_args
from tools import (
    set_base,
    set_more,
    init_map_and_pose,
    init_map_and_pose_for_env,
    prepare_planner_inputs,
    reset_current_loction,
    whether_seen_the_goal_and_set_goal_map,
    update_semantic_map,
    expand_semantic_mask,
    update_full_and_local_map,
    update_full_map_for_env,
    set_doneInfo,
    set_logger,
    log_interval,
    obtain_final_res,
    expand_masks,
    generate_candidate_goal_map,
    get_agent_position_mask,
    fill_agent_position_mask,
    filter_candidate_goals,
    sort_candidate_goals_by_distance,
    get_candidate_goal_mask_and_replace_obs,
    update_final_condidate_goal_map,
    is_stuck,
    is_goal_blocked,
    get_Node_based_full_map,
    check_consistent_positions,
    is_path_reachable
)
from PIL import Image
# from transformers import CLIPProcessor, CLIPModel
from tools import CandidateGoalIterator
from graph import EndpointDetector
from NodeDetector import NodeDetector
from collections import deque
# -------------------------------------------------------------
from vlm import QwenVLM
from mapping_tmp_goals import TraversableAreaDetector
# -------------------------------------------------------------
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
# -------------------------------------------------------------

# 添加episode完成检查函数


def check_episode_completion(dones, infos, args, episode_success, episode_spl, episode_dist,
                             spl_per_category, success_per_category, envs,
                             full_map, local_map, lmb, full_pose, planner_pose_inputs,
                             origins, local_pose, local_w, local_h, full_w, full_h,
                             device, goal_maps, found_goal, num_scenes, history_position):
    """
    检查episode是否完成，如果完成则更新统计信息并检查是否应该结束

    注意：此函数会直接修改传入的可变对象（PyTorch张量、NumPy数组等），
    包括：full_map, local_map, full_pose, local_pose, planner_pose_inputs,
    origins, lmb, goal_maps, found_goal 等
    """
    for e, x in enumerate(dones):
        if x:
            spl = infos[e]['spl']
            success = infos[e]['success']
            dist = infos[e]['distance_to_goal']
            spl_per_category[infos[e]['goal_name']].append(spl)
            success_per_category[infos[e]['goal_name']].append(success)
            if args.eval:
                episode_success[e].append(success)
                episode_spl[e].append(spl)
                episode_dist[e].append(dist)
                # 计算当前进程应该运行的总episode数：场景数量 × 每个场景的episode数
                # scenes_count = envs.get_scenes_count()[e]
                # total_episodes_for_process = scenes_count * args.episodes_per_scene
                if len(episode_success[e]) >= args.num_eval_episodes:
                    # print("Process {} finished after {} episodes ({} scenes × {} episodes/scene)".format(
                    #     e, len(episode_success[e]), scenes_count, args.episodes_per_scene))
                    return True  # 返回True表示应该结束

            full_map = update_full_map_for_env(full_map, local_map, lmb, e)
            full_map, full_pose, planner_pose_inputs, origins, lmb, local_map, local_pose = \
                init_map_and_pose_for_env(full_map, full_pose, planner_pose_inputs, origins, local_map, local_pose,
                                          args.map_size_cm, args.map_resolution, local_w, local_h, full_w, full_h,
                                          args.global_downscaling, device, lmb, e)

    return False


def main():
    args = get_args()

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    # -------------------------------------------------------------
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    logger = set_logger(args)
    episode_success, episode_spl, episode_dist, spl_per_category, success_per_category = set_doneInfo(
        args)
    # Logging and loss variables

    # Starting environments
    torch.set_num_threads(1)
    device = args.device = torch.device("cuda" if args.cuda else "cpu")
    envs = make_vec_envs(args)
    # -------------------------------------------------------------
    torch.set_grad_enabled(False)
    full_w, full_h, local_w, local_h, \
        num_scenes, num_episodes, device, _, \
        full_map, local_map, full_pose, local_pose, \
        origins, lmb, planner_pose_inputs = set_base(
            args)
    # -------------------------------------------------------------
    rotation_history, rotatation_over_at_beginning, new_goal_required, \
        dones, policy_vis, dones, \
        sem_map_module, found_goal,  goal_maps, \
        spl_per_category, success_per_category, start = set_more(
            args, num_scenes, full_w, full_h, num_episodes)

    # -------------------------------------------------------------

    full_map, full_pose, planner_pose_inputs, origins, local_map, local_pose = \
        init_map_and_pose(full_map, full_pose, args.map_size_cm, origins, planner_pose_inputs, args.map_resolution,
                          num_scenes, lmb, local_map, local_pose, device, local_w, local_h, full_w, full_h, args.global_downscaling)
    # -------------------------------------------------------------
    # for _ in range(527):
    #     obs, infos, rgbd = envs.reset()
    obs, infos, rgbd = envs.reset()
    # CLIP = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
    # CLIP_processor = CLIPProcessor.from_pretrained(
    #     "openai/clip-vit-large-patch14")
    vlm = QwenVLM()
    trav = TraversableAreaDetector(args)
    GRAPH = EndpointDetector()
    GRAPH2 = EndpointDetector(height_up_threshold=0.8755)
    Node = NodeDetector()
    NextGoalIterator = CandidateGoalIterator()
    history_position = deque(maxlen=25)
    # 只存储必要信息，不存储完整地图副本
    # FIXME 目前该代码只支持单线程运行
    count = 0
    reset_history = []
    reset = False
    local_step = 0
    while True:

        # if infos[0]['goal_name'] != "bed" :
        #     obs, infos, rgbd = envs.reset()
        #     full_map, full_pose, planner_pose_inputs, origins, lmb, local_map, local_pose = \
        #         init_map_and_pose_for_env(full_map, full_pose, planner_pose_inputs, origins, local_map, local_pose,
        #                                   args.map_size_cm, args.map_resolution, local_w, local_h, full_w, full_h,
        #                                   args.global_downscaling, device, lmb, 0)
        #     for e in range(num_scenes):
        #         goal_maps[e].fill(0)
        #         found_goal[e] = 0
        #     count += 1
        # 检查episode是否完成（最重要的检查点，在主要导航循环中）
        should_finish = check_episode_completion(dones, infos, args, episode_success, episode_spl, episode_dist,
                                                 spl_per_category, success_per_category, envs,
                                                 full_map, local_map, lmb, full_pose, planner_pose_inputs,
                                                 origins, local_pose, local_w, local_h, full_w, full_h,
                                                 device, goal_maps, found_goal, num_scenes, history_position)
        if dones[0]:
            for e in range(num_scenes):
                goal_maps[e].fill(0)
                found_goal[e] = 0
            reset = False
            local_step = 0
            history_position.clear()

        if should_finish:
            break
        # if count == args.num_eval_episodes:
        #     break
        if reset:
            full_map, full_pose, planner_pose_inputs, origins, lmb, local_map, local_pose = \
                init_map_and_pose_for_env(full_map, full_pose, planner_pose_inputs, origins, local_map, local_pose,
                                          args.map_size_cm, args.map_resolution, local_w, local_h, full_w, full_h,
                                          args.global_downscaling, device, lmb, 0)
            for e in range(num_scenes):
                goal_maps[e].fill(0)
                found_goal[e] = 0
            reset = False
            times_before_done = 0
            local_step = 0
            history_position.clear()
        policy_vis = [False, None]
        if infos[0]['new_episode']:
            count += 1
            seen_goal = False
            candidate_goal_map_list = []
            cn = infos[0]['goal_cat_id'] + 4
            tmp_goal_arrived = False
            sorted_candidate_goal_map = None

            # step = 0  # start the new episode
            # -------
            # 如果在旋转过程中发现目标，则把该目标作为长期目标点导航过去
            # TODO FIXME 新的episode 把goal_maps、found_goal、rotate_views 、pre_obs_pose重置 05-22 20:53
            for w in range(6):
                # 将obs中索引4及之后的通道数据清零

                if obs[0, cn, :, :].sum() == 0.:  # Do not see the goal category
                    seen_goal = False
                    candidate_goal_mask, obs = get_candidate_goal_mask_and_replace_obs(
                        rgbd[0, :, :, 3], GRAPH, GRAPH2, obs, args, device)
                else:
                    seen_goal = True
                    candidate_goal_mask = np.zeros(
                        (0, args.env_frame_height, args.env_frame_width))
                    obs = expand_semantic_mask(obs)  # see the goal category

                local_map, local_pose = update_semantic_map(
                    infos, obs, local_map, local_pose, sem_map_module, num_scenes, device)
                for e in range(num_scenes):
                    full_map = update_full_map_for_env(
                        full_map, local_map, lmb, e)
                if not seen_goal and candidate_goal_mask is not None:
                    rotatation_over_at_beginning[0] = False
                    candidate_goal_map = generate_candidate_goal_map(
                        full_map, candidate_goal_mask.shape[0])
                    candidate_goal_map_list.append(candidate_goal_map)
                elif seen_goal and local_map[0, cn, :, :].sum() == 0:
                    seen_goal = False
                else:
                    pass
                if seen_goal:
                    candidate_goal_map = None
                    goal_maps, found_goal = whether_seen_the_goal_and_set_goal_map(
                        num_scenes, infos, full_map, goal_maps, found_goal)
                    planner_pose_inputs, local_map = reset_current_loction(
                        local_pose, planner_pose_inputs, origins, local_map, num_scenes, args.map_resolution)
                    break

                # 更新位置
                planner_pose_inputs, local_map = reset_current_loction(
                    local_pose, planner_pose_inputs, origins, local_map, num_scenes, args.map_resolution)
                if w == 5:  # rotate over
                    rotatation_over_at_beginning[0] = True
                    new_goal_required[0] = False
                    break
                # 执行旋转
                full_map[0, 4:, :, :] = 0
                local_map[0, 4:, :, :] = 0
                planner_inputs, full_map = prepare_planner_inputs(
                    full_map, goal_maps, candidate_goal_map, candidate_goal_mask, None, None, planner_pose_inputs, num_scenes, found_goal,
                    args.visualize, args.print_images, is_rotatation=[True, "right"], policy_vis=policy_vis)
                obs, dones, infos, rgbd = envs.plan_act_and_preprocess(
                    planner_inputs)
        # Do not see the goal after rotating
        if rotatation_over_at_beginning[0]:
            rotatation_over_at_beginning[0] = False
            tmp_goal_arrived = False
            # vlm.ask_multiple_images(...)
            #  TODO
            if len(candidate_goal_map_list):
                trajectory_mask = infos[0]['trajectory_mask']
                sorted_candidate_goal_map = update_final_condidate_goal_map(
                    full_map, candidate_goal_map_list, None, planner_pose_inputs, args, trajectory_mask)
                if sorted_candidate_goal_map is not None:
                    NextGoalIterator.set_candidate_goals(
                        sorted_candidate_goal_map)
                    tmp = NextGoalIterator.next()
                    if tmp is not None:
                        goal_maps[0] = tmp
                    else:
                        node_map = get_Node_based_full_map(
                            full_map, Node, planner_pose_inputs, args, infos)
                        if node_map is not None:
                            goal_maps[0] = node_map[0]
                        else:
                            obs, infos, rgbd = envs.reset()
                            reset_history.append(count)
                            reset = True
                            continue

                else:
                    raise Exception(
                        "No candidate goals after filtering at beginning(rotate over)")
            else:
                raise Exception(
                    "No candidate goals at beginning before filtering")  # this exception is Low-probability event
        if (tmp_goal_arrived and not seen_goal):
            update_goals_list = []
            for k in range(6):
                full_map[0, 4:, :, :] = 0
                local_map[0, 4:, :, :] = 0
                if k <= 1:
                    planner_inputs, full_map = prepare_planner_inputs(
                        full_map, goal_maps, sorted_candidate_goal_map, None, None, None, planner_pose_inputs, num_scenes, found_goal,
                        args.visualize, args.print_images, is_rotatation=[True, "left"], policy_vis=policy_vis)
                else:
                    planner_inputs, full_map = prepare_planner_inputs(
                        full_map, goal_maps, sorted_candidate_goal_map, None, None, None, planner_pose_inputs, num_scenes, found_goal,
                        args.visualize, args.print_images, is_rotatation=[True, "right"], policy_vis=policy_vis)
                obs, dones, infos, rgbd = envs.plan_act_and_preprocess(
                    planner_inputs)
                if dones[0]:
                    break
                if obs[0, cn, :, :].sum() == 0.:  # Do not see the goal category
                    seen_goal = False
                    candidate_goal_mask, obs = get_candidate_goal_mask_and_replace_obs(
                        rgbd[0, :, :, 3], GRAPH, GRAPH2, obs, args, device)
                else:
                    seen_goal = True
                    obs = expand_semantic_mask(obs)

                local_map, local_pose = update_semantic_map(
                    infos, obs, local_map, local_pose, sem_map_module, num_scenes, device)

                for e in range(num_scenes):
                    full_map = update_full_map_for_env(
                        full_map, local_map, lmb, e)
                if seen_goal and full_map[0, infos[0]['goal_cat_id']+4, :, :].sum() != 0:
                    goal_maps, found_goal = whether_seen_the_goal_and_set_goal_map(
                        num_scenes, infos, full_map, goal_maps, found_goal)
                    # 更新位置
                    planner_pose_inputs, local_map = reset_current_loction(
                        local_pose, planner_pose_inputs, origins, local_map, num_scenes, args.map_resolution)
                    if goal_maps[0].sum() != 0:
                        break

                else:
                    if candidate_goal_mask is not None:
                        candidate_goal_map = generate_candidate_goal_map(
                            full_map, candidate_goal_mask.shape[0])

                        update_goals_list.append(candidate_goal_map)
                # 更新位置
                planner_pose_inputs, local_map = reset_current_loction(
                    local_pose, planner_pose_inputs, origins, local_map, num_scenes, args.map_resolution)
            tmp_goal_arrived = False
            if dones[0]:
                continue
            if not seen_goal:
                # for end
                sorted_candidate_goal_map = update_final_condidate_goal_map(
                    full_map, update_goals_list, NextGoalIterator, planner_pose_inputs, args)

                # and sorted_candidate_goal_map.shape[0] != 0:
                if sorted_candidate_goal_map is not None:
                    NextGoalIterator.set_candidate_goals(
                        sorted_candidate_goal_map)
                    tmp = NextGoalIterator.next()
                    if tmp is not None:
                        goal_maps[0] = tmp
                    else:
                        node_map = get_Node_based_full_map(
                            full_map, Node, planner_pose_inputs, args, infos)
                        if node_map is not None:
                            goal_maps[0] = node_map[0]
                        else:
                            obs, infos, rgbd = envs.reset()
                            reset_history.append(count)
                            reset = True
                            continue   # this exception is Low-probability event
                else:
                    raise Exception(
                        "No candidate goals to explore after tmp goal arrived")  # this exception is Low-probability event
        # 没找到目标的时候，使用一下函数
        if goal_maps[0].sum() != 0:
            full_map[0, 4:, :, :] = 0
            local_map[0, 4:, :, :] = 0
            planner_inputs, full_map = prepare_planner_inputs(
                full_map, goal_maps, sorted_candidate_goal_map, None, None, None, planner_pose_inputs, num_scenes, found_goal,
                args.visualize, args.print_images, is_rotatation=[False, None], policy_vis=policy_vis)

            obs, dones, infos, rgbd = envs.plan_act_and_preprocess(
                planner_inputs)
            local_step += 1
            if obs[0, cn, :, :].sum() != 0:
                obs = expand_semantic_mask(obs)

            local_map, local_pose = update_semantic_map(
                infos, obs, local_map, local_pose, sem_map_module, num_scenes, device)

            # 立即将局部地图更新同步到全局地图b
            for e in range(num_scenes):
                full_map = update_full_map_for_env(full_map, local_map, lmb, e)

            # update the goal_map
            tmp_goal_maps, tmp_found_goal = whether_seen_the_goal_and_set_goal_map(
                num_scenes, infos, full_map, goal_maps, found_goal)
            if tmp_goal_maps[0].sum() != 0 and tmp_found_goal[0] == 1:
                goal_maps = tmp_goal_maps
                found_goal = tmp_found_goal
            if not found_goal[0] and goal_maps[0].sum() == 1 and is_goal_blocked(full_map[0, 0, :, :] > 0, goal_maps[0], k=10, thresh=0.5):
                tmp_goal_arrived = True
                seen_goal = False
            planner_inputs, local_map = reset_current_loction(
                local_pose, planner_pose_inputs, origins, local_map, num_scenes, args.map_resolution)

            # if check_consistent_positions(history_position, planner_pose_inputs, full_map, args) and not seen_goal:
            #     tmp=NextGoalIterator.next()
            #     if tmp is not None:
            #         goal_maps[0] = tmp
            #     else:
            #         node_map = get_Node_based_full_map(
            #             full_map, Node, planner_pose_inputs, args, infos)
            #         if node_map is not None and node_map[0].shape[0] != 0:
            #             goal_maps[0] = node_map[0]
            #         else:
            #             obs, infos, rgbd = envs.reset()
            #             reset_history.append(count)
            #             reset = True
            #             continue

            # update region-Junction graph per 20 steps
            if local_step == 10 and not found_goal[0]:  # 20 12
                local_step = 0
                tmp_goal_arrived = True
                seen_goal = False
            if infos[0]['fmm_stop']:
                local_step = 0
                full_map, local_map, lmb, planner_pose_inputs, local_pose = update_full_and_local_map(
                    full_map, local_map, lmb, planner_pose_inputs, full_pose, local_pose, origins, num_scenes, args.map_resolution, args.global_downscaling, local_w, local_h, full_w, full_h, device)

                if goal_maps[0].sum() and found_goal[0]:
                    tmp_goal_arrived = False
                    seen_goal = True
                else:
                    tmp_goal_arrived = True
                    seen_goal = False
        if infos[0]['time'] % args.log_interval == 0:
            log_interval(logger, infos, args, num_scenes, start,
                         episode_success, episode_spl, episode_dist)

    obtain_final_res(args, logger, episode_success, episode_spl,
                     episode_dist, spl_per_category, success_per_category)
    # 将reset_history保存到本地文件
    with open('tmp/reset_history.txt', 'w') as f:
        for i, reset_count in enumerate(reset_history):
            f.write(f"Episode {i}: Reset at step {reset_count}\n")
        f.write(f"\nTotal resets: {len(reset_history)}")
    print(f"Reset history已保存到 tmp/reset_history.txt，共{len(reset_history)}次重置")

    # Print and save model performance numbers during evaluation


if __name__ == "__main__":
    main()
