import json
import os
import sys
import numpy as np
import random
import math
import time
from collections import defaultdict
import line_profiler
import logging

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

from utils.distributed import is_default_gpu
from utils.ops import pad_tensors, gen_seq_masks
from torch.nn.utils.rnn import pad_sequence

from moe.agent_base_moe import Seq2SeqAgent
# from r2r.agent_base import Seq2SeqAgent
from r2r.eval_utils import cal_dtw

from models.graph_utils import GraphMap
from models.model import VLNBert, Critic
from models.ops import pad_tensors_wgrad

from moe.router import NavigationSkillAnalyzer
from utils.get_images import load_vp_lookup, convert_path2img, extract_cand_img
from router_qwen import extract_from_response
from llm_utils import generate_temporal_instructions

import httpx
import asyncio
timeout = httpx.Timeout(120.0, connect=10.0)


async def identify_instructions_with_qwen_batch(
    instructions,
    previous_viewpoint_lists,
    max_tokens,
    server_url="http://localhost:8001/identify-instruction-batch",
):
    payload = {
        "instructions": instructions,
        "previous_viewpoint_lists": previous_viewpoint_lists,
        "max_tokens": max_tokens
    }
    async with httpx.AsyncClient(timeout=180.0) as client:
        response = await client.post(server_url, json=payload)
        response.raise_for_status()
        return response.json()["results"]
    

async def analyze_with_qwen_batch(
    instructions,
    sub_instructions,
    reasonings,
    candidate_viewpoint_lists,
    max_tokens,
    server_url="http://localhost:8001/analyze-skill-batch",
):
    payload = {
        "instructions": instructions,
        "sub_instructions": sub_instructions,
        "reasonings": reasonings,
        "candidate_viewpoint_lists": candidate_viewpoint_lists,
        "max_tokens": max_tokens
    }
    async with httpx.AsyncClient(timeout=180.0) as client:
        response = await client.post(server_url, json=payload)
        response.raise_for_status()
        return response.json()["results"]


async def process_batch_navigation(obs, instructions, previous_viewpoint_lists, candidate_viewpoint_lists, args, logger, step):
    loaded_weights = []
    # print('-'*20)
    try:
        identify_results_raw = await identify_instructions_with_qwen_batch(
            instructions, previous_viewpoint_lists, 2000*len(obs)
        )
        # print(type(identify_results_raw),identify_results_raw)
        
    except Exception as e:
        logger.error(f"Failed to identify instructions: {e}")
        return [], []

    # Ensure each result is a parsed dict
    identify_results = []
    for i, r in enumerate(identify_results_raw):
        # print(type(r),r)
        
        if isinstance(r, str):
            try:
                r = json.loads(r)
            except json.JSONDecodeError:
                logger.warning(f"Failed to decode identify_result at index {i}: {r}")
                r = {}
        identify_results.append(r)
    
    sub_instrs = [r.get("Sub-instruction to be executed", "") for r in identify_results]
    reasonings = [r.get("Reasoning", "") for r in identify_results]

    skill_routings = await analyze_with_qwen_batch(
        instructions, sub_instrs, reasonings, candidate_viewpoint_lists, 2000*len(obs),
    )
    
    if args.debug:    
        print('-'*20)
        print(f"Sub-instructions: {sub_instrs}")
        print(f"Reasonings: {reasonings}")
        print(f"Skill routings: {skill_routings}")
    
    for i, ob in enumerate(obs):
        # _, _, weights = args.analyzer.extract_from_response(skill_routings[i])
        _, _, weights = extract_from_response(skill_routings[i])

        if not weights or len(weights) == 0:
            print(f"[Warning] Empty weights at index {i}, using uniform weights")
            weights = args.resume_weights

        if args.routing_weights_type == 'float':
            total = sum(weights)
            normalized_weights = [w / total if total > 0 else 1.0 / len(weights) for w in weights]
        elif args.routing_weights_type == 'int':
            normalized_weights = [1 if w > 0 else 0 for w in weights]
        else:
            normalized_weights = weights

        if args.routing_mode == 'shared':
            normalized_weights = [1 if w > 0 else 0 for w in weights]
            normalized_weights[-1] = 1
            

        # sub_instructions.append(sub_instrs[i])
        if normalized_weights:
            loaded_weights.append(normalized_weights)
        else:
            loaded_weights.append(args.resume_weights)

        logger.info("-" * 20)
        logger.info(f"Step t: {step}")
        logger.info(f"Scan: {ob['scan']}")
        logger.info(f"Instruction ID: {ob['instr_id']}")
        logger.info(f"Reordering Plan instruction: {instructions[i]}")
        logger.info(f"Previous viewpoint list: {previous_viewpoint_lists[i]}")
        logger.info(f"Candidate viewpoint list: {candidate_viewpoint_lists[i]}")
        logger.info(f"Sub-instruction: {sub_instrs[i]}")
        logger.info(f"Reasoning: {reasonings[i]}")
        logger.info(f"Skill routing: {skill_routings[i]}")

    return sub_instrs, loaded_weights


