import pdb
from typing_extensions import override

import torch
import torch.nn as nn
import numpy as np
from openpi_client import image_tools
from openpi_client import websocket_client_policy
try:
    from torchvision.transforms import v2 as tvf
except Exception:
    import torchvision.transforms as tvf  # Bamboo server use torchvision==0.14.1

from robobase.models import RoboBaseModule
from robobase.method.bc import BC
from robobase.replay_buffer.replay_buffer import ReplayBuffer
from robobase.envs.wrappers import (
    RescaleFromTanhWithMinMax,
    OnehotTime,
    ActionSequence,
    AppendDemoInfo,
    FrameStack,
    ConcatDim,
    RecedingHorizonControl,
)
from robobase.method.utils import (
    extract_from_spec,
    extract_from_batch,
    flatten_time_dim_into_channel_dim,
    stack_tensor_dictionary,
    extract_many_from_batch,
)
from robobase.models.act.backbone import build_backbone, build_film_backbone

task_instructions = {
        "sandwich_remove": "Take the sandwich out of the frying pan.",
        "take_cups": "Take two cups out from the closed wall cabinet and put them on the table.", 
        "put_cups": "Pick up cups from the table and put them into the closed wall cabinet.",
        "dishwasher_open_trays": "Pull out the dishwasher's trays with the door initially open.",
        "cupboards_open_all": "Open all drawers and doors of the kitchen set.",
        "sandwich_flip": "Flip the sandwich in the frying pan using the spatula.",
        "dishwasher_close_trays": "Push the dishwasher's trays back with the door initially open.",
        "drawers_close_all": "Close all sliding drawers of the kitchen cabinet.",
        "dishwasher_close": "Push back all trays and close the door of the dishwasher.",
        "store_box": "Move a large box from the counter to the shelf in the cabinet below.",
        "wall_cupboard_close": "Close doors of the wall cabinet.",
        "dishwasher_open": "Open the dishwasher door and pull out all trays.",
        "move_plate": "Move the plate between two draining racks.",
        "saucepan_to_hob": "Take the saucepan from the closed cabinet and place it on the hob.",
        "flip_cutlery": "Take the cutlery from the static holder, flip it, and place it back into the holder",
        "cupboards_close_all": "Close all drawers and doors of the kitchen set.",
        "pick_box": " Pick up a large box from the floor and place it on the counter.",
        "drawers_open_all": "Open all sliding drawers of the kitchen cabinet.",
        "wall_cupboard_open": "Open doors of the wall cabinet.",
        "sandwich_toast": "Use the spatula to put the sandwich on the frying pan.",
        "flip_cup": "Flip the cup, initially positioned upside down on the table, to an upright position.",
    }

