import os
import torch

from helios.agent.gaze.gaussian_representation.slam import SLAM
from helios.agent.gaze.gaussian_representation.backend import BackEnd

class SLAM_NoBackend(BackEnd, SLAM):
    def __init__(self,
        config: dict, 
        max_frames: int, 
        n_semantic_channels: int):

        super().__init__(config, max_frames, n_semantic_channels)

        self.keyframe_list = []
        self.keyframe_time_indices = []
        self.gt_w2c_all_frames = []

        self.checkpoint_time_idx = 0
        self.output_dir = os.path.join(config["workdir"], 'objects')

        self.device = torch.device(config["primary_device"])

        self.prev_len = 0

    def reset(self):
        super().reset()
        self.keyframe_list = []
        self.keyframe_time_indices = []
        self.gt_w2c_all_frames = []
        self.checkpoint_time_idx = 0
        self.prev_len = 0

    def first_frame(self, dataset):
        super().first_frame(dataset)
        if not self.params is None:
            self.semantic_scaled = torch.nan_to_num(self.params["semantic_c"]/torch.sum(self.params["semantic_c"],dim=-1).unsqueeze(-1))
            self.uncertainty =  torch.sqrt(self.semantic_scaled*(1-self.semantic_scaled)/(1+ torch.sum(self.params["semantic_c"],dim=-1).unsqueeze(-1))) 
            self.prev_len = self.params["semantic_c"].shape[0]
                    
            
    def step(self, dataset, not_overlapping):
        super().step(dataset, not_overlapping)
        if not self.params is None:
            self.semantic_scaled = torch.nan_to_num(self.params["semantic_c"]/torch.sum(self.params["semantic_c"],dim=-1).unsqueeze(-1))
            self.uncertainty =  torch.sqrt(self.semantic_scaled*(1-self.semantic_scaled)/(1+ torch.sum(self.params["semantic_c"],dim=-1).unsqueeze(-1))) 
            self.prev_len = self.params["semantic_c"].shape[0]
                    