class GMapNavAgents(Seq2SeqAgent):
    
    def __init__(self, args, env, rank):
        super().__init__(args, env, rank)
        self.args = args

        # Create log directory if it doesn't exist
        os.makedirs(self.args.log_dir, exist_ok=True)

        # Set up logger
        log_file_path = os.path.join(self.args.log_dir, "router_outputs.log")
        self.logger = logging.getLogger(__name__)
        self.logger.setLevel(logging.INFO)

        # Prevent duplicate handlers if multiple instances are created
        if not self.logger.handlers:
            file_handler = logging.FileHandler(log_file_path, mode='w')
            formatter = logging.Formatter('%(asctime)s - %(message)s')
            file_handler.setFormatter(formatter)
            self.logger.addHandler(file_handler)
        
        # if self.args.routing_mode == 'moe' or self.args.routing_mode == 'shared':
        if self.args.routing_mode != 'fixed':
            self.analyzer = NavigationSkillAnalyzer()
            ### If using FastAPI, you do not need to load the model here
            # self.analyzer.initialize_qwen_model(self.args.router_model)

            self.vp_lookup = load_vp_lookup()
    
    def _build_model(self):
        self.vln_berts = nn.ModuleList([])
        self.critics = nn.ModuleList([])  # Store multiple critic models if needed
        self.vln_bert_optimizers = [] # Store optimizers for multiple VLN-BERTs
        self.critic_optimizers = []  # Store optimizers for multiple critics

        if isinstance(self.args.resume_files, list):
            self.load_resumes(self.args.resume_files) # Load all resume files into self.resumes
            for model_idx, resume_file in enumerate(self.args.resume_files):
                if self.args.resume_weights[model_idx] == 0: continue # Skip if weight is 0
                
                self.vln_bert = VLNBert(self.args).cuda()
                self.critic = Critic(self.args).cuda()

                # Initialize optimizers for each model
                if self.args.optim == 'rms':
                    optimizer = torch.optim.RMSprop
                elif self.args.optim == 'adam':
                    optimizer = torch.optim.Adam
                elif self.args.optim == 'adamW':
                    optimizer = torch.optim.AdamW
                elif self.args.optim == 'sgd':
                    optimizer = torch.optim.SGD
                else:
                    assert False

                self.vln_bert_optimizer = optimizer(self.vln_bert.parameters(), lr=self.args.lr)
                self.critic_optimizer = optimizer(self.critic.parameters(), lr=self.args.lr)

                if resume_file in self.resumes:
                    self._load_model_weights(resume_file, self.vln_bert, self.vln_bert_optimizer, self.critic, self.critic_optimizer)

                    self.vln_berts.append(self.vln_bert)
                    self.critics.append(self.critic)
                    self.vln_bert_optimizers.append(self.vln_bert_optimizer)
                    self.critic_optimizers.append(self.critic_optimizer)

                # print('-'*20)
                # print(f"Corresponding model weight: {self.args.resume_weights[model_idx]} - {resume_file}")
        else:
            self.vln_bert = VLNBert(self.args).cuda()
            self.critic = Critic(self.args).cuda()
            # Initialize optimizers for the single models
            if self.args.optim == 'rms':
                optimizer = torch.optim.RMSprop
            elif self.args.optim == 'adam':
                optimizer = torch.optim.Adam
            elif self.args.optim == 'adamW':
                optimizer = torch.optim.AdamW
            elif self.args.optim == 'sgd':
                optimizer = torch.optim.SGD
            else:
                assert False
            self.vln_bert_optimizer = optimizer(self.vln_bert.parameters(), lr=self.args.lr)
            self.critic_optimizer = optimizer(self.critic.parameters(), lr=self.args.lr)

        self.scanvp_cands = {}

    def _load_model_weights(self, resume_file, vln_bert, vln_bert_optimizer, critic, critic_optimizer):
        """Loads model weights and optionally optimizer state from a resume file."""
        
        # print("-"*20)
        # print("Function `_load_model_weights` is called!")
        
        states = self.resumes.get(resume_file)
        
        if states:
            def recover_state(name, model, optimizer):
                if name in states:
                    state = model.state_dict()
                    model_keys = set(state.keys())
                    load_keys = set(states[name]['state_dict'].keys())
                    state_dict = states[name]['state_dict']

                    if model_keys != load_keys:
                        print(f"NOTICE: DIFFERENT KEYS IN {name} for {resume_file}")
                        if not list(model_keys)[0].startswith('module.') and list(load_keys)[0].startswith('module.'):
                            state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
                        if list(model_keys)[0].startswith('module.') and (not list(load_keys)[0].startswith('module.')):
                            state_dict = {'module.' + k: v for k, v in state_dict.items()}
                        same_state_dict = {}
                        extra_keys = []
                        
                        for k, v in state_dict.items():
                            if k in model_keys:
                                same_state_dict[k] = v
                            else:
                                extra_keys.append(k)
                        state_dict = same_state_dict
                        print(f'Extra keys in {name} state_dict for {resume_file}: {", ".join(extra_keys)}')

                    state.update(state_dict)
                    model.load_state_dict(state)
                    
                    # ### Display model info
                    # if self.args.debug:
                    #     print(f"✔️ Loaded model '{name}' from {resume_file}")
                    #     print(f"Model summary for {name} ({resume_file}):")
                    #     print(f"  - Total parameters: {sum(p.numel() for p in model.parameters())}")
                    #     print(f"  - Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
                    #     # print(f"  - First few keys in state_dict: {list(state_dict.keys())[:5]}")
                    #     print(f"\nSample weights from '{name}':")
                    #     printed = 0
                    #     for k, v in model.state_dict().items():
                    #         print(f" - {k}: shape={tuple(v.shape)}")
                    #         print(f"   values: {v.view(-1)[:5].tolist()}")  # Print first 5 values
                    #         printed += 1
                    #         if printed >= 5:
                    #             break  # Limit output to avoid clutter
                    #     print('-'*20)

                    if self.args.resume_optimizer and 'optimizer' in states[name]:
                        optimizer.load_state_dict(states[name]['optimizer'])
                else:
                    print(f"WARNING: Weights for {name} not found in {resume_file}")

            recover_state("vln_bert", vln_bert, vln_bert_optimizer)
            recover_state("critic", critic, critic_optimizer)
        else:
            print(f"WARNING: No state data found for {resume_file}")
                        
    def _language_variable(self, obs):
        seq_lengths = [len(ob['instr_encoding']) for ob in obs]
        
        seq_tensor = np.zeros((len(obs), max(seq_lengths)), dtype=np.int64)
        # mask = np.zeros((len(obs), max(seq_lengths)), dtype=np.bool)
        mask = np.zeros((len(obs), max(seq_lengths)), dtype=np.bool_)
        for i, ob in enumerate(obs):
            seq_tensor[i, :seq_lengths[i]] = ob['instr_encoding']
            mask[i, :seq_lengths[i]] = True

        seq_tensor = torch.from_numpy(seq_tensor).long().cuda()
        mask = torch.from_numpy(mask).cuda()
        return {
            'txt_ids': seq_tensor, 'txt_masks': mask
        }

    def _panorama_feature_variable(self, obs):
        ''' Extract precomputed features into variable. '''
        batch_view_img_fts, batch_loc_fts, batch_nav_types = [], [], []
        batch_view_lens, batch_cand_vpids = [], []
        
        for i, ob in enumerate(obs):
            view_img_fts, view_ang_fts, nav_types, cand_vpids = [], [], [], []
            # cand views
            used_viewidxs = set()
            for j, cc in enumerate(ob['candidate']):
                view_img_fts.append(cc['feature'][:self.args.image_feat_size])
                view_ang_fts.append(cc['feature'][self.args.image_feat_size:])
                nav_types.append(1)
                cand_vpids.append(cc['viewpointId'])
                used_viewidxs.add(cc['pointId'])
            # non cand views
            view_img_fts.extend([x[:self.args.image_feat_size] for k, x \
                in enumerate(ob['feature']) if k not in used_viewidxs])
            view_ang_fts.extend([x[self.args.image_feat_size:] for k, x \
                in enumerate(ob['feature']) if k not in used_viewidxs])
            nav_types.extend([0] * (36 - len(used_viewidxs)))
            # combine cand views and noncand views
            view_img_fts = np.stack(view_img_fts, 0)    # (n_views, dim_ft)
            view_ang_fts = np.stack(view_ang_fts, 0)
            view_box_fts = np.array([[1, 1, 1]] * len(view_img_fts)).astype(np.float32) # 
            view_loc_fts = np.concatenate([view_ang_fts, view_box_fts], 1)
            
            batch_view_img_fts.append(torch.from_numpy(view_img_fts))
            batch_loc_fts.append(torch.from_numpy(view_loc_fts))
            batch_nav_types.append(torch.LongTensor(nav_types))
            batch_cand_vpids.append(cand_vpids)
            batch_view_lens.append(len(view_img_fts))

        # pad features to max_len
        batch_view_img_fts = pad_tensors(batch_view_img_fts).cuda()
        batch_loc_fts = pad_tensors(batch_loc_fts).cuda()
        batch_nav_types = pad_sequence(batch_nav_types, batch_first=True, padding_value=0).cuda()
        batch_view_lens = torch.LongTensor(batch_view_lens).cuda()

        return {
            'view_img_fts': batch_view_img_fts, 'loc_fts': batch_loc_fts, 
            'nav_types': batch_nav_types, 'view_lens': batch_view_lens, 
            'cand_vpids': batch_cand_vpids,
        }

    def _nav_gmap_variable(self, obs, gmaps):
        # [stop] + gmap_vpids
        batch_size = len(obs)
        
        batch_gmap_vpids, batch_gmap_lens = [], []
        batch_gmap_img_embeds, batch_gmap_step_ids, batch_gmap_pos_fts = [], [], []
        batch_gmap_pair_dists, batch_gmap_visited_masks = [], []
        batch_no_vp_left = []
        for i, gmap in enumerate(gmaps):
            visited_vpids, unvisited_vpids = [], []                
            for k in gmap.node_positions.keys():
                if self.args.act_visited_nodes:
                    if k == obs[i]['viewpoint']:
                        visited_vpids.append(k)
                    else:
                        unvisited_vpids.append(k)
                else:
                    if gmap.graph.visited(k):
                        visited_vpids.append(k)
                    else:
                        unvisited_vpids.append(k)
            batch_no_vp_left.append(len(unvisited_vpids) == 0)
            if self.args.enc_full_graph:
                gmap_vpids = [None] + visited_vpids + unvisited_vpids
                gmap_visited_masks = [0] + [1] * len(visited_vpids) + [0] * len(unvisited_vpids)
            else:
                gmap_vpids = [None] + unvisited_vpids
                gmap_visited_masks = [0] * len(gmap_vpids)

            gmap_step_ids = [gmap.node_step_ids.get(vp, 0) for vp in gmap_vpids]
            gmap_img_embeds = [gmap.get_node_embed(vp) for vp in gmap_vpids[1:]]
            gmap_img_embeds = torch.stack(
                [torch.zeros_like(gmap_img_embeds[0])] + gmap_img_embeds, 0
            )   # cuda

            gmap_pos_fts = gmap.get_pos_fts(
                obs[i]['viewpoint'], gmap_vpids, obs[i]['heading'], obs[i]['elevation'],
            )

            gmap_pair_dists = np.zeros((len(gmap_vpids), len(gmap_vpids)), dtype=np.float32)
            for i in range(1, len(gmap_vpids)):
                for j in range(i+1, len(gmap_vpids)):
                    gmap_pair_dists[i, j] = gmap_pair_dists[j, i] = \
                        gmap.graph.distance(gmap_vpids[i], gmap_vpids[j])

            batch_gmap_img_embeds.append(gmap_img_embeds)
            batch_gmap_step_ids.append(torch.LongTensor(gmap_step_ids))
            batch_gmap_pos_fts.append(torch.from_numpy(gmap_pos_fts))
            batch_gmap_pair_dists.append(torch.from_numpy(gmap_pair_dists))
            batch_gmap_visited_masks.append(torch.BoolTensor(gmap_visited_masks))
            batch_gmap_vpids.append(gmap_vpids)
            batch_gmap_lens.append(len(gmap_vpids))

        # collate
        batch_gmap_lens = torch.LongTensor(batch_gmap_lens)
        batch_gmap_masks = gen_seq_masks(batch_gmap_lens).cuda()
        batch_gmap_img_embeds = pad_tensors_wgrad(batch_gmap_img_embeds)
        batch_gmap_step_ids = pad_sequence(batch_gmap_step_ids, batch_first=True).cuda()
        batch_gmap_pos_fts = pad_tensors(batch_gmap_pos_fts).cuda()
        batch_gmap_visited_masks = pad_sequence(batch_gmap_visited_masks, batch_first=True).cuda()

        max_gmap_len = max(batch_gmap_lens)
        gmap_pair_dists = torch.zeros(batch_size, max_gmap_len, max_gmap_len).float()
        for i in range(batch_size):
            gmap_pair_dists[i, :batch_gmap_lens[i], :batch_gmap_lens[i]] = batch_gmap_pair_dists[i]
        gmap_pair_dists = gmap_pair_dists.cuda()

        return {
            'gmap_vpids': batch_gmap_vpids, 'gmap_img_embeds': batch_gmap_img_embeds, 
            'gmap_step_ids': batch_gmap_step_ids, 'gmap_pos_fts': batch_gmap_pos_fts,
            'gmap_visited_masks': batch_gmap_visited_masks, 
            'gmap_pair_dists': gmap_pair_dists, 'gmap_masks': batch_gmap_masks,
            'no_vp_left': batch_no_vp_left,
        }

    def _nav_vp_variable(self, obs, gmaps, pano_embeds, cand_vpids, view_lens, nav_types):
        batch_size = len(obs)

        # add [stop] token
        vp_img_embeds = torch.cat(
            [torch.zeros_like(pano_embeds[:, :1]), pano_embeds], 1
        )

        batch_vp_pos_fts = []
        for i, gmap in enumerate(gmaps):
            cur_cand_pos_fts = gmap.get_pos_fts(
                obs[i]['viewpoint'], cand_vpids[i], 
                obs[i]['heading'], obs[i]['elevation']
            )
            cur_start_pos_fts = gmap.get_pos_fts(
                obs[i]['viewpoint'], [gmap.start_vp], 
                obs[i]['heading'], obs[i]['elevation']
            )                    
            # add [stop] token at beginning
            vp_pos_fts = np.zeros((vp_img_embeds.size(1), 14), dtype=np.float32)
            vp_pos_fts[:, :7] = cur_start_pos_fts
            vp_pos_fts[1:len(cur_cand_pos_fts)+1, 7:] = cur_cand_pos_fts
            batch_vp_pos_fts.append(torch.from_numpy(vp_pos_fts))

        batch_vp_pos_fts = pad_tensors(batch_vp_pos_fts).cuda()

        vp_nav_masks = torch.cat([torch.ones(batch_size, 1).bool().cuda(), nav_types == 1], 1)

        return {
            'vp_img_embeds': vp_img_embeds,
            'vp_pos_fts': batch_vp_pos_fts,
            'vp_masks': gen_seq_masks(view_lens+1),
            'vp_nav_masks': vp_nav_masks,
            'vp_cand_vpids': [[None]+x for x in cand_vpids],
        }

    def _teacher_action(self, obs, vpids, ended, visited_masks=None):
        """
        Extract teacher actions into variable.
        :param obs: The observation.
        :param ended: Whether the action seq is ended
        :return:
        """
        a = np.zeros(len(obs), dtype=np.int64)
        for i, ob in enumerate(obs):
            if ended[i]:                                            # Just ignore this index
                a[i] = self.args.ignoreid
            else:
                if ob['viewpoint'] == ob['gt_path'][-1]:
                    a[i] = 0    # Stop if arrived 
                else:
                    scan = ob['scan']
                    cur_vp = ob['viewpoint']
                    min_idx, min_dist = self.args.ignoreid, float('inf')
                    for j, vpid in enumerate(vpids[i]):
                        if j > 0 and ((visited_masks is None) or (not visited_masks[i][j])):
                            # dist = min([self.env.shortest_distances[scan][vpid][end_vp] for end_vp in ob['gt_end_vps']])
                            dist = self.env.shortest_distances[scan][vpid][ob['gt_path'][-1]] \
                                    + self.env.shortest_distances[scan][cur_vp][vpid]
                            if dist < min_dist:
                                min_dist = dist
                                min_idx = j
                    a[i] = min_idx
                    if min_idx == self.args.ignoreid:
                        print('scan %s: all vps are searched' % (scan))

        return torch.from_numpy(a).cuda()

    def _teacher_action_r4r(
        self, obs, vpids, ended, visited_masks=None, imitation_learning=False, t=None, traj=None
    ):
        """R4R is not the shortest path. The goal location can be visited nodes.
        """
        a = np.zeros(len(obs), dtype=np.int64)
        for i, ob in enumerate(obs):
            if ended[i]:                                            # Just ignore this index
                a[i] = self.args.ignoreid
            else:
                if imitation_learning:
                    assert ob['viewpoint'] == ob['gt_path'][t]
                    if t == len(ob['gt_path']) - 1:
                        a[i] = 0    # stop
                    else:
                        goal_vp = ob['gt_path'][t + 1]
                        for j, vpid in enumerate(vpids[i]):
                            if goal_vp == vpid:
                                a[i] = j
                                break
                else:
                    if ob['viewpoint'] == ob['gt_path'][-1]:
                        a[i] = 0    # Stop if arrived 
                    else:
                        scan = ob['scan']
                        cur_vp = ob['viewpoint']
                        min_idx, min_dist = self.args.ignoreid, float('inf')
                        for j, vpid in enumerate(vpids[i]):
                            if j > 0 and ((visited_masks is None) or (not visited_masks[i][j])):
                                if self.args.expert_policy == 'ndtw':
                                    dist = - cal_dtw(
                                        self.env.shortest_distances[scan], 
                                        sum(traj[i]['path'], []) + self.env.shortest_paths[scan][ob['viewpoint']][vpid][1:], 
                                        ob['gt_path'], 
                                        threshold=3.0
                                    )['nDTW']
                                elif self.args.expert_policy == 'spl':
                                    # dist = min([self.env.shortest_distances[scan][vpid][end_vp] for end_vp in ob['gt_end_vps']])
                                    dist = self.env.shortest_distances[scan][vpid][ob['gt_path'][-1]] \
                                            + self.env.shortest_distances[scan][cur_vp][vpid]
                                if dist < min_dist:
                                    min_dist = dist
                                    min_idx = j
                        a[i] = min_idx
                        if min_idx == self.args.ignoreid:
                            print('scan %s: all vps are searched' % (scan))
        return torch.from_numpy(a).cuda()

    def make_equiv_action(self, a_t, gmaps, obs, traj=None):
        """
        Interface between Panoramic view and Egocentric view
        It will convert the action panoramic view action a_t to equivalent egocentric view actions for the simulator
        """
        for i, ob in enumerate(obs):
            action = a_t[i]
            if action is not None:            # None is the <stop> action
                traj[i]['path'].append(gmaps[i].graph.path(ob['viewpoint'], action))
                if len(traj[i]['path'][-1]) == 1:
                    prev_vp = traj[i]['path'][-2][-1]
                else:
                # elif len(traj[i]['path'][-1]) >= 2:
                    prev_vp = traj[i]['path'][-1][-2]
                # else:
                    # prev_vp = traj[i]['path'][-1][-1] # TODO
                    
                viewidx = self.scanvp_cands['%s_%s'%(ob['scan'], prev_vp)][action]
                heading = (viewidx % 12) * math.radians(30)
                elevation = (viewidx // 12 - 1) * math.radians(30)
                self.env.env.sims[i].newEpisode([ob['scan']], [action], [heading], [elevation])

    def _update_scanvp_cands(self, obs):
        for ob in obs:
            scan = ob['scan']
            vp = ob['viewpoint']
            scanvp = '%s_%s' % (scan, vp)
            self.scanvp_cands.setdefault(scanvp, {})
            for cand in ob['candidate']:
                self.scanvp_cands[scanvp].setdefault(cand['viewpointId'], {})
                self.scanvp_cands[scanvp][cand['viewpointId']] = cand['pointId']

    # @profile
    def rollout(self, train_ml=None, train_rl=False, reset=True):

        if reset:  # Reset env
            obs = self.env.reset()
        else:
            obs = self.env._get_obs()
        self._update_scanvp_cands(obs)

        batch_size = len(obs)
        # build graph: keep the start viewpoint
        gmaps = [GraphMap(ob['viewpoint']) for ob in obs]
        for i, ob in enumerate(obs):
            gmaps[i].update_graph(ob)

        # Record the navigation path
        traj = [{
            'instr_id': ob['instr_id'],
            'path': [[ob['viewpoint']]],
            'details': {},
        } for ob in obs]
        
        # Original instructions
        oringal_instructions = [ob['instruction'] for ob in obs]
        
        # instructions = oringal_instructions
        
        if self.args.routing_mode != 'fixed' and 'reordering' not in self.args.val_env_names[0]:
            instructions = generate_temporal_instructions(oringal_instructions, max_retries=3, max_tokens=2000, temperature=0, model="gpt-4o", num_threads=batch_size)
        else:
            instructions = oringal_instructions
        
        # Language input: txt_ids, txt_masks
        language_inputs = self._language_variable(obs)
        # txt_embeds = self.vln_bert('language', language_inputs)
        
        # MoE: Convert txt_embeds to txt_embeds_list and then apply resume_weights
        txt_embeds_list = [vln_bert('language', language_inputs) for vln_bert in self.vln_berts]

        # if self.args.debug:
        #     self.logger.info("-"*50)
        #     self.logger.info(f"language inputs: {language_inputs}")
        #     self.logger.info(f"language embedding {txt_embeds_list[0]}")
                
        # Initialization the tracking state
        ended = np.array([False] * batch_size)
        just_ended = np.array([False] * batch_size)

        # Init the logs
        masks = []
        entropys = []
        ml_loss = 0.     

        for t in range(self.args.max_action_len):
            
            # if self.args.routing_mode == 'moe' or self.args.routing_mode == 'shared':
            if self.args.routing_mode != 'fixed':

                ### Get loaded_weights from traj
                original_paths = [traj[i]['path'] for i, ob in enumerate(obs)] 
                # paths = [original_paths[i][0] for i in range(len(original_paths))] # 
                paths = [[vp for step in traj_steps for vp in step] for traj_steps in original_paths]
                 
                previous_viewpoint_lists = [convert_path2img(paths[i], ob['scan'], self.vp_lookup) for i, ob in enumerate(obs)]
                candidate_viewpoint_lists = [extract_cand_img(paths[i][-1], ob['scan'], self.vp_lookup) for i, ob in enumerate(obs)]
                
                if self.args.debug:
                    print('-'*20)
                    print(f"Step t: {t}")
                    print(f"Oringinal instructions: {oringal_instructions}")
                    print(f"Instructions: {instructions}")
                    # print(f"Previous viewpoint lists: {previous_viewpoint_lists}")
                    # print(f"Candidate viewpoint lists: {candidate_viewpoint_lists}")
                
                sub_instructions = []
                loaded_weights = []
                
                sub_instructions, loaded_weights = asyncio.run(
                    process_batch_navigation(
                        obs, instructions, previous_viewpoint_lists, candidate_viewpoint_lists, self.args, self.logger, t
                    )
                )
                
                if len(loaded_weights) != len(obs):
                    print(f"[Warning] Length of loaded_weights ({len(loaded_weights)}) does not match batch size ({len(obs)})")
                    loaded_weights = [self.args.resume_weights] * len(obs)
                
                
                if self.args.debug:
                    print('-'*20)
                    print(f"Sub-instructions: {sub_instructions}")  
                    print(f"Loaded weights: {loaded_weights}")
                                
                '''           
                for i, ob in enumerate(obs):
                    identify_result = self.analyzer.identify_instructions_with_qwen(
                        instruction = instructions[i], 
                        previous_viewpoint_list= previous_viewpoint_lists[i]
                    )
                    
                    if isinstance(identify_result, str):
                        try:
                            identify_result = json.loads(identify_result)
                        except json.JSONDecodeError:
                            print("Failed to decode identify_result:", identify_result)
                            identify_result = {}
                    
                    sub_instruction = identify_result.get("Sub-instruction to be executed", "")
                    reasoning = identify_result.get("Reasoning", "")
                    
                    skill_routing = self.analyzer.analyze_with_qwen(
                        instructions[i], 
                        sub_instruction, 
                        reasoning, 
                        previous_viewpoint_list = None, 
                        candidate_viewpoint_list = candidate_viewpoint_lists[i]
                    )
                    _, _, weights = self.analyzer.extract_from_response(skill_routing)
                    
                    # Ensure weights is not empty or malformed
                    if not weights or len(weights) == 0:
                        print(f"[Warning] Empty weights at index {i}, using uniform weights")
                        weights = self.args.resume_weights
                    
                    if self.args.routing_weights_type == 'float':
                        total = sum(weights)
                        if total == 0:
                            print(f"[Warning] Sum of weights is zero at index {i}, using uniform weights")
                            normalized_weights = [1.0 / len(weights)] * len(weights)
                        else:
                            normalized_weights = [w / total for w in weights]
                    elif self.args.routing_weights_type == 'int':
                        normalized_weights = [1 if w > 0 else 0 for w in weights]
                    
                    if self.args.routing_mode == 'shared':
                        normalized_weights = [1 if w > 0 else 0 for w in weights]
                        normalized_weights[-1] = 1
                    
                    sub_instructions.append(sub_instruction)
                    loaded_weights.append(normalized_weights)
                    
                    self.logger.info("-"*20)
                    self.logger.info(f"Step t: {t}")
                    self.logger.info(f"Scan: {ob['scan']}")
                    self.logger.info(f"Original instruction: {instructions[i]}")
                    self.logger.info(f"Previous viewpoint list: {previous_viewpoint_lists[i]}")
                    self.logger.info(f"Candidate viewpoint list: {candidate_viewpoint_lists[i]}")
                    self.logger.info(f"Sub-instruction: {sub_instruction}")
                    self.logger.info(f"Reasoning: {reasoning}")
                    self.logger.info(f"Skill routing: {skill_routing}")
                '''
                
            for i, gmap in enumerate(gmaps):
                if not ended[i]:
                    gmap.node_step_ids[obs[i]['viewpoint']] = t + 1
                                
            # graph representation
            pano_inputs = self._panorama_feature_variable(obs)
            
            # pano_embeds, pano_masks = self.vln_bert('panorama', pano_inputs) # pano_embeds: (bs, n_panos, dim)
            # avg_pano_embeds = torch.sum(pano_embeds * pano_masks.unsqueeze(2), 1) / \
            #                   torch.sum(pano_masks, 1, keepdim=True)
                              
            # Convert pano_embeds to pano_embeds_list and then apply resume_weights

            pano_embeds_list = []
            # pano_masks_list = []
            pano_masks = None

            # Collect pano_embeds from each expert
            for vln_bert in self.vln_berts:
                pano_embeds_i, pano_masks = vln_bert('panorama', pano_inputs)
                pano_embeds_list.append(pano_embeds_i)
                # pano_masks_list.append(pano_masks_i)
            
            # ### Solution1: choose the first
            # pano_embeds = pano_embeds_list[0]
            
            # ### Solution2: choose the max weighted pano_embeds
            # max_id = np.argmax(self.args.resume_weights)
            # pano_embeds = pano_embeds_list[max_id]
            
            ### Solution3: weighted combination of pano_embeds
            pano_embeds = torch.zeros_like(pano_embeds_list[0])
            for model_idx, pano_embeds_i in enumerate(pano_embeds_list):
                # if self.args.resume_weights[model_idx] != 0:
                pano_embeds += self.loaded_weights[model_idx] * pano_embeds_i
            
            avg_pano_embeds = torch.sum(pano_embeds * pano_masks.unsqueeze(2), 1) / \
                  torch.sum(pano_masks, 1, keepdim=True)
                  
            for i, gmap in enumerate(gmaps):
                if not ended[i]:
                    # update visited node
                    i_vp = obs[i]['viewpoint']
                    gmap.update_node_embed(i_vp, avg_pano_embeds[i], rewrite=True)
                    # update unvisited nodes
                    for j, i_cand_vp in enumerate(pano_inputs['cand_vpids'][i]): # traverse all candidates
                        if not gmap.graph.visited(i_cand_vp):
                            gmap.update_node_embed(i_cand_vp, pano_embeds[i, j])

            # navigation policy
            # nav_inputs = self._nav_gmap_variable(obs, gmaps)
            # nav_inputs.update(
            #     self._nav_vp_variable(
            #         obs, gmaps, pano_embeds, pano_inputs['cand_vpids'], 
            #         pano_inputs['view_lens'], pano_inputs['nav_types'],
            #     )
            # )
            
            # top_model_indices = [max(range(len(weights)), key=lambda k: weights[k]) for weights in loaded_weights]

            
            nav_logits_list = []
            for model_idx, vln_bert in enumerate(self.vln_berts):
                # if self.args.resume_weights[model_idx] == 0: continue # Skip if weight is 0
                
                nav_inputs = self._nav_gmap_variable(obs, gmaps)
                nav_inputs.update(
                    self._nav_vp_variable(
                        obs, gmaps, pano_embeds_list[model_idx], pano_inputs['cand_vpids'], 
                        pano_inputs['view_lens'], pano_inputs['nav_types'],
                    )
                )
                nav_inputs.update({
                    'txt_embeds': txt_embeds_list[model_idx],  # Use individual txt_embeds
                    'txt_masks': language_inputs['txt_masks'],
                })
                
                # if model_idx == 0 and self.args.debug:
                #     self.logger.info("-"*50)
                #     self.logger.info(f"nav_inputs: {nav_inputs}")
                
                nav_outs = vln_bert('navigation', nav_inputs) # there are some inf in nav_outs
                
                if self.args.fusion == 'local':
                    nav_logits = nav_outs['local_logits']
                    nav_vpids = nav_inputs['vp_cand_vpids']
                elif self.args.fusion == 'global':
                    nav_logits = nav_outs['global_logits']
                    nav_vpids = nav_inputs['gmap_vpids']
                else:
                    nav_logits = nav_outs['fused_logits'] # Dynamic Fused
                    nav_vpids = nav_inputs['gmap_vpids']
                    
                # nav_logits_list.append(nav_logits.detach())
                # print(type(nav_logits))
                nav_logits_list.append(nav_logits)
                # shared_mask = (nav_logits_list[0] != float('-inf')).float()
                    
            ### Solution1
            # nav_logits = torch.stack(nav_logits_list).mean(0)  # Averaging the logits
            # nav_probs = torch.softmax(nav_logits, dim=1)
            
            ### Solution2: Like LOViS: Weighted combination of logits
            if self.args.routing_mode == 'fixed':
                weighted_logits = torch.zeros_like(nav_logits_list[0])
                for model_idx, nav_logits in enumerate(nav_logits_list):
                    weighted_logits += self.loaded_weights[model_idx] * nav_logits 
                
                nav_probs = torch.softmax(weighted_logits, dim=1)
            
            ### Solution3: Use Router to acitvate the agents
            if self.args.routing_mode != 'fixed':
                '''
                # 3.a. argmax weights
                weighted_logits_list = []

                for i, ob in enumerate(obs):
                    weights_i = loaded_weights[i]

                    # Pick the model index with the highest weight
                    top_model_idx = max(range(len(weights_i)), key=lambda k: weights_i[k])
                    top_model_logits = nav_logits_list[top_model_idx][i]

                    # Ensure shape: (1, num_candidates)
                    if top_model_logits.dim() == 1:
                        top_model_logits = top_model_logits.unsqueeze(0)
                    elif top_model_logits.dim() != 2:
                        raise ValueError(f"Unexpected shape for top_model_logits: {top_model_logits.shape}")

                    weighted_logits_list.append(top_model_logits)

                # Combine all into (batch_size, num_candidates)
                weighted_logits = torch.cat(weighted_logits_list, dim=0)
                nav_probs = torch.softmax(weighted_logits, dim=1)
                '''
                # '''
                # 3.b.
                weighted_logits_list = []
                for i, ob in enumerate(obs):
                    weights_i = loaded_weights[i]
                    weighted_logits_i = None
                    
                    for model_idx, logits in enumerate(nav_logits_list):
                        if weights_i[model_idx] == 0:
                            continue
        
                        if weighted_logits_i is None:
                            weighted_logits_i = weights_i[model_idx] * logits[i]
                        else:
                            weighted_logits_i += weights_i[model_idx] * logits[i]
                            
                    
                    if weighted_logits_i is None:
                        print(f"[Warning] weighted_logits_i is None at ob index {i}, using uniform average")
                        # # Fallback: uniform average of all logits
                        num_models = len(nav_logits_list)
                        weighted_logits_i = sum(logits[i] for logits in nav_logits_list) / num_models
                        
                        # # Using last one nav_logits
                        # weighted_logits_i = nav_logits_list[-1]
                    
                    # Ensure shape: (1, num_candidates)
                    if weighted_logits_i.dim() == 1:
                        weighted_logits_i = weighted_logits_i.unsqueeze(0)
                    elif weighted_logits_i.dim() != 2:
                        raise ValueError(f"Unexpected shape for weighted_logits_i: {weighted_logits_i.shape}")
                                
                    # weighted_logits_list.append(weighted_logits_i) # retain batch dim
                    weighted_logits_list.append(weighted_logits_i)  # Shape: (1, num_candidates)
                
                weighted_logits = torch.cat(weighted_logits_list, dim=0) # shape: (batch_size, num_candidates)
                # print("weighted_logits shape:", weighted_logits.shape)
                nav_probs = torch.softmax(weighted_logits, dim=1)
                # '''               
                
            # if self.args.debug:
            #     self.logger.info("-"*20)
            #     # self.logger.info(nav_logits_list[0])
            #     self.logger.info(f"weighted_logits: {weighted_logits}")
            #     self.logger.info(f"nav_probs: {nav_probs}")
            #     self.logger.info("-"*20)
            
            # update graph
            for i, gmap in enumerate(gmaps):
                if not ended[i]:
                    i_vp = obs[i]['viewpoint']
                    gmap.node_stop_scores[i_vp] = {
                       'stop': nav_probs[i, 0].data.item()
                    }
                                        
            if train_ml is not None:
                # Supervised training
                if self.args.dataset == 'r2r':
                    # nav_targets = self._teacher_action(
                    #     obs, nav_vpids, ended, 
                    #     visited_masks=nav_inputs['gmap_visited_masks'] if self.args.fusion != 'local' else None
                    # )
                    nav_targets = self._teacher_action_r4r(
                        obs, nav_vpids, ended, 
                        visited_masks=nav_inputs['gmap_visited_masks'] if self.args.fusion != 'local' else None,
                        imitation_learning=(self.feedback=='teacher'), t=t, traj=traj
                    )
                elif self.args.dataset == 'r4r':
                    nav_targets = self._teacher_action_r4r(
                        obs, nav_vpids, ended, 
                        visited_masks=nav_inputs['gmap_visited_masks'] if self.args.fusion != 'local' else None,
                        imitation_learning=(self.feedback=='teacher'), t=t, traj=traj
                    )
                # print(t, nav_logits, nav_targets)
                ml_loss += self.criterion(nav_logits, nav_targets)
                # print(t, 'ml_loss', ml_loss.item(), self.criterion(nav_logits, nav_targets).item())
                                                 
            # Determinate the next navigation viewpoint
            if self.feedback == 'teacher':
                a_t = nav_targets                 # teacher forcing
            elif self.feedback == 'argmax':
                # _, a_t = nav_logits.max(1)
                _, a_t = weighted_logits.max(1)        # student forcing - argmax
                a_t = a_t.detach() 
            elif self.feedback == 'sample':
                c = torch.distributions.Categorical(nav_probs)
                self.logs['entropy'].append(c.entropy().sum().item())            # For log
                entropys.append(c.entropy())                                     # For optimization
                a_t = c.sample().detach() 
            elif self.feedback == 'expl_sample':
                _, a_t = nav_probs.max(1)
                rand_explores = np.random.rand(batch_size, ) > self.args.expl_max_ratio  # hyper-param
                if self.args.fusion == 'local':
                    cpu_nav_masks = nav_inputs['vp_nav_masks'].data.cpu().numpy()
                else:
                    cpu_nav_masks = (nav_inputs['gmap_masks'] * nav_inputs['gmap_visited_masks'].logical_not()).data.cpu().numpy()
                for i in range(batch_size):
                    if rand_explores[i]:
                        cand_a_t = np.arange(len(cpu_nav_masks[i]))[cpu_nav_masks[i]]
                        a_t[i] = np.random.choice(cand_a_t)
            else:
                print(self.feedback)
                sys.exit('Invalid feedback option')

            # Determine stop actions
            if self.feedback == 'teacher' or self.feedback == 'sample': # in training
                # a_t_stop = [ob['viewpoint'] in ob['gt_end_vps'] for ob in obs]
                a_t_stop = [ob['viewpoint'] == ob['gt_path'][-1] for ob in obs] #	In training, the agent stops when it reaches the last ground-truth viewpoint.
            else:
                a_t_stop = a_t == 0 # In inference, it stops when a_t == 0.

            # Prepare environment action
            cpu_a_t = []  
            for i in range(batch_size):
                if a_t_stop[i] or ended[i] or nav_inputs['no_vp_left'][i] or (t == self.args.max_action_len):
                # if a_t_stop[i] or ended[i] or nav_inputs['no_vp_left'][i] or (t == self.args.max_action_len - 1): # determines if an action should be set to None (which means no movement will happen)
                    cpu_a_t.append(None)
                    just_ended[i] = True
                else:
                    cpu_a_t.append(nav_vpids[i][a_t[i]])   

            # Make action and get the new state
            self.make_equiv_action(cpu_a_t, gmaps, obs, traj)
            for i in range(batch_size):
                if (not ended[i]) and just_ended[i]:
                    stop_node, stop_score = None, {'stop': -float('inf')}
                    for k, v in gmaps[i].node_stop_scores.items():
                        if v['stop'] > stop_score['stop']:
                            stop_score = v
                            stop_node = k
                    if stop_node is not None and obs[i]['viewpoint'] != stop_node: # If stop_node is found and it is not the current viewpoint, the function adds a path to that stop node.
                        traj[i]['path'].append(gmaps[i].graph.path(obs[i]['viewpoint'], stop_node))
                    if self.args.detailed_output:
                        for k, v in gmaps[i].node_stop_scores.items():
                            traj[i]['details'][k] = {
                                'stop_prob': float(v['stop']),
                            }

            # new observation and update graph
            obs = self.env._get_obs()
            self._update_scanvp_cands(obs)
            for i, ob in enumerate(obs):
                if not ended[i]:
                    gmaps[i].update_graph(ob)

            ended[:] = np.logical_or(ended, np.array([x is None for x in cpu_a_t]))

            # Early exit if all ended
            if ended.all():
                break
        
        if train_ml is not None:
            ml_loss = ml_loss * train_ml / batch_size
            self.loss += ml_loss
            self.logs['IL_loss'].append(ml_loss.item())

        return traj