class Pi05BCAgent(BC):
    def __init__(
        self,
        lr_backbone: float = 1e-5,
        weight_decay: float = 1e-4,
        use_lang_cond: bool = False,
        current_task: str = None,
        *args,
        **kwargs,
    ):
        """
        ACT Behavioral Cloning (BC) Agent.

        Args:
            lr (float): Learning rate for the policy.
            lr_backbone (float): Learning rate for the backbone.
            weight_decay (f`loa`t): Weight decay for optimization.
        """
        self.lr_backbone = lr_backbone
        self.weight_decay = weight_decay
        self.current_task = current_task
        super().__init__(*args, **kwargs)

        # sanity check
        assert self.frame_stack_on_channel, "frame_stack_on_channel must be enabled"

        self.pi05_client = websocket_client_policy.WebsocketClientPolicy(host="localhost", port=8000)

        self.train()

    def build_actor(self):
        # NOTE: Encoder returns visual_obs_feat, pos_emb, task_emb, we pass
        # visual_obs_feat shape into actor model constructor
        pass

    def train(self, training=True):
        self.training = training
        pass
    
    def transform_to_pi05_train_raw_obs(self, qpos: dict[str, torch.Tensor], act_env):
        # wrapper env logics
        # proprioception_floating_base_actions: raw
        # 
        # norm_env = act_env.env.env.env.env
        # assert isinstance(norm_env, ConcatDim)

        # p = 0
        # raw_state = []
        # for k in ['proprioception', 'proprioception_floating_base', 'proprioception_grippers']:
        #     mu, sigma = norm_env._obs_stats['mean'][k], norm_env._obs_stats['std'][k]
        #     mu = torch.tensor(mu).to(qpos.device)
        #     sigma = torch.tensor(sigma).to(qpos.device)
        #     cur_key_len = len(mu)
        #     raw_state.append(qpos[:, p:p+cur_key_len]*sigma + mu)
        #     p = p+cur_key_len
        # return torch.cat(raw_state, dim=1)
        return torch.sign(qpos)*torch.log(1+torch.abs(qpos))

    def transform_to_env_action(self, pi05_action: np.ndarray, act_env) -> np.ndarray:
        """transform_to_tanh"""
        # scale_env = act_env.env.env.env.env.env.env
        scale_env = act_env
        while not isinstance(scale_env, RescaleFromTanhWithMinMax):
            scale_env = scale_env.env

        assert isinstance(scale_env, RescaleFromTanhWithMinMax)
        env_action = scale_env.transform_to_tanh(pi05_action, scale_env.action_stats, scale_env.min_max_margin)
        return env_action


    @override
    def act(self, obs: dict[str, torch.Tensor], step: int, eval_mode: bool, act_env):
        qpos = flatten_time_dim_into_channel_dim(
            extract_from_spec(obs, "low_dim_state")
        )
        qpos = self.transform_to_pi05_train_raw_obs(qpos, act_env)

        pi05_state_dim = 66
        pi05_action_dim = act_env.action_space.shape[-1]
        env_state_dim = qpos.shape[-1]
        #
        
        qpos = qpos.detach()
        # pi05 cat sytle
        # pro, pro_floating_base, pro_grippers
        if env_state_dim < pi05_state_dim:
            state_pi05 =  torch.cat([qpos, 
                                    torch.zeros_like(qpos[:, :pi05_state_dim - env_state_dim]), # (B, 70)
                                    ], dim=-1) # (B, 70)
            assert state_pi05.shape[-1] == pi05_state_dim
        else:
            state_pi05 = qpos
        assert state_pi05.shape[-1] == pi05_state_dim

        rgb_head = obs['rgb_head'][:,0].permute(0, 2, 3, 1).detach()  # (B, 84, 84, 3)
        rgb_left_wrist = obs['rgb_left_wrist'][:,0].permute(0, 2, 3, 1).detach()  # (B, 84, 84, 3)
        rgb_right_wrist = obs['rgb_right_wrist'][:,0].permute(0, 2, 3, 1).detach()
        
        ###
        #s
        #
        prompt = task_instructions[self.current_task]

        # {
        #         "observation/state": np.random.rand(70),
        #         "observation/rgb_head": np.random.randint(256, size=(84, 84, 3), dtype=np.uint8),
        #         "observation/rgb_left_wrist": np.random.randint(256, size=(84, 84, 3), dtype=np.uint8),
        #         "observation/rgb_right_wrist": np.random.randint(256, size=(84, 84, 3), dtype=np.uint8),
        #         "prompt": "do something",
        #     }
        observation = {
            "observation/state": state_pi05.cpu().numpy()[0],
            "observation/rgb_head": rgb_head.cpu().numpy()[0],
            "observation/rgb_left_wrist": rgb_left_wrist.cpu().numpy()[0],
            "observation/rgb_right_wrist": rgb_right_wrist.cpu().numpy()[0],
            "prompt": prompt,
        }

        action_chunk = self.pi05_client.infer(observation)["actions"]
        action_chunk = action_chunk[:, :pi05_action_dim]

        
        ################# important #################
        # action_chunk = self.transform_to_env_action(action_chunk, act_env)
        ################# important #################

        return action_chunk
    
    def sample(self, obs: dict[str, torch.Tensor], step: int, eval_mode: bool):
        if self.low_dim_size > 0:
            qpos = flatten_time_dim_into_channel_dim(
                extract_from_spec(obs, "low_dim_state").unsqueeze(0)
            )
            qpos = qpos.detach()
            # qpos[:, 30:60] = 0 # mask qvel

        if self.use_pixels:
            rgb = flatten_time_dim_into_channel_dim(
                stack_tensor_dictionary(extract_many_from_batch(obs, r"rgb.*"), 0).unsqueeze(0),
                has_view_axis=True,
            )
            image = rgb.float().detach()

        samples = self.actor(qpos, image, is_sample = True)

        return samples

    def load_state_dict(self, state_dict):
        pass
    
    @override
    def update(
        self, replay_iter, step: int, replay_buffer: ReplayBuffer = None
    ) -> dict:
        """
        Update the agent's policy using behavioral cloning.

        Args:
            replay_iter (iterable): An iterator over a replay buffer.
            step (int): The current step.
            replay_buffer (ReplayBuffer): The replay buffer.

        Returns:
            dict: Dictionary containing training metrics.

        """
        return None

    def reset(self, step: int, agents_to_reset: list[int]):
        pass  # TODO: Implement LSTM support.
