import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import torchvision.transforms as transforms

from .visual_model.vit_backbone import ViTBackbone
from .visual_model.resnet_encoder import ResnetIMPALA
from .visual_model.image_modules import ViTModelRGBD, ImageDecoder
from .prompt_modules import masked_modeling_attention_mask
# from .visual_model.dragan import Discriminator
from .diffusion.ddpm import GaussianDiffusion


class StateEncoder(pl.LightningModule):
    def __init__(self, config, bbox_encoder_n_layers = 3):
        super().__init__()

        self.config = config
        self.feature_list = config.feature_list

        self.image_size = config.image_size # 352
        self.crop_size = config.crop_size # 64
        self.patch_size = config.patch_size # 64

        # ----- Base and gripper images -----
        if self.image_size % self.patch_size == 0:
            pad_width = 0
        else:
            pad_width = self.patch_size - self.image_size % self.patch_size
        self.pad_transform = transforms.Pad((0, 0, pad_width, pad_width)) # left, top, right, bottom
        self.padded_image_size = self.image_size + pad_width
        # ----- Cropped images -----
        if self.crop_size % self.patch_size == 0:
            crop_pad_width = 0
        else:
            crop_pad_width = self.patch_size - self.crop_size % self.patch_size
        self.crop_pad_transform = transforms.Pad((0, 0, crop_pad_width, crop_pad_width)) # left, top, right, bottom
        self.padded_crop_size = self.crop_size + crop_pad_width

        if config.image_encoder in ["vit-patches", "vit-cls"]:

            self.image_encoder = ViTBackbone(
                resolution=self.image_size + pad_width,
                patch_size=self.patch_size,
                width=config.n_embd,
                layers=4,
                heads=24,
                output_dim=config.n_embd,
            )
        elif config.image_encoder == "vit-pretrained":
            self.image_encoder = ViTModelRGBD(input_channels = 4, 
                                              image_size = self.image_size,
                                              hidden_size = config.n_embd,
                                              pretrain = True, 
                                              frozen = True, 
                                              depth_channel = "embed")

        elif config.image_encoder == "resnet":
            self.base_encoder = ResnetIMPALA(input_shape=(self.get_channel_num("base"), self.image_size, self.image_size), hidden_size = config.n_embd)
            # self.grip_encoder = ResnetIMPALA(input_shape=(self.get_channel_num("hand"), self.image_size, self.image_size), hidden_size = config.n_embd)
            self.grip_encoder = self.base_encoder
            self.crop_encoder = ResnetIMPALA(input_shape=(self.get_channel_num("object"), self.crop_size, self.crop_size), hidden_size = config.n_embd)
        else:
            raise ValueError("image_encoder_name should be one of ['vit-patches', 'vit-cls', 'vit-pretrained', 'resnet']")

        self.bbox_encoder = nn.ModuleList()
        self.bbox_encoder.append(nn.Linear(4, config.n_embd))
        self.bbox_encoder.append(nn.ReLU6())
        for _ in range(bbox_encoder_n_layers - 2):
            self.bbox_encoder.append(nn.Linear(config.n_embd, config.n_embd))
            self.bbox_encoder.append(nn.ReLU())
        self.bbox_encoder.append(nn.Linear(config.n_embd, config.n_embd))
        self.crop_mlp = nn.Linear(config.n_embd + config.n_embd, config.n_embd)

        # # Add a trainable token embedding for state end token
        # self.add_special_token = config.add_special_token
        # if config.add_special_token:
        #     self.state_token_embed = nn.Embedding(1, config.n_embd)

    def get_channel_num(self, name):
        for feat in self.feature_list:
            if name in feat.name:
                return feat.channels
        raise ValueError(f"feature_list does not contain {name}")

    def encode_image(self, image, bbox = None, image_type = None, noise = None):
        """
        Encode the image into a sequence of image tokens using ViT
        Pretrain targets are the image patches
        The amount of patches should be the same as the amount of tokens

        image_type == base or grip:
            image: (B, S, 4, H, W)
            bbox: None
        image_type == crop:
            image: (B, S, 4, H', W')
            bbox: (B, S, 4)

        Return:
            ViT-patches:
                base or grip:
                    image_embeddings: (B, S, patch_num**2, hidden)
                crop:
                    image_embeddings: (B, S, patch_num**2, hidden)

                msr_target: (B, S, patch_num**2, patch_size*patch_size*4)
                Note: patch_num is the number of patches in each dimension

            ViT-cls or ResNet:
                base or grip:
                    image_embeddings: (B, S, 1, hidden)
                crop:
                    image_embeddings: (B, S, 1, hidden)

                msr_target: (B, S, 4 * padded_image_size * padded_image_size)
        """

        def recover_pixel_values(img):
            # Map pixel values to [0, 1], which is compatible with ReLU
            return img * 0.225 + 0.45

        # ----------------- Encode the image -----------------
        B, S, _, _, _ = image.shape
        image = recover_pixel_values(image)
        image = image.flatten(end_dim = 1) # (B*S, 4, H, W)

        if image_type is None or image_type in ["base", "hand"]:
            if self.config.image_encoder == "vit-patches":
                image_padded = self.pad_transform(image)
                _, hidden_states = self.image_encoder(image_padded)
                image_embeddings = hidden_states[:, 1:, :].view(B, S, -1, hidden_states.size(-1)) # (B, S, patch_num**2, hidden)
            elif self.config.image_encoder == "vit-cls":
                image_padded = self.pad_transform(image)
                cls_hidden_states, _ = self.image_encoder(image_padded)
                image_embeddings = cls_hidden_states.view(B, S, -1, cls_hidden_states.size(-1)) # (B, S, 1, hidden)
            elif self.config.image_encoder == "vit-pretrain":
                image_padded = image
                image_embeddings = self.image_encoder(image_padded).view(B, S, -1, self.config.n_embd)
            elif self.config.image_encoder == "resnet":
                image_padded = image
                if image_type is None or image_type == "base":
                    image_embeddings = self.base_encoder(image_padded).view(B, S, 1, -1) # (B, S, 1, hidden)
                elif image_type == "hand":
                    image_embeddings = self.grip_encoder(image_padded).view(B, S, 1, -1) # (B, S, 1, hidden)

        elif image_type == "crop":
            if self.config.image_encoder == "vit-patches":
                image_padded = self.crop_pad_transform(image)
                _, hidden_states = self.image_encoder(image_padded)
                image_embeddings = hidden_states[:, 1:, :].view(B, S, -1, hidden_states.size(-1)) # (B, S, patch_num**2, hidden)
            elif self.config.image_encoder == "vit-cls":
                image_padded = self.crop_pad_transform(image)
                cls_hidden_states, _ = self.image_encoder(image_padded)
                image_embeddings = cls_hidden_states.view(B, S, -1, cls_hidden_states.size(-1)) # (B, S, 1, hidden)
            elif self.config.image_encoder == "vit-pretrain":
                image_padded = image
                image_embeddings = self.image_encoder(image_padded).view(B, S, -1, self.config.n_embd)
            elif self.config.image_encoder == "resnet":
                image_padded = image
                image_embeddings = self.crop_encoder(image_padded).view(B, S, 1, -1) # (B, S, 1, hidden)

            # bbox_normalizer = torch.tensor([self.image_size, self.image_size, self.image_size, self.image_size], dtype = bbox.dtype, device = bbox.device)
            # bbox = bbox / bbox_normalizer # Normalize the bbox coordinates (B, S, 4)
            # bbox_embed = self.bbox_encoder(bbox).unsqueeze(2) # (B, S, 1, hidden)
            # image_embeddings = image_embeddings + bbox_embed # (B, S, patch_num**2 or 1, hidden)
            for i, layer in enumerate(self.bbox_encoder):
                if i == 0:
                    bbox_embed = layer(bbox)
                else:
                    bbox_embed = layer(bbox_embed)
            bbox_embed = bbox_embed.unsqueeze(2) # (B, S, 1, hidden)
            image_embeddings = self.crop_mlp(torch.cat([image_embeddings, bbox_embed], dim = -1)) # (B, S, patch_num**2 or 1, hidden)

        else:
            raise ValueError("image_type should be one of [None, 'base', 'hand', 'crop']")

        # ----------------- Construct the MSR target -----------------
        if self.training:
            # image_padded: (B*S, 4, padded_image_size, padded_image_size)
            if self.config.image_encoder == "vit-patches":
                # Unfold the image into patches
                patches = image_padded.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
                # patches: (B*S, 4, patch_num, patch_num, patch_size, patch_size)
                patch_num = patches.size(2)
                patches = torch.permute(patches, (0, 2, 3, 1, 4, 5))
                # patches: (B*S, patch_num, patch_num, 4, patch_size, patch_size)
                msr_target = patches.reshape(B, S, patch_num**2, -1)
            elif self.config.image_encoder == "vit-cls":
                # Use the whole image as the MSR target when using ViT-CLS and ResNet
                # Unflat the image's B, S dimensions
                msr_target = image_padded.view(B, S, 1, -1)
            elif self.config.image_encoder == "vit-pretrain" or self.config.image_encoder == "resnet":
                # Use the whole image as the MSR target when using ViT-CLS and ResNet
                if image_type is None or image_type in ["base", "hand"]:
                    image_padded = self.pad_transform(image)
                    if noise is not None:
                        noise_padded = self.pad_transform(noise)
                        msr_target = [image_padded.view(B, S, 1, -1), noise_padded.view(B, S, 1, -1)]
                    else:
                        # Unflat the image's B, S dimensions
                        msr_target = [image_padded.view(B, S, 1, -1)]
                elif image_type == "crop":
                    image_padded = self.crop_pad_transform(image)
                    # Unflat the image's B, S dimensions
                    msr_target = [image_padded.view(B, S, 1, -1)]

            # msr_target = recover_pixel_values(msr_target)

        else:
            # No need to calculate MSR target when inference
            msr_target = None

        return image_embeddings, msr_target

    def forward(self, batch_inputs):
        """
        batch_inputs['base_camera_rgbd'] : (B, S, C, H, W)
        batch_inputs['hand_camera_rgbd'] : (B, S, C, H, W)
        batch_inputs['pick_object_image_bbox'], batch_inputs['place_object_image_bbox'] :
            Each is a dict: {
                "cropped_image": (B, S, C, H', W')
                "bbox": (B, S, 4)
            }

        Return:
            All lists follow the order: [base, crop1, crop2, grip]

            embed_list: [(B, S, patch_num**2 or 1, hidden), ...]
            attn_mask_list: [(B, S, patch_num**2 or 1), ...]
            msr_target_list: 
                ViT-patches:
                    msr_target: (B * S * patch_num**2, C * patch_size*patch_size)
                ViT-cls or ResNet:
                    msr_target: (B * S, C * padded_image_size*padded_image_size)
        """

        B, S = batch_inputs[self.feature_list[0].name].shape[:2]

        embed_list = []
        attn_mask_list = []
        msr_target_list = []
        for feature in self.feature_list:
            if "base_camera" in feature.name:
                if type(batch_inputs[feature.name]) is not list:
                    base_embed, base_msr_target = self.encode_image(batch_inputs[feature.name], image_type = "base")
                    base_attn_mask = torch.ones((B, S, base_embed.size(2),), device = self.device) # (B, S, patch_num**2 or 1)
                    embed_list.append(base_embed)
                    attn_mask_list.append(base_attn_mask)
                    msr_target_list.append(base_msr_target)
                else:
                    # batch_inputs[feature.name] is a list of [image, noise]
                    base_embed, base_msr_target = self.encode_image(batch_inputs[feature.name][0], image_type = "base", noise = batch_inputs[feature.name][1])
            elif "hand_camera" in feature.name:
                grip_embed, grip_msr_target = self.encode_image(batch_inputs[feature.name], image_type = "hand")
                grip_attn_mask = torch.ones((B, S, grip_embed.size(2),), device = self.device)
                embed_list.append(grip_embed)
                attn_mask_list.append(grip_attn_mask)
                msr_target_list.append(grip_msr_target)
            elif "object_image" in feature.name:
                crop_embed, crop_msr_target = self.encode_image(batch_inputs[feature.name]['cropped_image'], 
                                                                bbox = batch_inputs[feature.name]['bbox'], 
                                                                image_type = "crop")

                # If bbox is all -1, then the image is invalid and its attn_mask should be 0
                crop_attn_mask = (batch_inputs[feature.name]['bbox'] != -1).any(dim = -1, keepdim = True) # (B, S, 1)
                # Repeat attention mask for patch_num**2 or 1 times
                crop_attn_mask = crop_attn_mask.repeat(1, 1, crop_embed.size(2)) # (B, S, patch_num**2 or 1)

                embed_list.append(crop_embed)
                attn_mask_list.append(crop_attn_mask)
                msr_target_list.append(crop_msr_target)
            else:
                # Feature can be curr_action or prev_action
                continue

        # if self.add_special_token:
        #     state_end_ids = torch.zeros((B, S, 1), dtype = torch.long, device = self.device)
        #     state_end_embed = self.state_token_embed(state_end_ids) # (B, S, 1, hidden)
        #     embed_list.append(state_end_embed)
        #     attn_mask_list.append(torch.ones((B, S, 1), device = self.device))
        #     msr_target_list.append(None)

        return embed_list, attn_mask_list, msr_target_list


