from omegaconf import DictConfig
import torch
import numpy as np
import pyrealsense2 as rs
import zmq
import cv2
import torch
from einops import rearrange, repeat
from torchvision.transforms import Resize

from algorithms.diffusion_forcing.df_robot import DiffusionForcingRobot
from algorithms.diffusion_forcing.df_prediction import DiffusionForcingPrediction
from utils.robot_utils import unpack_to_1d


class DiffusionPolicy(DiffusionForcingRobot):
    def _preprocess_batch(self, batch):
        xs, conditions, masks, init_z = super()._preprocess_batch(batch)
        original_t = len(xs)
        xs = torch.stack([xs[:-1], xs[1:]])
        xs = rearrange(xs, "m t b c ... -> m (t b) c ...")
        masks = masks[:-1] * masks[1:]
        masks = torch.stack([torch.zeros_like(masks), masks])
        masks = repeat(masks, "m t b -> m (t b) c", c=self.x_shape[0]).clone()
        masks[:, :, : 3 * self.n_cameras] = 0
        conditions = [None, None]
        init_z = repeat(init_z, "b c ... -> (t b) c ...", t=original_t - 1)
        return xs, conditions, masks, init_z

    def training_step(self, batch, batch_idx):
        DiffusionForcingPrediction.training_step(self, batch, batch_idx)

    def on_validation_epoch_end(self, namespace="validation"):
        DiffusionForcingPrediction.on_validation_epoch_end(self, namespace)

    @torch.no_grad()
    def test_step(self, batch, batch_idx):
        if self.frame_stack > 1:
            raise NotImplementedError("frame_stack > 1 not implemented for robot dataset")
        dummy_x = batch[0]
        max_steps = dummy_x.shape[1]
        self.maybe_reset_cameras()
        self.maybe_reset_socket()
        n_cameras = len(self.cameras)
        print(f"Detected {n_cameras} cameras")

        resize = Resize(self.x_shape[-2:], antialias=True)

        while True:
            action_stack = len(self.data_mean) - n_cameras * 3
            a_pad = dummy_x[0][0][-action_stack:]  # take action padding from data
            a_pad = a_pad[None]

            for _ in range(max_steps):
                z = torch.zeros(1, *self.z_shape)
                z = z.to(self.device)

                # wait for robot request
                message = self.socket.recv()
                if message == b"stop":
                    self.socket.send(b"restarting")
                    print("Received stop message. Restarting...")
                    break
                else:
                    print(f"Received request: {message.decode()}")

                # read cameras
                o = []
                for cam in self.cameras:
                    frame = cam.wait_for_frames()
                    o.append(np.array(frame.get_color_frame().get_data()))
                if self.debug:
                    cv2.imwrite("debug.png", o[0])
                    input("Checkout debug.png to verify camera order is same as training data. ")

                o = torch.from_numpy(np.stack(o) / 255.0).float().permute(0, 3, 1, 2)
                o = resize(o).to(self.device)
                o = rearrange(o, "n c h w -> 1 (n c) h w")  # (n_cam * 3, h, w)
                x = torch.cat([o, a_pad], 1)
                x = self._normalize_x(x)

                # update posterior
                z, _, _, _ = self.transition_model(z, x, None, deterministic_t=0)

                # predict next step
                _, x_pred, _ = self.transition_model.rollout(z, None)
                x_pred = self._unnormalize_x(x_pred)
                _, a = torch.split(x_pred, [3 * n_cameras, action_stack], 1)

                # send action
                actions = a[0].cpu().numpy()
                actions = [unpack_to_1d(sub) for sub in actions]
                actions = np.stack([np.concatenate([pos, quat, [grasp]]) for pos, quat, grasp in actions])
                message = actions.astype(np.float32).tobytes()
                self.socket.send(message)

            # input("Press Enter to restart")
