import os

import numpy as np
import torch
from PIL import Image
from torchvision import transforms

from configs.state_vec import STATE_VEC_IDX_MAPPING
from models.multimodal_encoder.siglip_encoder import SiglipVisionTower
from models.multimodal_encoder.t5_encoder import T5Embedder
from models.rdt_runner import RDTRunner

MANISKILL_INDICES = [STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"]
                     for i in range(7)] + [STATE_VEC_IDX_MAPPING[f"right_gripper_open"]]


def create_model(args, pretrained, **kwargs):
    model = RoboticDiffusionTransformerModel(args, **kwargs)
    if pretrained is not None:
        model.load_pretrained_weights(pretrained)
    return model


DATA_STAT = {
    "state_min": [
        -0.7463043928146362,
        -0.0801204964518547,
        -0.4976441562175751,
        -2.657780647277832,
        -0.5742632150650024,
        1.8309762477874756,
        -2.2423808574676514,
        0.0,
    ],
    "state_max": [
        0.7645499110221863,
        1.4967026710510254,
        0.4650936424732208,
        -0.3866899907588959,
        0.5505855679512024,
        3.2900545597076416,
        2.5737812519073486,
        0.03999999910593033,
    ],
    "action_min": [
        -0.7472005486488342,
        -0.08631071448326111,
        -0.4995281398296356,
        -2.658363103866577,
        -0.5751323103904724,
        1.8290787935256958,
        -2.245187997817993,
        -1.0,
    ],
    "action_max": [
        0.7654682397842407,
        1.4984270334243774,
        0.46786263585090637,
        -0.38181185722351074,
        0.5517147779464722,
        3.291581630706787,
        2.575840711593628,
        1.0,
    ],
}


