''' Batched Room-to-Room navigation environment '''

import json
import os
import numpy as np
import math
import random
import networkx as nx
from collections import defaultdict

import lmdb
from timm.data.transforms_factory import create_transform
from PIL import Image

import MatterSim

from r2r.data_utils import load_nav_graphs
from r2r.data_utils import new_simulator
from r2r.data_utils import angle_feature, get_all_point_angle_feature

from r2r.eval_utils import cal_dtw, cal_cls
import torch
import time

from .vision_transformer import vit_base_patch16_224

from functools import partial
import torch.nn as nn

import ipywidgets as widgets
import io
from torchvision import transforms
import matplotlib.pyplot as plt
import torch.nn.functional as F
from timm.models.layers import to_2tuple

import warnings
from datetime import datetime
import sys
from thop import profile
import layer_count
warnings.filterwarnings("ignore")

ERROR_MARGIN = 3.0
HEIGHT = 248
WIDTH = 330
IMGSIZE = 224
class EnvBatch(object):
    ''' A simple wrapper for a batch of MatterSim environments,
        using discretized viewpoints and pretrained features '''

    def __init__(self, connectivity_dir, img_db_file, scan_data_dir=None, feat_db=None, batch_size=100, cache=None, mode=None):
        """
        1. Load pretrained image feature
        2. Init the Simulator.
        :param feat_db: The name of file stored the feature.
        :param batch_size:  Used to create the simulator list.
        """
        self.feat_db = feat_db
        self.cache = cache
        self.mode = mode
        self.image_w = 640
        self.image_h = 480
        self.vfov = 60
        self.sims = []
        self.buffered_state_dict = {}

        self.sim_nav = new_simulator(connectivity_dir)

        self.vision_backbone = vit_base_patch16_224(pretrained=True,
            drop_rate=0.1, 
            attn_drop_rate=0.1, 
            drop_path_rate=0.)

        self.img_db_file = img_db_file
        self.img_db_env = lmdb.open(
            self.img_db_file,
            map_size=int(1e12),
            readonly=True,
            create=False,
            readahead=False,
            max_readers=2000,
        )
        self.img_db_txn = self.img_db_env.begin()

        self.img_transform = create_transform(
            (3, 224, 224),
            mean=(0.5, 0.5, 0.5),
            std=(0.5, 0.5, 0.5),
            interpolation="bicubic",
            crop_pct=0.9,
            is_training=False,
            auto_augment=None,
            re_mode="const",
            re_prob=0.0,
        )

        for i in range(batch_size):
            sim = MatterSim.Simulator()
            if scan_data_dir:
                sim.setDatasetPath(scan_data_dir)
            sim.setNavGraphPath(connectivity_dir)
            sim.setRenderingEnabled(False)
            sim.setDiscretizedViewingAngles(True)   # Set increment/decrement to 30 degree. (otherwise by radians)
            sim.setCameraResolution(self.image_w, self.image_h)
            sim.setCameraVFOV(math.radians(self.vfov))
            sim.initialize()
            self.sims.append(sim)

    def forward_vision_backbone(self, images, house_id, navigable_idx, viewpointId, instr_id, detach=False):
        att_images = images
        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        if device.type == "cuda":
            torch.cuda.set_device(0)
        patch_size = 16
        model = self.vision_backbone
        factor_reduce = 2

        # due to memory issue, we cannot propagate to pano images in the history
        is_pano = len(images.size()) == 6

        if is_pano:
            N, T, P, C, H, W = images.size()    # pano images
            images = images.view(N*T*P, C, H, W)
        else:
            T, C, H, W = images.size()
            # N = 1, since we are using batch size 1
            images = images.view(1*T, C, H, W)

        # img_size = tuple(np.array(att_images.shape[-2:]) // factor_reduce) #????

        if is_pano:
            with torch.no_grad():
                feats, flops, params = profile(self.vision_backbone, inputs=(images, house_id, navigable_idx, viewpointId, instr_id, self.mode), verbose=False)
                gflops = flops / (10**9)
                txt = f"is_pano (ViT)Gflops: {gflops}\n"
                with open('gflops_per_traj_log.txt', 'a') as file:
                    file.write(txt)
                    
                layer_count.total_gflops += gflops
            feats = feats.view(1, T, P, -1)

        else:
            feats, flops, params = profile(self.vision_backbone, inputs=(images, house_id, navigable_idx, viewpointId, instr_id, self.mode), verbose=False)
            gflops = flops / (10**9)
            txt = f"not_pano (ViT)Gflops: {gflops}\n"
            with open('gflops_per_traj_log.txt', 'a') as file:
                file.write(txt)

            layer_count.total_gflops += gflops

            feats = feats.view(1, T, -1)

        if detach:
            feats = feats.detach()
    
        return feats

    def _make_id(self, scanId, viewpointId):
        return scanId + '_' + viewpointId

    def newEpisodes(self, scanIds, viewpointIds, headings):
        for i, (scanId, viewpointId, heading) in enumerate(zip(scanIds, viewpointIds, headings)):
            self.sims[i].newEpisode([scanId], [viewpointId], [heading], [0])

    def get_image(self, scan, viewpoint):
        key = "%s_%s" % (scan, viewpoint)
        buf = self.img_db_txn.get(key.encode("ascii"))
        images = np.frombuffer(buf, dtype=np.uint8)
        images = images.reshape(36, HEIGHT, WIDTH, 3)  # fixed image size
        # (36, 3, IMGSIZE, IMGSIZE)
        images = torch.stack(
            [self.img_transform(Image.fromarray(image)) for image in images], 0
        )

        return images

    def navigable_view(self, scanId, viewpointId, viewId):
        def _loc_distance(loc):
            return np.sqrt(loc.rel_heading ** 2 + loc.rel_elevation ** 2)
        base_heading = (viewId % 12) * math.radians(30)
        adj_dict = {}
        long_id = "%s_%s" % (scanId, viewpointId)
        if long_id not in self.buffered_state_dict:
            for ix in range(36):
                if ix == 0:
                    self.sim_nav.newEpisode([scanId], [viewpointId], [0], [math.radians(-30)])
                elif ix % 12 == 0:
                    self.sim_nav.makeAction([0], [1.0], [1.0])
                else:
                    self.sim_nav.makeAction([0], [1.0], [0])

                state = self.sim_nav.getState()[0]
                assert state.viewIndex == ix

                # Heading and elevation for the viewpoint center
                heading = state.heading - base_heading
                elevation = state.elevation

                # visual_feat = feature[ix]
                    
                # get adjacent locations
                for j, loc in enumerate(state.navigableLocations[1:]):
                    # if a loc is visible from multiple view, use the closest
                    # view (in angular distance) as its representation
                    distance = _loc_distance(loc)

                    # Heading and elevation for for the loc
                    loc_heading = heading + loc.rel_heading
                    loc_elevation = elevation + loc.rel_elevation
                    # angle_feat = angle_feature(loc_heading, loc_elevation, self.angle_feat_size)
                    if (loc.viewpointId not in adj_dict or
                            distance < adj_dict[loc.viewpointId]['distance']):
                        adj_dict[loc.viewpointId] = {
                            'heading': loc_heading,
                            'elevation': loc_elevation,
                            "normalized_heading": state.heading + loc.rel_heading,
                            'scanId':scanId,
                            'viewpointId': loc.viewpointId, # Next viewpoint id
                            'pointId': ix,
                            'distance': distance,
                            'idx': j + 1,
                            # 'feature': np.concatenate((visual_feat, angle_feat), -1)
                        }
            candidate = list(adj_dict.values())
            self.buffered_state_dict[long_id] = [
                {key: c[key]
                 for key in
                    ['normalized_heading', 'elevation', 'scanId', 'viewpointId',
                     'pointId', 'idx']}
                for c in candidate
            ]
            return candidate
        else:
            candidate = self.buffered_state_dict[long_id]
            candidate_new = []
            for c in candidate:
                c_new = c.copy()
                ix = c_new['pointId']
                normalized_heading = c_new['normalized_heading']
                # visual_feat = feature[ix]
                loc_heading = normalized_heading - base_heading
                c_new['heading'] = loc_heading
                # angle_feat = angle_feature(c_new['heading'], c_new['elevation'], self.angle_feat_size)
                # c_new['feature'] = np.concatenate((visual_feat, angle_feat), -1)
                c_new.pop('normalized_heading')
                candidate_new.append(c_new)
            return candidate_new

    def getStates(self, item):
        """
        Get list of states augmented with precomputed image features. rgb field will be empty.
        Agent's current view [0-35] (set only when viewing angles are discretized)
            [0-11] looking down, [12-23] looking at horizon, [24-35] looking up
        :return: [ ((36, 2048), sim_state) ] * batch_size
        """
        feature_states = []
        navigable_idx = set()

        instr_id = item['instr_id']

        for i, sim in enumerate(self.sims):
            state = sim.getState()[0]

            if self.cache == "True":
                feature = self.feat_db.get_image_feature(state.scanId, state.location.viewpointId)

            elif self.cache == "False":
                images = self.get_image(state.scanId, state.location.viewpointId)
                base_view_id = state.viewIndex
                nav = self.navigable_view(state.scanId, state.location.viewpointId, base_view_id)
                
                for i in nav:
                    navigable_idx.add(i['pointId'])

                feature = self.forward_vision_backbone(images, state.scanId, navigable_idx, state.location.viewpointId, instr_id)
                feature = feature.squeeze(0)
                feature = feature.detach().numpy().astype(np.float32)

            else:
                raise ValueError('Invalid value for flag cache.')

            feature_states.append((feature, state))
        return feature_states

    def makeActions(self, actions):
        ''' Take an action using the full state dependent action interface (with batched input).
            Every action element should be an (index, heading, elevation) tuple. '''
        for i, (index, heading, elevation) in enumerate(actions):
            self.sims[i].makeAction([index], [heading], [elevation])


