import os
import matplotlib
matplotlib.use('Agg')
import torch
import torch.multiprocessing as mp


from helios.agent.gaze.gaussian_representation.backend import BackEnd

class BackEndInterface(BackEnd, mp.Process):
    def __init__(self, config, max_frames, n_semantic_channels, backend_queue, frontend_queue):
        super().__init__()
        self.backend_queue = backend_queue
        self.frontend_queue = frontend_queue

        self.num_frames = max_frames
        self.config = config
        self.n_semantic_channels = n_semantic_channels

        self.time_idx = 0

        self.use_instances = config.instances.use_instance

        self.pause = False

        self.params = None
        self.variables = None

        # Initialize list to keep track of Keyframes
        self.keyframe_list = []
        self.keyframe_time_indices = []

        # Init Variables to keep track of ground truth poses and runtimes
        self.gt_w2c_all_frames = []

        self.checkpoint_time_idx = 0
        self.output_dir = os.path.join(config["workdir"], 'objects')

        self.device = torch.device(config["primary_device"])
        self.prev_len = 0

    def reset(self):
        self.keyframe_list = []
        self.keyframe_time_indices = []
        self.gt_w2c_all_frames = []
        self.params = None
        self.variables = None
        self.time_idx = 0

        self.pause = False
        self.checkpoint_time_idx = 0
        self.prev_len = 0

    def run(self):
        while True:
            if self.backend_queue.empty():
                if self.pause:
                    continue

                if len(self.keyframe_list) == 0:
                    continue

                # refine
                self.refine()

            else:
                data = self.backend_queue.get()
                if data[0] == "stop":
                    return
                elif data[0] == "pause":
                    self.pause = True
                elif data[0] == "unpause":
                    self.pause = False
                elif data[0] == "init":
                    self.reset()
                    self.first_frame(data[1])
                    self.prev_len = self.params["semantic_c"].shape[0]
                    self.frontend_queue.put(["finished init", self.first_frame_w2c,self.intrinsics.detach().clone()])
                elif data[0] == "step":
                    self.step(data[1],data[2])
                    self.prev_len = self.params["semantic_c"].shape[0]
                    self.frontend_queue.put(["finished step"])
                elif data[0] == "save":
                    self.save_model(data[1])
                elif data[0] == "get_params":
                    params_to_return = {}
                    if not self.params is None:
                        for k, v in self.params.items():
                            v2 = v.data
                            v2.requires_grad = False
                            params_to_return[k] = v2
                    self.frontend_queue.put(["params",params_to_return])
                else:
                    raise Exception("Unprocessed data", data)

    
