from PIL import Image
import numpy as np
import trimesh
import torch
from torchvision.utils import save_image
from pointcept.models.utils.structure import Point

import os
import matplotlib.pyplot as plt
import os.path as osp
from sklearn.decomposition import PCA
from sklearn.preprocessing import minmax_scale

import time
import logging


class Visualizer2D:
    def __init__(self, patch_h, patch_w, view_num):
        self.patch_h = patch_h
        self.patch_w = patch_w
        self.view_num = view_num

    def feature2D_vis(
        self,
        name,
        images_plot,
        x_norm_patchtokens,
        save_fg_mask=False,
        output_folder="outputs",
        mask_sample=None,
    ):
        b = len(name)
        images_plot = images_plot.reshape(
            -1, self.view_num, 3, images_plot.shape[-2], images_plot.shape[-1]
        )
        x_norm_patchtokens = x_norm_patchtokens.reshape(
            -1, self.view_num, self.patch_h, self.patch_w, x_norm_patchtokens.shape[-1]
        )
        for i in range(images_plot.shape[0]):
            # name[i] = "test"
            os.makedirs(osp.join(output_folder, name[i]), exist_ok=True)
            images_plot_i = images_plot[i]
            x_norm_patchtokens_i = x_norm_patchtokens[i]
            self.feature2D_vis_single(
                name[i],
                images_plot_i,
                x_norm_patchtokens_i,
                save_fg_mask=save_fg_mask,
                output_folder=osp.join(output_folder, name[i]),
                mask_sample=mask_sample,
            )

    def feature2D_vis_single(
        self,
        name,
        images_plot,
        x_norm_patchtokens,
        save_fg_mask=False,
        output_folder="outputs",
        mask_sample=None,
    ):
        # os.makedirs(output_folder, exist_ok=True)
        images_plot = (
            ((images_plot.cpu().numpy()) * 255).transpose(0, 2, 3, 1).astype(np.uint8)
        )
        # img_size = self.patch_w*self._alignment
        x_norm_patchtokens = x_norm_patchtokens.cpu().detach().to(torch.float32).numpy()
        if x_norm_patchtokens.shape[0] > 4:
            x_norm_patchtokens = x_norm_patchtokens[:4]
            images_plot = images_plot[:4]
        img_cnt = x_norm_patchtokens.shape[0]
        x_norm_patches = x_norm_patchtokens.reshape(
            img_cnt * self.patch_h * self.patch_w, -1
        )

        heatmap = np.sum(x_norm_patches, axis=-1)
        heatmap = heatmap.reshape(img_cnt, self.patch_h, self.patch_w)

        fg_pca = PCA(n_components=3)
        fg_pca_images = fg_pca.fit_transform(x_norm_patches)
        fg_pca_images = minmax_scale(fg_pca_images)
        masks = fg_pca_images[:, 0] < 0.4
        pca_images = fg_pca_images
        pca_images = pca_images.reshape(img_cnt, self.patch_h, self.patch_w, 3)
        masks = masks.reshape(img_cnt, self.patch_h, self.patch_w)
        if save_fg_mask:
            for i in range(img_cnt):
                mask = masks[i]
                plt.subplot(221 + i)
                # print(images_plot)
                plt.imshow(images_plot[i])
                plt.imshow(
                    mask,
                    extent=(0, images_plot[i].shape[1], images_plot[i].shape[0], 0),
                    alpha=0.5,
                )
            plt.savefig(osp.join(output_folder, f"{name}_mask.jpg"))
            plt.close()

        plt.figure(figsize=(8, 8))
        for i in range(img_cnt):
            plt.subplot(img_cnt, 2, 2 * i + 1)
            plt.axis("off")
            plt.imshow(images_plot[i])

            plt.subplot(img_cnt, 2, 2 * i + 2)
            plt.axis("off")
            pca_results = pca_images[i]
            plt.imshow(pca_results)
        plt.savefig(osp.join(output_folder, f"{name}_results.jpg"))
        plt.close()

        plt.figure(figsize=(8, 8))
        for i in range(img_cnt):
            plt.subplot(img_cnt, 2, 2 * i + 1)
            plt.axis("off")
            plt.imshow(images_plot[i])

            plt.subplot(img_cnt, 2, 2 * i + 2)
            plt.axis("off")
            heatmap_i = heatmap[i]
            plt.imshow(heatmap_i)
        plt.savefig(osp.join(output_folder, f"{name}_heatmap.jpg"))
        plt.close()

        if mask_sample is not None:
            plt.figure(figsize=(8, 8))
            for i in range(img_cnt):
                mask_sample[i] = np.floor(mask_sample[i].cpu().detach().numpy())
                mask_sample[i] = mask_sample[i].astype(np.int32)
                plt.subplot(img_cnt, 2, 2 * i + 1)
                plt.axis("off")
                plt.imshow(images_plot[i])

                plt.subplot(img_cnt, 2, 2 * i + 2)
                plt.axis("off")
                pca_results = np.zeros((self.patch_h, self.patch_w, 3))

                pca_results[mask_sample[i][:, 0], mask_sample[i][:, 1]] = pca_images[i][
                    mask_sample[i][:, 0], mask_sample[i][:, 1]
                ]
                plt.imshow(pca_results)
            plt.savefig(osp.join(output_folder, f"{name}_mask_sample.jpg"))
            plt.close()