class R2RBatch(object):
    ''' Implements the Room to Room navigation task, using discretized viewpoints and pretrained features '''

    def __init__(
        self, feat_db, instr_data, connectivity_dir, img_db_file,
        batch_size=64, angle_feat_size=4,
        seed=0, name=None, sel_data_idxs=None, cache=None, mode=None
    ):
        self.env = EnvBatch(connectivity_dir, img_db_file, feat_db=feat_db, batch_size=batch_size, cache=cache, mode=mode)

        self.data = instr_data
        self.scans = set([x['scan'] for x in self.data])
        # to evaluate full data
        self.gt_trajs = self._get_gt_trajs(self.data)

        # in validation, we would split the data
        if sel_data_idxs is not None:
            t_split, n_splits = sel_data_idxs
            ndata_per_split = len(self.data) // n_splits 
            start_idx = ndata_per_split * t_split
            if t_split == n_splits - 1:
                end_idx = None
            else:
                end_idx = start_idx + ndata_per_split
            self.data = self.data[start_idx: end_idx]

        self.connectivity_dir = connectivity_dir
        self.angle_feat_size = angle_feat_size
        self.name = name
        # use different seeds in different processes to shuffle data
        self.seed = seed
        random.seed(self.seed)
        random.shuffle(self.data)

        self.ix = 0
        self.batch_size = batch_size
        self._load_nav_graphs()

        self.sim = new_simulator(self.connectivity_dir)
        self.angle_feature = get_all_point_angle_feature(self.sim, self.angle_feat_size)
        
        self.buffered_state_dict = {}
        print('%s loaded with %d instructions, using splits: %s' % (
            self.__class__.__name__, len(self.data), self.name))

    def _get_gt_trajs(self, data):
        return {x['instr_id']: (x['scan'], x['path']) for x in data}

    def size(self):
        return len(self.data)

    def _load_nav_graphs(self):
        """
        load graph from self.scan,
        Store the graph {scan_id: graph} in self.graphs
        Store the shortest path {scan_id: {view_id_x: {view_id_y: [path]} } } in self.paths
        Store the distances in self.distances. (Structure see above)
        Load connectivity graph for each scan, useful for reasoning about shortest paths
        :return: None
        """
        print('Loading navigation graphs for %d scans' % len(self.scans))

        self.graphs = load_nav_graphs(self.connectivity_dir, self.scans)
        self.shortest_paths = {}
        for scan, G in self.graphs.items():  # compute all shortest paths
            self.shortest_paths[scan] = dict(nx.all_pairs_dijkstra_path(G))
        self.shortest_distances = {}
        for scan, G in self.graphs.items():  # compute all shortest paths
            self.shortest_distances[scan] = dict(nx.all_pairs_dijkstra_path_length(G))

    def _next_minibatch(self, batch_size=None, **kwargs):
        """
        Store the minibach in 'self.batch'
        :param tile_one: Tile the one into batch_size
        :return: None
        """
        if batch_size is None:
            batch_size = self.batch_size
        
        batch = self.data[self.ix: self.ix+batch_size]
        if len(batch) < batch_size:
            random.shuffle(self.data)
            self.ix = batch_size - len(batch)
            batch += self.data[:self.ix]
        else:
            self.ix += batch_size
        self.batch = batch

    def reset_epoch(self, shuffle=False):
        ''' Reset the data index to beginning of epoch. Primarily for testing.
            You must still call reset() for a new episode. '''
        if shuffle:
            random.shuffle(self.data)
        self.ix = 0

    def _shortest_path_action(self, state, goalViewpointId):
        ''' Determine next action on the shortest path to goal, for supervised training. '''
        if state.location.viewpointId == goalViewpointId:
            return goalViewpointId      # Just stop here
        path = self.shortest_paths[state.scanId][state.location.viewpointId][goalViewpointId]
        nextViewpointId = path[1]
        return nextViewpointId

    def make_candidate(self, feature, scanId, viewpointId, viewId):
        def _loc_distance(loc):
            return np.sqrt(loc.rel_heading ** 2 + loc.rel_elevation ** 2)
        base_heading = (viewId % 12) * math.radians(30)
        adj_dict = {}
        long_id = "%s_%s" % (scanId, viewpointId)
        if long_id not in self.buffered_state_dict:
            for ix in range(36):
                if ix == 0:
                    self.sim.newEpisode([scanId], [viewpointId], [0], [math.radians(-30)])
                elif ix % 12 == 0:
                    self.sim.makeAction([0], [1.0], [1.0])
                else:
                    self.sim.makeAction([0], [1.0], [0])

                state = self.sim.getState()[0]
                assert state.viewIndex == ix

                # Heading and elevation for the viewpoint center
                heading = state.heading - base_heading
                elevation = state.elevation

                visual_feat = feature[ix]
                    
                # get adjacent locations
                for j, loc in enumerate(state.navigableLocations[1:]):
                    # if a loc is visible from multiple view, use the closest
                    # view (in angular distance) as its representation
                    distance = _loc_distance(loc)

                    # Heading and elevation for for the loc
                    loc_heading = heading + loc.rel_heading
                    loc_elevation = elevation + loc.rel_elevation
                    angle_feat = angle_feature(loc_heading, loc_elevation, self.angle_feat_size)
                    if (loc.viewpointId not in adj_dict or
                            distance < adj_dict[loc.viewpointId]['distance']):
                        adj_dict[loc.viewpointId] = {
                            'heading': loc_heading,
                            'elevation': loc_elevation,
                            "normalized_heading": state.heading + loc.rel_heading,
                            'scanId':scanId,
                            'viewpointId': loc.viewpointId, # Next viewpoint id
                            'pointId': ix,
                            'distance': distance,
                            'idx': j + 1,
                            'feature': np.concatenate((visual_feat, angle_feat), -1)
                        }
            candidate = list(adj_dict.values())
            self.buffered_state_dict[long_id] = [
                {key: c[key]
                 for key in
                    ['normalized_heading', 'elevation', 'scanId', 'viewpointId',
                     'pointId', 'idx']}
                for c in candidate
            ]
            return candidate
        else:
            candidate = self.buffered_state_dict[long_id]
            candidate_new = []
            for c in candidate:
                c_new = c.copy()
                ix = c_new['pointId']
                normalized_heading = c_new['normalized_heading']
                visual_feat = feature[ix]
                loc_heading = normalized_heading - base_heading
                c_new['heading'] = loc_heading
                angle_feat = angle_feature(c_new['heading'], c_new['elevation'], self.angle_feat_size)
                c_new['feature'] = np.concatenate((visual_feat, angle_feat), -1)
                c_new.pop('normalized_heading')
                candidate_new.append(c_new)
            return candidate_new

    def _teacher_path_action(self, state, path, t=None, shortest_teacher=False):
        if shortest_teacher:
            return self._shortest_path_action(state, path[-1])

        teacher_vp = None
        if t is not None:
            teacher_vp = path[t + 1] if t < len(path) - 1 else state.location.viewpointId
        else:
            if state.location.viewpointId in path:
                cur_idx = path.index(state.location.viewpointId)
                if cur_idx == len(path) - 1: # STOP
                    teacher_vp = state.location.viewpointId
                else:
                    teacher_vp = path[cur_idx + 1]
        return teacher_vp

    def _get_obs(self, t=None, shortest_teacher=False):
        obs = []
        # pass self.batch[0] to get instruction ID
        for i, (feature, state) in enumerate(self.env.getStates(self.batch[0])):
            item = self.batch[i]
            base_view_id = state.viewIndex

            if feature is None:
                feature = np.zeros((36, 2048))

            # Full features
            candidate = self.make_candidate(feature, state.scanId, state.location.viewpointId, state.viewIndex)
            # [visual_feature, angle_feature] for views
            feature = np.concatenate((feature, self.angle_feature[base_view_id]), -1)

            obs.append({
                'instr_id' : item['instr_id'],
                'scan' : state.scanId,
                'viewpoint' : state.location.viewpointId,
                'viewIndex' : state.viewIndex,
                'heading' : state.heading,
                'elevation' : state.elevation,
                'feature' : feature,
                'candidate': candidate,
                'navigableLocations' : state.navigableLocations,
                'instruction' : item['instruction'],
                'teacher' : self._teacher_path_action(state, item['path'], t=t, shortest_teacher=shortest_teacher),
                'gt_path' : item['path'],
                'path_id' : item['path_id']
            })
            if 'instr_encoding' in item:
                obs[-1]['instr_encoding'] = item['instr_encoding']
            # A2C reward. The negative distance between the state and the final state
            obs[-1]['distance'] = self.shortest_distances[state.scanId][state.location.viewpointId][item['path'][-1]]
        return obs

    def reset(self, **kwargs):
        ''' Load a new minibatch / episodes. '''
        self._next_minibatch(**kwargs)
        
        scanIds = [item['scan'] for item in self.batch]
        viewpointIds = [item['path'][0] for item in self.batch]
        headings = [item['heading'] for item in self.batch]
        self.env.newEpisodes(scanIds, viewpointIds, headings)
        return self._get_obs(t=0)

    def step(self, actions, t=None):
        ''' Take action (same interface as makeActions) '''
        self.env.makeActions(actions)
        return self._get_obs(t=t)


    ############### Evaluation ###############
    def _get_nearest(self, shortest_distances, goal_id, path):
        near_id = path[0]
        near_d = shortest_distances[near_id][goal_id]
        for item in path:
            d = shortest_distances[item][goal_id]
            if d < near_d:
                near_id = item
                near_d = d
        return near_id

    def _eval_item(self, scan, path, gt_path):
        scores = {}

        shortest_distances = self.shortest_distances[scan]

        assert gt_path[0] == path[0], 'Result trajectories should include the start position'

        nearest_position = self._get_nearest(shortest_distances, gt_path[-1], path)

        scores['nav_error'] = shortest_distances[path[-1]][gt_path[-1]]
        scores['oracle_error'] = shortest_distances[nearest_position][gt_path[-1]]
        scores['trajectory_steps'] = len(path) - 1
        scores['trajectory_lengths'] = np.sum([shortest_distances[a][b] for a, b in zip(path[:-1], path[1:])])

        gt_lengths = np.sum([shortest_distances[a][b] for a, b in zip(gt_path[:-1], gt_path[1:])])
        
        scores['success'] = float(scores['nav_error'] < ERROR_MARGIN)
        scores['spl'] = scores['success'] * gt_lengths / max(scores['trajectory_lengths'], gt_lengths, 0.01)
        scores['oracle_success'] = float(scores['oracle_error'] < ERROR_MARGIN)

        scores.update(
            cal_dtw(shortest_distances, path, gt_path, scores['success'], ERROR_MARGIN)
        )
        scores['CLS'] = cal_cls(shortest_distances, path, gt_path, ERROR_MARGIN)

        return scores

    def eval_metrics(self, preds):
        ''' Evaluate each agent trajectory based on how close it got to the goal location 
        the path contains [view_id, angle, vofv]'''
        print('eval %d predictions' % (len(preds)))

        metrics = defaultdict(list)
        for item in preds:
            instr_id = item['instr_id']
            traj = [x[0] for x in item['trajectory']]
            scan, gt_traj = self.gt_trajs[instr_id]
            traj_scores = self._eval_item(scan, traj, gt_traj)
            for k, v in traj_scores.items():
                metrics[k].append(v)
            metrics['instr_id'].append(instr_id)
        
        avg_metrics = {
            'steps': np.mean(metrics['trajectory_steps']),
            'lengths': np.mean(metrics['trajectory_lengths']),
            'nav_error': np.mean(metrics['nav_error']),
            'oracle_error': np.mean(metrics['oracle_error']),
            'sr': np.mean(metrics['success']) * 100,
            'oracle_sr': np.mean(metrics['oracle_success']) * 100,
            'spl': np.mean(metrics['spl']) * 100,
            'nDTW': np.mean(metrics['nDTW']) * 100,
            'SDTW': np.mean(metrics['SDTW']) * 100,
            'CLS': np.mean(metrics['CLS']) * 100,
        }
        return avg_metrics, metrics


