import math
from pyexpat import features
from re import split
import numpy as np
import einops
from matplotlib.gridspec import GridSpec
from matplotlib.colors import Normalize
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import warnings
import matplotlib.gridspec as gridspec
from mpl_toolkits.axes_grid1 import make_axes_locatable
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from sympy import frac
# umap throws numba warnings - ignore them
warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*")
import umap as mp

class VisualizationCarrier():
    
    def __init__(self, params) -> None:
        self.params = params
        self.turned_on = params.vis_carrier_turned_on
        self.use_umap = params.use_umap
        if self.params.use_vit:
            self.patch_size = params.vit_patch_size
        else:
            self.patch_size = (1, 1)
        self.qk_masks = {}
        self.queries_keys = {}
    
    def reset(self):
        if self.turned_on:
            self.qk_masks = {}
            self.image = None
    
    def add_qk_masks(self, name, mask):
        if self.turned_on:
            self.qk_masks[name] = mask
            
    def add_queries_keys(self, name, queries, keys):
        if self.turned_on and self.use_umap:
            self.queries_keys[name] = (queries, keys)
    
    def add_image(self, image):
        if self.turned_on:
            self.image = image
        
    def create_plots(self):
        if not self.turned_on:
            return {}
    
        # plot_dict = self._create_concat_plots()
    
        plot_dict = self._create_mask_plots()
        if self.use_umap:
            umap_plot_dict = self._create_umap_plots()
            plot_dict.update(umap_plot_dict)
        
        return plot_dict
    
    def _create_concat_plots(self, image, mask, patch_size, pad_size=2):
        n_clusters, n_patches = mask.shape
        height, width, channels = image.shape

        assert n_clusters <= 20, f'Number of clusters ({n_clusters}) is too large.'

        # reshape mask to image size
        mask_img = einops.rearrange(mask, 'c (n_h n_w) -> c n_h n_w', n_h=height // patch_size[0],
                                    n_w=width // patch_size[1])

        cmap_norm = cm.colors.Normalize(vmin=0, vmax=1)
        if mask_img.max() > 1:
            cmap_norm = cm.colors.Normalize(vmin=mask_img.min(), vmax=mask_img.max())

        num_figures = n_clusters + 2

        fig = plt.figure(figsize=(num_figures * 4, 4.8))
        width_ratios = [1.0] * (num_figures - 1)
        width_ratios.append(0.05)
        gs = gridspec.GridSpec(1, num_figures, width_ratios=width_ratios)

        # Create figures and axes
        axes = [fig.add_subplot(gs[i]) for i in range(0, num_figures, 1)]

        heatmaps = []
        for i, ax in enumerate(axes):
            if i == 0:
                ax.imshow(image)
                ax.set_xticks([])
                ax.set_yticks([])
            elif i <= n_clusters:
                ax.imshow(np.flip(image, axis=0), alpha=0.5)
                heatmap = ax.matshow(mask_img[i - 1], cmap='viridis', norm=cmap_norm,
                                extent=[0, mask_img[i - 1].shape[1], 0, mask_img[i - 1].shape[0]], alpha=0.5)
                heatmaps.append(heatmap)
                ax.set_xticks([])
                ax.set_yticks([])

        gs.update(wspace=0.005)

        cbar = plt.colorbar(heatmap, ax=axes[1:-1], cax=axes[-1], shrink=0.6)
        cbar.ax.yaxis.set_ticks_position('right')
        
        return fig
        
            
    def _create_umap_plots(self):
        plot_dict = {}
        
        for name in self.queries_keys.keys():
            queries, keys = self.queries_keys[name]
            fig = self.plot_umap(self.image, queries, keys)
            plot_dict[name] = fig
        
        return plot_dict
            
    def _create_mask_plots(self):
        plot_dict = {}
        
        height, width, channels = self.image.shape
        
        split_image = einops.rearrange(self.image, '(h p1) (w p2) c -> (h w p1) p2 c', p1=self.patch_size[0], p2=self.patch_size[1])
        n_patches = height // self.patch_size[0] * width // self.patch_size[1]
        
        for name in self.qk_masks.keys():
            mask = self.qk_masks[name]
            
            # unify orientation; width should be larger than height
            m_height, m_width = mask.shape
            if m_width < m_height:
                mask = einops.rearrange(mask, 'w h -> h w')
                m_height, m_width = mask.shape
            
            if m_width / m_height >= 20:
                # fig = self.plot_clusters_in_image(self.image, mask, patch_size=self.patch_size)
                fig = self._create_concat_plots(self.image, mask, patch_size=self.patch_size)
                
            else:
                if not self.params.use_vit:
                    fig = self.plot_cnn_masks(self.image, mask)
                else:
                    fig = self.plot_vit_masks(self.image, mask, split_image, patch_size=self.patch_size)
            plot_dict[name] = fig
            
        return plot_dict
    
    
    def plot_clusters_in_image(self, image, mask, patch_size=(7,7)):
        n_clusters, n_patches = mask.shape
        height, width, channels = image.shape
        
        assert n_clusters <= 20, f'Number of clusters ({n_clusters}) is too large.'
        
        # reshape mask to image size
        mask_img = einops.rearrange(mask, 'c (n_h n_w) -> c n_h n_w', n_h=height // patch_size[0], n_w=width // patch_size[1])
        
        n_cols = int(math.floor(math.sqrt(n_clusters + 1)))
        n_rows = int(math.ceil((n_clusters + 1) / n_cols))
        # x_img_ticks = np.arange(0, image.shape[0], patch_size[0])
        # x_img_ticks_labels = x_img_ticks // patch_size[0]
        # y_img_ticks = np.arange(1, image.shape[1] + 1, patch_size[1])
        # y_img_ticks_labels = np.arange(len(y_img_ticks) -1, -1, -1)
        # mat_ticks = np.arange(0, mask_img[0].shape[1], 1)
        
        fig, axes = plt.subplots(nrows=n_rows, ncols=n_cols)
        fig.tight_layout()
        cmap=cm.get_cmap('viridis')
        # normalizer=Normalize(0,4)
        im=cm.ScalarMappable()
        for i, ax in enumerate(axes.flat):
            if i == 0:
                ax.imshow(image)
                # if patch_size[0] > 1:  # only show ticks for patches
                    # only set ticks for the first row
                    # ax.set_xticks(x_img_ticks)
                    # ax.xaxis.set_ticks_position('top')
                    # if i < n_cols:
                    #     ax.set_xticklabels(x_img_ticks_labels)
                    # else:
                    #     ax.set_xticklabels([])
                    # ax.set_yticks(y_img_ticks)
                    # ax.set_yticklabels(y_img_ticks_labels)
                    # ax.grid(True, linestyle='--', linewidth=0.5)
            elif i <= n_clusters:
                ax.imshow(np.flip(image, axis=0), alpha=0.5)
                cax = ax.matshow(mask_img[i - 1], cmap=cmap, extent=[0, mask_img[i - 1].shape[1], 0, mask_img[i - 1].shape[0]], alpha=0.5)
                cbar = fig.colorbar(cax)
                # if patch_size[0] > 1:  # only show ticks for patches
                #     # only set ticks for the first row
                #     ax.set_xticks(mat_ticks)
                #     if i < n_cols:
                #         ax.set_xticklabels(mat_ticks)
                #     else:
                #         ax.set_xticklabels([])
                #     ax.set_yticks(mat_ticks)
                #     ax.set_yticklabels(mat_ticks)
                #     ax.grid(True, linestyle='--', linewidth=0.5)
            else:
                ax.axis('off')
            
            # Hide X and Y axes label marks
            ax.xaxis.set_tick_params(labelbottom=False)
            ax.yaxis.set_tick_params(labelleft=False)

            # Hide X and Y axes tick marks
            ax.set_xticks([])
            ax.set_yticks([])
                
        # fig.colorbar(im, ax=axes.ravel().tolist())
        
        return fig

    def plot_vit_masks(self, image, mask, split_image, patch_size=(7,7)):
        
        n_patches = image.shape[0] // patch_size[0] * image.shape[1] // patch_size[1]
        n_patches_x = image.shape[0] // patch_size[0]
        n_patches_y = image.shape[1] // patch_size[1]
        mask_height, mask_width = mask.shape
        
        fig_size = (10, 5)
        if n_patches == mask_height:
            fig_size = (10, 10)
        fig = plt.figure(figsize=fig_size)
        gs = GridSpec(2, 2, figure=fig, hspace=0.1, wspace=0.1, height_ratios=[n_patches_x, mask_height], width_ratios=[n_patches_y, mask_width])

        # Create subplots
        ax1 = fig.add_subplot(gs[0, 0])
        ax2 = fig.add_subplot(gs[0, 1])
        ax3 = fig.add_subplot(gs[1, 0])
        ax4 = fig.add_subplot(gs[1, 1])

        x_ticks = np.arange(0, image.shape[0], patch_size[0])
        y_ticks = np.arange(0, image.shape[1], patch_size[1])

        # the original image
        ax1.imshow(image)
        # make grid white lines every patch_size pixels
        # ax1.set_xticks(x_ticks)
        # ax1.set_yticks(y_ticks)
        ax1.set_xticklabels(x_ticks // patch_size[0])
        ax1.set_yticklabels(y_ticks // patch_size[1])
        ax1.grid(True, linestyle='--', linewidth=0.5)
    
        x_patch_ticks = np.arange(0, mask_width, 1) * patch_size[0]
        y_patch_ticks = np.arange(0, mask_height, 1) * patch_size[1]
        patch_labels = []
        for i in range(0, mask_width, 1):
            row = i // n_patches_y
            col = i % n_patches_y
            patch_labels.append(f'({row}, {col})')
        
        shift_pos_down_or_right = 0.08
        
        # horizontally concatenated patches 
        ax2.imshow(einops.rearrange(split_image, '(n_patches p_height) p_width c -> p_height (n_patches p_width) c', p_height=patch_size[0], p_width=patch_size[1]))
        ax2.set_xticks(x_patch_ticks)
        ax2.set_xticklabels(patch_labels, rotation=90)
        ax2.xaxis.set_ticks_position('top')
        ax2.xaxis.set_label_position('top')
        ax2.grid(True, linestyle='--', linewidth=0.5)
        ax2.set_yticks([])
        ax2_pos = ax2.get_position()
        ax2.set_position([ax2_pos.x0, ax2_pos.y0 - shift_pos_down_or_right, ax2_pos.width, ax2_pos.height])

        # vertically concatenated patches
        # or clusters
        if n_patches == mask_height:
            ax3.imshow(split_image)
            ax3.set_yticks(y_patch_ticks)
            ax3.set_yticklabels(patch_labels)
            ax3.grid(True, linestyle='--', linewidth=0.5)
            ax3.set_xticks([])
            ax3_pos = ax3.get_position()
            ax3.set_position([ax3_pos.x0 + shift_pos_down_or_right, ax3_pos.y0, ax3_pos.width, ax3_pos.height])
        else:
            # leave ax3 empty
            ax3.axis('off')

        # similarity matrix
        extent = [0, mask.shape[1] * patch_size[1], 0, mask.shape[0] * patch_size[0]]
        cax = ax4.matshow(mask, extent=extent, cmap='viridis')
        cb_ax = fig.add_axes([0.92, 0.35, 0.02, 0.3])
        cbar = fig.colorbar(cax, cax=cb_ax)
        ax4.set_xticks(x_patch_ticks)
        ax4.set_yticks(y_patch_ticks)
        ax4.set_xticklabels([])
        ax4.set_yticklabels([])
        ax4.grid(True, linestyle='--', linewidth=0.5)     
            
        return fig
    
    def plot_umap(self, image, queries, keys):
        # https://umap-learn.readthedocs.io/en/latest/basic_usage.html
        H, W, C = image.shape
        n_queries, d = queries.shape
        n_keys, d = keys.shape
        assert n_keys == H * W, f'Number of keys ({n_keys}) does not match number of pixels ({H * W}).'
        
        image_flat = einops.rearrange(image, 'h w c -> (h w) c')
        
        features = np.concatenate((queries, keys), axis=0)
        reducer = mp.UMAP(random_state=42)
        reducer.fit(features)
        embedding = reducer.transform(features)
        
        emb_queries = embedding[:n_queries]
        emb_keys = embedding[n_queries:]
        
        fig_size = (10, 5)
        fig = plt.figure(figsize=fig_size)
        fig.tight_layout()
        gs = GridSpec(1, 2, figure=fig, hspace=0.1, wspace=0.1, width_ratios=[1, 8])

        # Create subplots
        ax1 = fig.add_subplot(gs[0, 0])
        ax2 = fig.add_subplot(gs[0, 1])

        ax1.imshow(image)
        
        ax2.scatter(emb_keys[:, 0], emb_keys[:, 1], marker='.', c=image_flat, alpha=0.5)
        ax2.scatter(emb_queries[:, 0], emb_queries[:, 1], marker='x', c='gold', s=200, alpha=0.5)
        # ax2.set_aspect('equal', 'datalim')   

        return fig
    
    
    def plot_cnn_masks(self, image, mask):
        
        mask_height, mask_width = mask.shape
        
        col_tick_labels = np.arange(0, image.shape[0], 1)
        col_ticks = np.arange(0, mask_width, image.shape[1])
        
        # row_ticks_labels = np.arange(0, image.shape[1], 1)
        
        fig_size = (10, 10)
        fig, ax = plt.subplots(figsize=fig_size)

        extent = [0, mask_width, 0, mask_height]  # left, right, bottom, top
        if mask_width > mask_height:
            extent = [0, mask_width, 0, mask_width]
        cax = ax.matshow(mask, extent=extent, cmap='viridis')

        ax.set_xticks(col_ticks, col_tick_labels)
        ax.set_yticks([])
        if mask_width == mask_height:
            ax.set_yticks(col_ticks, col_tick_labels)
        ax.grid(True, which='both', linestyle='--', linewidth=1, color='white', alpha=0.2)

        colorbar = fig.colorbar(cax, fraction=0.046, pad=0.04)
        # colorbar.set_label('Colorbar Label', rotation=270, labelpad=15)

        return fig
    