class RoboticDiffusionTransformerModel(object):
    """A wrapper for the RDT model, which handles
    1. Model initialization
    2. Encodings of instructions
    3. Model inference
    """

    def __init__(
        self,
        args,
        device="cuda",
        dtype=torch.bfloat16,
        image_size=None,
        control_frequency=25,
        pretrained_text_encoder_name_or_path=None,
        pretrained_vision_encoder_name_or_path=None,
    ):
        self.args = args
        self.dtype = dtype
        self.image_size = image_size
        self.device = device
        self.control_frequency = control_frequency
        self.text_tokenizer, self.text_model = self.get_text_encoder(pretrained_text_encoder_name_or_path)
        self.image_processor, self.vision_model = self.get_vision_encoder(pretrained_vision_encoder_name_or_path)
        self.policy = self.get_policy()

        self.state_min = torch.tensor(DATA_STAT["state_min"]).to(device)
        self.state_max = torch.tensor(DATA_STAT["state_max"]).to(device)
        self.action_min = torch.tensor(DATA_STAT["action_min"]).to(device)
        self.action_max = torch.tensor(DATA_STAT["action_max"]).to(device)

        self.reset()

    def get_policy(self):
        """Initialize the model."""
        # Initialize model with arguments
        img_cond_len = (self.args["common"]["img_history_size"] * self.args["common"]["num_cameras"] *
                        self.vision_model.num_patches)

        _model = RDTRunner(
            action_dim=self.args["common"]["state_dim"],
            pred_horizon=self.args["common"]["action_chunk_size"],
            config=self.args["model"],
            lang_token_dim=self.args["model"]["lang_token_dim"],
            img_token_dim=self.args["model"]["img_token_dim"],
            state_token_dim=self.args["model"]["state_token_dim"],
            max_lang_cond_len=self.args["dataset"]["tokenizer_max_length"],
            img_cond_len=img_cond_len,
            img_pos_embed_config=[
                # No initial pos embed in the last grid size
                # since we've already done in ViT
                (
                    "image",
                    (
                        self.args["common"]["img_history_size"],
                        self.args["common"]["num_cameras"],
                        -self.vision_model.num_patches,
                    ),
                ),
            ],
            lang_pos_embed_config=[
                # Similarly, no initial pos embed for language
                ("lang", -self.args["dataset"]["tokenizer_max_length"]),
            ],
            dtype=self.dtype,
        )

        return _model

    def get_text_encoder(self, pretrained_text_encoder_name_or_path):
        text_embedder = T5Embedder(
            from_pretrained=pretrained_text_encoder_name_or_path,
            model_max_length=self.args["dataset"]["tokenizer_max_length"],
            device=self.device,
        )
        tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
        return tokenizer, text_encoder

    def get_vision_encoder(self, pretrained_vision_encoder_name_or_path):
        vision_encoder = SiglipVisionTower(vision_tower=pretrained_vision_encoder_name_or_path, args=None)
        image_processor = vision_encoder.image_processor
        return image_processor, vision_encoder

    def reset(self):
        """Set model to evaluation mode."""
        device = self.device
        weight_dtype = self.dtype
        self.policy.eval()
        self.text_model.eval()
        self.vision_model.eval()

        self.policy = self.policy.to(device, dtype=weight_dtype)
        self.text_model = self.text_model.to(device, dtype=weight_dtype)
        self.vision_model = self.vision_model.to(device, dtype=weight_dtype)

    def load_pretrained_weights(self, pretrained=None):
        if pretrained is None:
            return
        print(f"Loading weights from {pretrained}")
        filename = os.path.basename(pretrained)
        if filename.endswith(".pt"):
            checkpoint = torch.load(pretrained)
            self.policy.load_state_dict(checkpoint["module"])
        elif filename.endswith(".safetensors"):
            from safetensors.torch import load_model

            load_model(self.policy, pretrained)
        else:
            raise NotImplementedError(f"Unknown checkpoint format: {pretrained}")

    def encode_instruction(self, instruction, device="cuda"):
        """Encode string instruction to latent embeddings.

        Args:
            instruction: a string of instruction
            device: a string of device

        Returns:
            pred: a tensor of latent embeddings of shape (text_max_length, 512)
        """
        tokens = self.text_tokenizer(instruction, return_tensors="pt", padding="longest",
                                     truncation=True)["input_ids"].to(device)

        tokens = tokens.view(1, -1)
        with torch.no_grad():
            pred = self.text_model(tokens).last_hidden_state.detach()

        return pred

    def _format_joint_to_state(self, joints):
        """
        Format the robot joint state into the unified state vector.

        Args:
            joints (torch.Tensor): The joint state to be formatted.
                qpos ([B, N, 14]).

        Returns:
            state (torch.Tensor): The formatted state for RDT ([B, N, 128]).
        """
        # Rescale the gripper
        # joints = joints / torch.tensor(
        #     [[[1, 1, 1, 1, 1, 1, 4.7908, 1, 1, 1, 1, 1, 1, 4.7888]]],
        #     device=joints.device, dtype=joints.dtype
        # )

        # normalize to -1,1
        joints = (joints - self.state_min) / (self.state_max - self.state_min) * 2 - 1
        B, N, _ = joints.shape
        state = torch.zeros(
            (B, N, self.args["model"]["state_token_dim"]),
            device=joints.device,
            dtype=joints.dtype,
        )
        # assemble the unifed state vector
        state[:, :, MANISKILL_INDICES] = joints
        state_elem_mask = torch.zeros(
            (B, self.args["model"]["state_token_dim"]),
            device=joints.device,
            dtype=joints.dtype,
        )
        state_elem_mask[:, MANISKILL_INDICES] = 1
        return state, state_elem_mask

    def _unformat_action_to_joint(self, action):
        action_indices = MANISKILL_INDICES
        joints = action[:, :, action_indices]

        # denormalize to action space

        joints = (joints + 1) / 2 * (self.action_max - self.action_min) + self.action_min

        return joints

    @torch.no_grad()
    def step(self, proprio, images, text_embeds):
        """
        Args:
            proprio: proprioceptive states
            images: RGB images
            text_embeds: instruction embeddings

        Returns:
            action: predicted action
        """
        device = self.device
        dtype = self.dtype

        background_color = np.array([int(x * 255) for x in self.image_processor.image_mean],
                                    dtype=np.uint8).reshape(1, 1, 3)
        background_image = (np.ones(
            (
                self.image_processor.size["height"],
                self.image_processor.size["width"],
                3,
            ),
            dtype=np.uint8,
        ) * background_color)

        image_tensor_list = []
        for image in images:
            if image is None:
                # Replace it with the background image
                image = Image.fromarray(background_image)

            if self.image_size is not None:
                image = transforms.Resize(self.data_args.image_size)(image)

            if self.args["dataset"].get("auto_adjust_image_brightness", False):
                pixel_values = list(image.getdata())
                average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3)
                if average_brightness <= 0.15:
                    image = transforms.ColorJitter(brightness=(1.75, 1.75))(image)

            if self.args["dataset"].get("image_aspect_ratio", "pad") == "pad":

                def expand2square(pil_img, background_color):
                    width, height = pil_img.size
                    if width == height:
                        return pil_img
                    elif width > height:
                        result = Image.new(pil_img.mode, (width, width), background_color)
                        result.paste(pil_img, (0, (width - height) // 2))
                        return result
                    else:
                        result = Image.new(pil_img.mode, (height, height), background_color)
                        result.paste(pil_img, ((height - width) // 2, 0))
                        return result

                image = expand2square(image, tuple(int(x * 255) for x in self.image_processor.image_mean))
            image = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
            image_tensor_list.append(image)

        image_tensor = torch.stack(image_tensor_list, dim=0).to(device, dtype=dtype)

        image_embeds = self.vision_model(image_tensor).detach()
        image_embeds = image_embeds.reshape(-1, self.vision_model.hidden_size).unsqueeze(0)

        # history of actions
        joints = proprio.to(device).unsqueeze(0)  # (1, 1, 14)
        states, state_elem_mask = self._format_joint_to_state(joints)  # (1, 1, 128), (1, 128)
        states, state_elem_mask = states.to(device, dtype=dtype), state_elem_mask.to(device, dtype=dtype)
        states = states[:, -1:, :]  # (1, 1, 128)
        ctrl_freqs = torch.tensor([self.control_frequency]).to(device)

        text_embeds = text_embeds.to(device, dtype=dtype)

        trajectory = self.policy.predict_action(
            lang_tokens=text_embeds,
            lang_attn_mask=torch.ones(text_embeds.shape[:2], dtype=torch.bool, device=text_embeds.device),
            img_tokens=image_embeds,
            state_tokens=states,
            action_mask=state_elem_mask.unsqueeze(1),
            ctrl_freqs=ctrl_freqs,
        )
        trajectory = self._unformat_action_to_joint(trajectory).to(torch.float32)

        return trajectory