class R2RBackBatch(R2RBatch):
    def __init__(
        self, feat_db, instr_data, connectivity_dir, img_db_file,
        batch_size=64, angle_feat_size=4,
        seed=0, name=None, sel_data_idxs=None
    ):
        self.gt_midstops = {
            x['instr_id']: x['midstop'] for x in instr_data
        }
        super().__init__(
            feat_db, instr_data, connectivity_dir, img_db_file, batch_size=batch_size,
            angle_feat_size=angle_feat_size, seed=seed, name=name, sel_data_idxs=sel_data_idxs
        )

    def _get_obs(self, t=None, shortest_teacher=False):
        obs = []
        for i, (feature, state) in enumerate(self.env.getStates(self.batch[0])):
            item = self.batch[i]
            base_view_id = state.viewIndex

            if feature is None:
                feature = np.zeros((36, 2048))

            # Full features
            candidate = self.make_candidate(feature, state.scanId, state.location.viewpointId, state.viewIndex)
            # [visual_feature, angle_feature] for views
            feature = np.concatenate((feature, self.angle_feature[base_view_id]), -1)

            obs.append({
                'instr_id' : item['instr_id'],
                'scan' : state.scanId,
                'viewpoint' : state.location.viewpointId,
                'viewIndex' : state.viewIndex,
                'heading' : state.heading,
                'elevation' : state.elevation,
                'feature' : feature,
                'candidate': candidate,
                'navigableLocations' : state.navigableLocations,
                'instruction' : item['instruction'],
                'teacher' : self._teacher_path_action(state, item['path'], t=t, shortest_teacher=shortest_teacher),
                'gt_path' : item['path'],
                'path_id' : item['path_id']
            })
            if 'instr_encoding' in item:
                obs[-1]['instr_encoding'] = item['instr_encoding']
            # A2C reward. The negative distance between the state and the final state
            obs[-1]['distance'] = (
                self.shortest_distances[state.scanId][state.location.viewpointId][item['midstop']],
                self.shortest_distances[state.scanId][state.location.viewpointId][item['path'][-1]]
            )
        return obs

    def _eval_item(self, scan, path, gt_path, midstop, gt_midstop):
        scores = {}

        shortest_distances = self.shortest_distances[scan]

        assert gt_path[0] == path[0], 'Result trajectories should include the start position'

        scores['nav_error'] = shortest_distances[path[-1]][gt_path[-1]]
        scores['trajectory_steps'] = len(path) - 1
        scores['trajectory_lengths'] = np.sum([shortest_distances[a][b] for a, b in zip(path[:-1], path[1:])])

        gt_lengths = np.sum([shortest_distances[a][b] for a, b in zip(gt_path[:-1], gt_path[1:])])
        
        success = 0
        if midstop is not None:
            if shortest_distances[midstop][gt_midstop] <= ERROR_MARGIN:
                if shortest_distances[path[-1]][gt_path[-1]] <= ERROR_MARGIN:
                    success = 1

        scores['success'] = success
        scores['spl'] = scores['success'] * gt_lengths / max(scores['trajectory_lengths'], gt_lengths, 0.01)

        scores.update(
            cal_dtw(shortest_distances, path, gt_path, scores['success'], ERROR_MARGIN)
        )
        scores['CLS'] = cal_cls(shortest_distances, path, gt_path, ERROR_MARGIN)

        return scores

    def eval_metrics(self, preds):
        print('eval %d predictions' % (len(preds)))

        metrics = defaultdict(list)

        for item in preds:
            instr_id = item['instr_id']
            traj = [x[0] for x in item['trajectory']]
            scan, gt_traj = self.gt_trajs[instr_id]
            traj_scores = self._eval_item(
                scan, traj, gt_traj, item['midstop'], self.gt_midstops[instr_id]
            )
            for k, v in traj_scores.items():
                metrics[k].append(v)
            metrics['instr_id'].append(instr_id)
        
        avg_metrics = {
            'steps': np.mean(metrics['trajectory_steps']),
            'lengths': np.mean(metrics['trajectory_lengths']),
            'nav_error': np.mean(metrics['nav_error']),
            'sr': np.mean(metrics['success']) * 100,
            'spl': np.mean(metrics['spl']) * 100,
            'nDTW': np.mean(metrics['nDTW']) * 100,
            'SDTW': np.mean(metrics['SDTW']) * 100,
            'CLS': np.mean(metrics['CLS']) * 100,
        }

        return avg_metrics, metrics
