import torch.nn as nn
from evaluate.mae_utils import *
from evaluate.segmentation_utils import *


class ResizeTransform:
    def __init__(self, size):
        self.size = size

    def __call__(self, img):
        return F.interpolate(img, size=self.size, mode='bilinear', align_corners=False)


def _generate_result_for_canvas(args, model, canvas_pred, canvas_label, arr):
    """canvas is already in the right range."""
    ids_shuffle, len_keep = generate_arr_mask_for_evaluation(arr)
    batch_size = canvas_pred.shape[0]
    original_image_list = []
    generated_result_list = []

    for i in range(batch_size):
        _, im_paste, _ = generate_image(canvas_pred[i].unsqueeze(0).to(args.device), model, ids_shuffle.to(args.device),
                                        len_keep, device=args.device)
        canvas_ = torch.einsum('chw->hwc', canvas_label[i])
        canvas_ = torch.clip((canvas_.cpu().detach() * imagenet_std + imagenet_mean) * 255, 0, 255).int().numpy()
        assert canvas_.shape == im_paste.shape, (canvas_.shape, im_paste.shape)

        original_image_list.append(np.uint8(canvas_))
        generated_result_list.append(np.uint8(im_paste))


    return original_image_list, generated_result_list


def _generate_result_for_large_canvas(args, model, canvas_pred, canvas_label, arr):
    """canvas is already in the right range."""
    ids_shuffle, len_keep = generate_large_mask_for_evaluation(arr)
    batch_size = canvas_pred.shape[0]
    original_image_list = []
    generated_result_list = []

    for i in range(batch_size):
        _, im_paste, _ = generate_image(canvas_pred[i].unsqueeze(0).to(args.device), model, ids_shuffle.to(args.device),
                                        len_keep, device=args.device)
        canvas_ = torch.einsum('chw->hwc', canvas_label[i])
        canvas_ = torch.clip((canvas_.cpu().detach() * imagenet_std + imagenet_mean) * 255, 0, 255).int().numpy()
        assert canvas_.shape == im_paste.shape, (canvas_.shape, im_paste.shape)

        original_image_list.append(np.uint8(canvas_))
        generated_result_list.append(np.uint8(im_paste))

    return original_image_list, generated_result_list


def _generate_result_by_knn_indice_weight_prob(args, model, canvas_pred, canvas_label, arr, datastore, key_list, k):
    """canvas is already in the right range."""
    ids_shuffle, len_keep = generate_arr_mask_for_evaluation(arr)
    batch_size = canvas_pred.shape[0]
    original_image_list = []
    generated_result_list = []

    for i in range(batch_size):
        _, im_paste, _ = generate_image_by_knn_indice_weight_prob(canvas_pred[i].unsqueeze(0).to(args.device), model,
                                                                  ids_shuffle.to(args.device),
                                                                  datastore, key_list[i], k, len_keep, args,
                                                                  device=args.device)
        canvas_ = torch.einsum('chw->hwc', canvas_label[i])
        canvas_ = torch.clip((canvas_.cpu().detach() * imagenet_std + imagenet_mean) * 255, 0, 255).int().numpy()
        assert canvas_.shape == im_paste.shape, (canvas_.shape, im_paste.shape)

        original_image_list.append(np.uint8(canvas_))
        generated_result_list.append(np.uint8(im_paste))

    return original_image_list, generated_result_list


def round_image(img, options=(WHITE, BLACK, RED, GREEN, BLUE), outputs=None, t=(0, 0, 0)):
    img = torch.tensor(img)
    t = torch.tensor((t)).to(img)
    options = torch.tensor(options)
    opts = options.view(len(options), 1, 1, 3).permute(1, 2, 3, 0).to(img)
    nn = (((img + t).unsqueeze(-1) - opts) ** 2).float().mean(dim=2)
    nn_indices = torch.argmin(nn, dim=-1)
    if outputs is None:
        outputs = options
    res_img = torch.tensor(outputs)[nn_indices]
    return res_img


