# Adapted from SplaTAM: https://github.com/spla-tam/SplaTAM/blob/main/scripts/splatam.py
from typing import Tuple, List
import torch
import torch.multiprocessing as mp
from helios.agent.gaze.gaussian_representation.backend_interface import BackEndInterface
mp.set_start_method("spawn")

from helios.agent.gaze.splatam.slam import SLAM

class SLAM_WithBackend(SLAM):
    def __init__(self, 
        config: dict, 
        max_frames: int, 
        n_semantic_channels: int
    ):
        super().__init__(config, max_frames, n_semantic_channels)

        self.frontend_queue = mp.Queue()
        self.backend_queue = mp.Queue()
        self.backend = BackEndInterface(
            config,
            max_frames,
            n_semantic_channels,
            self.backend_queue,
            self.frontend_queue
        )

        self.backend_process = mp.Process(target=self.backend.run)
        self.backend_process.start()

    def reset(self):
        super().reset()

        self.backend.reset()

    def stop(self):
        """ Stop the backend """
        self.backend_queue.put(["stop"])
        self.backend_process.join()

        self.backend_process = None

    def pause(self):
        """ Pause the backend """
        self.backend_queue.put(["pause"])
    
    def resume(self):
        """ Resume the backend """
        self.backend_queue.put(["unpause"])

    def first_frame(self, dataset):
        print("Calling first_frame!")
        if self.backend_process is None:
            self.backend_process = mp.Process(target=self.backend.run)
            self.backend_process.start()
        self.backend_queue.put(['init', dataset])
        self.params = None

        self.img_w = dataset[0].shape[1]
        self.img_h = dataset[0].shape[0]

        while True:
            # first process all data in the frontend_queue
            while not self.frontend_queue.empty():
                data = self.frontend_queue.get()
                if data[0] == "finished init":
                    self.first_frame_w2c = data[1]
                    self.intrinsics = data[2]
                    return


    def step(self, dataset, not_overlapping):
        self.obj_locs = {}
        self.obj_scores = {}
        self.backend_queue.put(['step', dataset, not_overlapping])
        self.params = None

    @torch.no_grad()
    def get_params(self):
        self.backend_queue.put(['get_params'])
        torch.cuda.empty_cache()
        while True:
            # first process all data in the frontend_queue
            while not self.frontend_queue.empty():
                data = self.frontend_queue.get()
                if data[0] == "params":
                    self.params = data[1]
                    if len(self.params.keys())>0:
                        self.semantic_scaled = torch.nan_to_num(self.params["semantic_c"]/torch.sum(self.params["semantic_c"],dim=-1).unsqueeze(-1))
                        self.uncertainty =  torch.sqrt(self.semantic_scaled*(1-self.semantic_scaled)/(1+ torch.sum(self.params["semantic_c"],dim=-1).unsqueeze(-1))) 
                    # self.score = self.semantic_scaled-self.config.objective.alpha_uncertainty*self.uncertainty
                    # if self.use_instances:
                    #     for i in torch.unique(self.params["instances"]):
                    #         for c in range(self.n_semantic_channels):
                                # self.semantic_scaled[:,c][self.params["instances"]==i] = torch.mean(self.semantic_scaled[:,c][self.params["instances"]==i])
                                # self.uncertainty[:,c][self.params["instances"]==i] = torch.mean(self.uncertainty[:,c][self.params["instances"]==i])
                                # self.score[:,c][self.params["instances"]==i] = torch.mean(self.score[:,c][self.params["instances"]==i])
                    return


    @torch.no_grad()
    def has_obj(self, obj_class: int, threshold: float, use_uncertainty: bool):
        if self.params is None:
            self.get_params()
        return super().has_obj(obj_class, threshold, use_uncertainty)[0]

    @torch.no_grad()
    def get_obj_score(self, obj_class: int, use_uncertainty: bool) -> float:
        if self.params is None:
            self.get_params()
        return super().get_obj_score(obj_class, use_uncertainty)
            
    
    @torch.no_grad()
    def get_obj_locs(
        self,
        obj_class: int,
        threshold: float,
        use_uncertainty: bool,
        return_scores: bool,
        return_instances: bool
    ):
        if self.params is None:
            self.get_params()

        return super().get_obj_locs(obj_class, threshold, use_uncertainty, return_scores, return_instances)

    @torch.no_grad()
    def get_renders(self, w2c, return_instances, return_uncertainty=True, return_semantics=True, backgrounds=None):
        if self.params is None:
            self.get_params()
        return super().get_renders(w2c, return_instances, return_uncertainty, return_semantics, backgrounds)
    
    def get_eig_path(self, poses: List[torch.Tensor], instance: int, visualize: bool=False, obj_class: int = 2) -> Tuple[torch.Tensor,torch.Tensor]:
        if self.params is None:
            self.get_params()

        return super().get_eig_path(poses, instance, visualize, obj_class)
        
    def get_expected_score(self, poses: List[torch.Tensor], instance: int, obj_class: int = 1) -> Tuple[torch.Tensor,torch.Tensor]:
        if self.params is None:
            self.get_params()

        return super().get_expected_score(poses, instance, obj_class)


    @torch.no_grad()
    def save_model(self, episode_key: str):
        self.backend_queue.put(['save', episode_key])

    def rendering_function(self, data, colors, render_mode, backgrounds=None):
        if self.params is None:
            self.get_params()
        return super().rendering_function(data, colors, render_mode, backgrounds)