class Visualizer3D:
    def __init__(self):
        pass

    def feature3D_vis(
        self,
        offset,
        name,
        coord,
        color,
        origin_coord,
        origin_color,
        feat,
        save_fg_mask=False,
        output_folder="outputs",
    ):
        start = 0
        coords = coord.cpu().detach().numpy()
        colors = color.cpu().detach().numpy()
        origin_coords = origin_coord.cpu().detach().numpy()
        origin_colors = origin_color.cpu().detach().numpy()
        feats = feat.cpu().detach().numpy()
        # print("shape", coords.shape, colors.shape, feats.shape)
        for i in range(len(offset)):
            # name[i] = "test"
            os.makedirs(osp.join(output_folder, name[i]), exist_ok=True)
            coord_i = coords[start : offset[i]]
            color_i = colors[start : offset[i]]
            origin_coord_i = origin_coords[start : offset[i]]
            origin_color_i = origin_colors[start : offset[i]]
            feat_i = feats[start : offset[i]]
            self.feature3D_vis_single(
                name[i],
                coord_i,
                color_i,
                origin_coord_i,
                origin_color_i,
                feat_i,
                save_fg_mask=save_fg_mask,
                output_folder=osp.join(output_folder, name[i]),
            )
            start = offset[i]

    def feature3D_vis_single(
        self,
        name,
        coords,
        colors,
        origin_coords,
        origin_colors,
        feats,
        save_fg_mask=False,
        output_folder="outputs",
    ):
        points_num = feats.shape[0]

        fg_pca = PCA(n_components=3)
        fg_pca_images = fg_pca.fit_transform(feats)
        fg_pca_images = minmax_scale(fg_pca_images)
        masks = fg_pca_images[:, 0] < 0.4

        pca_images = fg_pca_images

        if save_fg_mask:
            # pct = trimesh.PointCloud(coords, colors=colors)
            pct = trimesh.PointCloud(coords, colors=colors)
            # pct = trimesh.PointCloud(origin_coords, colors=origin_colors/255.)
            # pct = trimesh.PointCloud(coords[masks], colors=colors[masks])
            pct.export(osp.join(output_folder, f"{name}_mask.ply"))

        pct = trimesh.PointCloud(coords, pca_images)
        pct.export(osp.join(output_folder, f"{name}_results.ply"))

        # pct = trimesh.PointCloud(origin_coords, origin_colors / 255.0)
        pct = trimesh.PointCloud(origin_coords, origin_colors)
        pct.export(osp.join(output_folder, f"{name}_origin.ply"))


class IMGVisualizer:
    def __init__(self):
        pass

    def img_vis(self, global_crops, local_crops, masks, name, output_folder="outputs"):
        output_folder = osp.join(output_folder, name)
        os.makedirs(output_folder, exist_ok=True)
        global_crops = global_crops.cpu().detach()
        local_crops = local_crops.cpu().detach()
        masks = masks.cpu().detach()
        n_global_crops = global_crops.shape[0]
        n_local_crops = local_crops.shape[0]
        n_masks = masks.shape[0]
        for i in range(n_global_crops):
            global_crop = global_crops[i]
            save_image(
                global_crop,
                osp.join(output_folder, f"{name}_global_{i}.jpg"),
                nrow=1,
                padding=0,
            )
        for i in range(n_local_crops):
            local_crop = local_crops[i]
            save_image(
                local_crop,
                osp.join(output_folder, f"{name}_local_{i}.jpg"),
                nrow=1,
                padding=0,
            )
        patch_num = int(np.sqrt(masks.shape[-1]))
        patch_size = global_crops.shape[-1] // patch_num
        masks = masks.unsqueeze(-1).repeat(1, 1, patch_size * patch_size)
        masks = masks.reshape(n_masks, patch_num, patch_num, patch_size, patch_size)
        masks = masks.permute(0, 1, 3, 2, 4).reshape(
            n_masks, patch_num * patch_size, patch_num * patch_size
        )
        for i in range(n_masks):
            mask = masks[i]
            mask = mask.cpu().detach().numpy()
            mask = Image.fromarray(((1 - mask) * 255).astype(np.uint8))

            global_crop = global_crops[i % n_global_crops].cpu().detach().numpy()
            global_crop = np.transpose(global_crop, (1, 2, 0))  # CHW to HWC
            global_crop = (global_crop * 255).astype(np.uint8)
            global_crop = Image.fromarray(global_crop)

            blended = Image.blend(global_crop, mask.convert("RGB"), alpha=0.5)
            blended.save(osp.join(output_folder, f"{name}_mask_{i}.jpg"))