class MaskedStateReconstructionHead(pl.LightningModule):
    def __init__(self, config):
        super().__init__()

        self.config = config
        self.hidden_size = config.n_embd
        self.log_image_num = 10
        self.image_decoder = ImageDecoder(config)
        self.diffusion = GaussianDiffusion(timesteps = 10)

    def prepare_inputs(self, state_attn_mask_list, mask_ratio):
        msr_loss_mask_list = []
        for i, image_attn_mask in enumerate(state_attn_mask_list):
            cur_mask_ratio = mask_ratio
            if self.config.image_encoder in ["vit-cls", "resnet"] and i != 0:
                # Only mask base image (i==0) b/c we only have a MSR head for reconstructing the base image
                cur_mask_ratio = 0.
            image_attn_mask_masked, msr_loss_mask = masked_modeling_attention_mask(image_attn_mask, cur_mask_ratio)
            state_attn_mask_list[i] = torch.ones_like(image_attn_mask_masked)
            msr_loss_mask_list.append(torch.ones_like(msr_loss_mask)) # For diffusion, mask no token, but also predict all tokens

        return state_attn_mask_list, msr_loss_mask_list

    def prepare_noisy_image(self, batch_inputs, t_index = 0):
        for feature in self.feature_list:
            if "base_camera" in feature.name:
                image, noise = self.diffusion.q_sample(batch_inputs[feature.name], t_index = t_index)
                batch_inputs[feature.name] = [image, noise]
            else:
                continue
        return batch_inputs

    def forward(self, hidden_states, msr_target_list, msr_loss_mask_list, segment_lengths):
        B = segment_lengths['batch_size']
        B_prime = segment_lengths['batch_size_prime']
        S = segment_lengths['timesteps']
        prompt_length = segment_lengths['prompt_length']

        # Remove prompt hidden states
        sa_hidden_states = hidden_states[:, prompt_length:, :].view(B_prime, S, -1, hidden_states.size(-1)) # (B_prime, S, state_len + action_len, hidden)
        # Remove action hidden states
        state_hidden_states = sa_hidden_states[:B, :, :segment_lengths["state_length"], :] # (B, S, state_len, hidden)

        if self.config.image_encoder == "vit-patches":
            # Extract the corresponding hidden states and calculate the predicted image
            msr_loss_mask = torch.cat(msr_loss_mask_list, dim = 2)
            masked_state_hidden_states = torch.masked_select(state_hidden_states, msr_loss_mask.unsqueeze(-1).bool())
            # Construct training targets and calculate loss
            target_image_flat = torch.cat(msr_target_list, dim = 2)

        elif self.config.image_encoder in ["vit-cls", "resnet"]:
            # Extract the corresponding hidden states and calculate the predicted image
            msr_loss_mask = msr_loss_mask_list[0]
            base_hidden_states = state_hidden_states[:, :, 0, :] # (B, S, hidden)
            masked_state_hidden_states = torch.masked_select(base_hidden_states, msr_loss_mask.bool())
            # Only extract the base images
            target_image_flat, target_noise_flat = msr_target_list[0]

        pred_noise_flat = self.image_decoder(masked_state_hidden_states)
        msr_target = torch.masked_select(target_noise_flat, msr_loss_mask.unsqueeze(-1).bool())
        msr_target = msr_target.view(-1, target_noise_flat.size(-1))
        loss_diffusion = self.diffusion.p_losses(pred_noise_flat, msr_target, loss_type="huber")

        # Return a dict that contain the target image and the predicted image, only use the rgb channels
        out_channels = self.config.feature_list[0].channels
        image_size = int((target_noise_flat.size(-1) / out_channels) ** .5)
        pred_noise = pred_noise_flat.view(-1, out_channels, image_size, image_size)[:, :3, ...]
        target_image = target_image_flat.view(-1, out_channels, image_size, image_size)[:, :3, ...]
        target_noise = target_noise_flat.view(-1, out_channels, image_size, image_size)[:, :3, ...]

        compare_imgs = {"pred_noise": pred_noise,
                        "target_noise": target_noise,
                        "noisy_image": target_image}

        return loss_diffusion, compare_imgs

    # def train_gan(self, compare_imgs):
    #     # train discriminator
    #     # Measure discriminator's ability to classify real from generated samples
    #     pred_noise, target_noise = compare_imgs["pred_noise"], compare_imgs["target_noise"]

    #     # how well can it label as real?
    #     valid = torch.ones(target_noise.size(0), 1)
    #     valid = valid.type_as(target_noise)
    #     real_loss = F.binary_cross_entropy(self.discriminator(target_noise), valid)

    #     # how well can it label as fake?
    #     fake = torch.zeros(target_noise.size(0), 1)
    #     fake = fake.type_as(target_noise)
    #     fake_loss = F.binary_cross_entropy(self.discriminator(pred_noise.detach()), fake)

    #     # discriminator loss is the average of these
    #     d_loss = (real_loss + fake_loss) / 2

    #     return d_loss

    def calculate_auxiliary_loss(self, hidden_states, msr_target_list, segment_lengths):
        B = segment_lengths['batch_size']
        S = segment_lengths['timesteps']
        prompt_length = segment_lengths['prompt_length']

        # NOTE: The current action hidden states should predict the next state, not the state from the same timestep
        # Therefore, the last action does not predict

        # Remove prompt hidden states
        sa_hidden_states = hidden_states[:, prompt_length:, :]
        # Use the last action hidden states to predict the next state
        last_a_hidden_states = sa_hidden_states.view(B, S, -1, self.hidden_size)[:B, :-1, -1, :] # (B, S-1, hidden)

        # DeConv MSR head
        last_a_hidden_states = last_a_hidden_states.reshape(B*(S-1), self.hidden_size, 1, 1) # (B*(S-1), hidden, 1, 1)
        pred_img = self.image_decoder(last_a_hidden_states) # (B*(S-1), C, padded_image_size, padded_image_size)
        pred_img = pred_img.view(B*(S-1), -1)

        # Construct training targets and calculate loss
        target_image_flat = msr_target_list[0][:, 1:, ...] # (B, S-1, C * padded_image_size**2)
        # target_image_flat = target_image.view(B*(S-1), -1) # (B*(S-1), 4 * padded_image_size**2)
        loss_msr_auxiliary = F.mse_loss(pred_img, target_image_flat)

        # Return a dict that contain the target image and the predicted image, only use the rgb channels
        out_channels = self.config.feature_list[0].channels
        image_size = int((target_image_flat.size(-1) / out_channels) ** .5)
        pred_noise = pred_img.view(-1, out_channels, image_size, image_size)[:, :3, ...]
        target_noise = target_image_flat.reshape(-1, out_channels, image_size, image_size)[:, :3, ...]
        compare_imgs = {"pred_noise": pred_noise,
                        "target_noise": target_noise}

        return loss_msr_auxiliary, compare_imgs


    # def configure_optimizers(self):
    #     lr = self.hparams.lr
    #     b1 = self.hparams.b1
    #     b2 = self.hparams.b2

    #     opt_g = torch.optim.Adam(self.image_decoder.parameters(), lr=lr, betas=(b1, b2))
    #     opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
    #     return [opt_g, opt_d], []

    def log_compare_images(self, compare_imgs):
        # Compare the reconstructed images and the target images in tensorboard
        if self.global_rank == 0 and self.global_step % self.config.val_interval == 0:
            tb = self.logger.experiment
            for key in compare_imgs:
                for i in range(min(compare_imgs[key].shape[0], self.log_image_num)):
                    img = (compare_imgs[key][i] * 255).to(torch.uint8)
                    tb.add_image(f"pretrain_msr_images/{i}_{key}",
                                 img,
                                 global_step = int(self.global_step))

    def write_image(self, batch):
        img = batch["base_camera_rgbd"][0, 0, :3, :, :].permute(1, 2, 0)
        img = (img * 0.225 + 0.45) * 255
        im = Image.fromarray(img.detach().cpu().numpy().astype(np.uint8))
        im.save("test.png")