class MetaTrn(nn.Module):
    """editted for visual prompting"""
    def __init__(self, args, vqgan, arr):
        super().__init__()
        self.args = args
        self.device = args.device
        self.padding = 1
        self.vqgan = vqgan
        self.arr = arr

    def create_grid_from_images(self, prompt_img, support_mask, query_img, query_mask):
        canvas = torch.ones((prompt_img.shape[1], 2 * prompt_img.shape[2] + 2 * self.padding,
                             2 * prompt_img.shape[3] + 2 * self.padding))
        canvas[:, :prompt_img.shape[2], :prompt_img.shape[3]] = prompt_img

        canvas[:, -query_img.shape[2]:, :query_img.shape[3]] = query_img
        canvas[:, :prompt_img.shape[2], -prompt_img.shape[3]:] = support_mask
        canvas[:, -query_img.shape[2]:, -prompt_img.shape[3]:] = query_mask
        canvas = (canvas.detach().numpy() - imagenet_mean[:, None, None]) / imagenet_std[:, None, None]

        return torch.from_numpy(canvas)

    def create_gradiant_grid_images(self, support_img, support_mask, query_img, query_mask, grid, arr):
        # create grid image for support images and query image.
        content_list = [support_img, support_mask, query_img, query_mask]

        if arr == 'a1':
            support_img = content_list[0]
            support_mask = content_list[1]
            query_img = content_list[2]
            query_mask = content_list[3]

        elif arr == 'a2':
            support_img = content_list[1]
            support_mask = content_list[0]
            query_img = content_list[3]
            query_mask = content_list[2]

        elif arr == 'a3':
            support_img = content_list[3]
            support_mask = content_list[2]
            query_img = content_list[1]
            query_mask = content_list[0]

        elif arr == 'a4':
            support_img = content_list[2]
            support_mask = content_list[3]
            query_img = content_list[0]
            query_mask = content_list[1]

        elif arr == 'a5':
            support_img = content_list[1]
            support_mask = content_list[3]
            query_img = content_list[0]
            query_mask = content_list[2]

        elif arr == 'a6':
            support_img = content_list[3]
            support_mask = content_list[1]
            query_img = content_list[2]
            query_mask = content_list[0]

        elif arr == 'a7':
            support_img = content_list[2]
            support_mask = content_list[0]
            query_img = content_list[3]
            query_mask = content_list[1]

        elif arr == 'a8':
            support_img = content_list[0]
            support_mask = content_list[2]
            query_img = content_list[1]
            query_mask = content_list[3]

        img_size = 111
        grid[:, :, :img_size, :img_size] = support_img
        grid[:, :, -img_size:, :img_size] = query_img
        grid[:, :, :img_size, -img_size:] = support_mask
        grid[:, :, -img_size:, -img_size:] = 1

        return grid

    def create_gradiant_grid_label_images(self, support_img, support_mask, query_img, query_mask, grid):
        # create grid image for suppot images and query image.
        grid[:, :, :support_img.shape[2], :support_img.shape[3]] = support_img

        grid[:, :, -query_img.shape[2]:, :query_img.shape[3]] = query_img
        grid[:, :, :support_img.shape[2], -support_img.shape[3]:] = support_mask
        grid[:, :, -query_img.shape[2]:, -support_img.shape[3]:] = query_mask

        return grid

    def _generate_raw_prediction(self, canvas, arr):
        """canvas is already in the right range."""
        ids_shuffle, len_keep = generate_arr_mask_for_evaluation(arr)

        mask, _, y_pred = generate_raw_prediction(self.device, ids_shuffle, len_keep, self.vqgan, canvas)

        return y_pred, mask

    def _generate_raw_prediction_features_before_fc(self, canvas, arr):
        """canvas is already in the right range."""
        ids_shuffle, len_keep = generate_arr_mask_for_evaluation(arr)
        mask, _, y_pred, fc_feature = generate_raw_prediction_features_before_fc(self.device, ids_shuffle, len_keep,
                                                                                 self.vqgan, canvas)

        return y_pred, mask, fc_feature

    @staticmethod
    def generate_left_mask():
        left_mask = torch.zeros([14, 14])
        left_mask[7:, :7] = 1
        left_mask = left_mask.unsqueeze(0).flatten(1)

        return left_mask

    @staticmethod
    def generate_right_mask():
        left_mask = torch.zeros([14, 14])
        left_mask[7:, 7:] = 1
        left_mask = left_mask.unsqueeze(0).flatten(1)

        return left_mask

    @staticmethod
    def generate_left_right_mask():
        left_mask = torch.zeros([14, 14])
        left_mask[7:, :] = 1
        left_mask = left_mask.unsqueeze(0).flatten(1)

        return left_mask

    @staticmethod
    def generate_upper_left_mask():
        upper_left_mask = torch.zeros([14, 14])
        upper_left_mask[:7, :7] = 1
        upper_left_mask = upper_left_mask.unsqueeze(0).flatten(1)

        return upper_left_mask

    @staticmethod
    def generate_upper_right_mask():
        upper_right_mask = torch.zeros([14, 14])
        upper_right_mask[:7, 7:] = 1
        upper_right_mask = upper_right_mask.unsqueeze(0).flatten(1)

        return upper_right_mask

    def forward_meta_test_indice(self, canvas_pred, support_name):
        y_pred, mask = self._generate_raw_prediction(canvas_pred, self.arr)
        data_list, support_name = self.vqgan.cal_key_with_indice_label(canvas_pred, y_pred, mask, support_name)

        return data_list, support_name

    def form_cavas(self, support_img, support_mask, query_img, query_mask, grid):
        canvas_label = grid.clone()

        canvas_pred = self.create_gradiant_grid_images(support_img, support_mask, query_img, query_mask, grid, self.arr)
        canvas_label = canvas_label.float().to(self.device)

        return canvas_label, canvas_pred

    def forward_meta_test_indice_feature_bf_fc(self, canvas_pred, support_name):
        y_pred, mask, fc_features = self._generate_raw_prediction_features_before_fc(canvas_pred, self.arr)
        data_list, support_name = self.vqgan.cal_key_with_indice_fc_features(canvas_pred, y_pred, mask, fc_features, support_name)

        return data_list, support_name

    def forward_meta_test_indice_pixel_patch(self, canvas_pred, support_name, model):

        def patchify(image, patch_size=16):
            patches = rearrange(image, '(h p1) (w p2) c -> (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)
            return patches

        y_pred, mask = self._generate_raw_prediction(canvas_pred, self.arr)
        num_patches = 14
        y = y_pred.argmax(dim=-1)
        im_paste, _, _ = decode_raw_predicion(mask, model, num_patches, canvas_pred, y)

        im_paste = im_paste[0]  # torch.Size([224, 224, 3])
        patches = patchify(im_paste, patch_size=16)
        patches = patches.unsqueeze(0)  # [1, 196, 768]

        data_list, support_name = self.vqgan.cal_key_with_indice_pixel_patch(canvas_pred, y_pred, mask, support_name, patches)

        return data_list, support_name
