import torch
import numpy as np
import matplotlib.pyplot as plt
from data_utils.sampling import *

import matplotlib.ticker as mticker
import matplotlib.colors as colors
from matplotlib import cm
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

class Visualizer():
    def __init__(self, simul_type, model_title_list,
                        K, sample_nu_list, sample_mu_list, sample_var_list, ratio_list, xlim,device,seed):
        self.simul_type = simul_type
        self.sampler = select_sampler(simul_type,device, seed)
        self.model_title_list = model_title_list
        if simul_type in ["t2"]:
            self.model_title_list.insert(0,"test_data")
        self.K = K
        self.sample_nu_list = sample_nu_list
        self.sample_mu_list =  sample_mu_list 
        self.sample_var_list = sample_var_list 
        self.ratio_list = ratio_list
        self.xlim = xlim

    def visualize(self, model_gen_list, **kwargs):
        if self.simul_type in ["t1", "p1"]:
            return self.visualize_density(model_gen_list)
        elif self.simul_type in ["t2"]:
            return self.draw_heatmap(model_gen_list)
        else:
            return None
        
    def visualize_density(self, model_gen_list):
        model_gen_list = [gen[torch.isfinite(gen)].cpu().numpy() for gen in model_gen_list]
        M = len(model_gen_list)
        input = np.arange(-self.xlim * 100, self.xlim * 100 + 1) * 0.01
        contour = self.sampler.density_contour(input, self.K, self.sample_nu_list, self.sample_mu_list, self.sample_var_list, self.ratio_list).squeeze().numpy()

        # Set subplot grid
        rows = 2
        cols = (M + 1) // 2  # round up for odd M
        fig = plt.figure(figsize=(3.5 * cols, 7))

        for m in range(M):
            row_idx = m % rows
            col_idx = m // rows
            ax = fig.add_subplot(rows, cols, row_idx * cols + col_idx + 1)
            ax.set_title(self.model_title_list[m])
            ax.plot(input, contour, color='black')
            ax.hist(model_gen_list[m], bins=100, range=[0, self.xlim], density=True, alpha=0.5, color='dodgerblue')
            if self.simul_type == "t1":
                ax.set_xlim(-self.xlim, self.xlim)
            else:
                ax.set_xlim(0, self.xlim)
            ax.set_yscale("log")
            ax.set_ylim(1e-6, 1)

        return fig

    

    def draw_heatmap(self, data_list, x_range=[-20, 20], y_range=[-20, 20], bins=100, per_fig_size=(3.5,7)):
        data_list = [data.cpu().numpy() for data in data_list]
        
        all_colors = np.vstack((plt.cm.Greys(0), plt.cm.terrain(np.linspace(0, 1, 256))))
        new_cmap = colors.LinearSegmentedColormap.from_list('viridis', all_colors)
        nb_plot = len(data_list)
        num_row = (nb_plot+2) // 2 if nb_plot != 1 else 1
        
        fig = plt.figure(figsize=(per_fig_size[0] * num_row, per_fig_size[1]))
        for i in range(nb_plot):
            ax = fig.add_subplot(2, num_row, i+1)
            x, y = data_list[i][:,0], data_list[i][:,1] # data_point : [N, m_dim]
            hist, xedges,yedges = np.histogram2d(x,y, bins=bins,range=[x_range,y_range],density=True)
            ax.set_title(self.model_title_list[i],pad=-5)
            plt.imshow(hist.T, interpolation='nearest', origin='lower', extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]],cmap = new_cmap)
        return fig