import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import numpy as np
from scipy import stats
import torch


# sub_fontsize=18
# title_fontsize=20
sub_fontsize=24
title_fontsize=24

def _add_t_to_title_and_path(title, path, t, name):
    if t is not None and 'WGAN' not in name:
        title = f"{title}, t={t})"
        # Insert _t{t} before .png or .pdf
        if '.png' in path:
            path = path.replace('.png', f'_t{t}.png')
        elif '.pdf' in path:
            path = path.replace('.pdf', f'_t{t}.pdf')
        
    return title, path

def remove_underscores(s):
    remove_list = ['_mlp', '_mlp', '_ResNet']
    for item in remove_list:
        s = s.replace(item, '')
    s = s.replace('MAGGAN', 'MaGAN')
    return s.replace('_', ' ')

def name_to_path(s):
    a = s.replace(' ', '')
    return a.lower()

def plot_training_losses(loss_G, loss_C = None, folder_name = 'GAN_MNIST_results', pdf_path = None, name='WGAN (Optimizer: Adam)', t=None, step_name = 'Epoch'):
    plt.figure(figsize=(10, 5))
    plt.plot(loss_G, label='Generator Loss')
    if loss_C is not None:
        plt.plot(loss_C, label='Critic Loss')
    plt.xlabel(step_name, fontsize=sub_fontsize)
    plt.ylabel('Loss', fontsize=sub_fontsize)
    title = f"Training Losses of {remove_underscores(name)}"
    file_path = f"{folder_name}/{name_to_path(name)}_training_losses.png"
    title, file_path = _add_t_to_title_and_path(title, file_path, t, name)
    plt.title(title, fontsize=title_fontsize)
    plt.legend()
    plt.savefig(file_path)
    plt.close()
    plt.show()

def plot_generator_grad_norms(generator_grad_norm_list, folder_name, pdf_path = None, name='WGAN (Optimizer: Adam)', visualization=None, t=None, step_name = 'Epoch'):
    plt.figure(figsize=(10, 5))
    plt.plot(generator_grad_norm_list, label='Generator Gradient Norm')
    plt.xlabel(step_name, fontsize=sub_fontsize)
    plt.ylabel('Gradient Norm (L2)', fontsize=sub_fontsize)
    if visualization == 'log':
        plt.yscale('log')
    title = f"Generator Gradient Norms of {remove_underscores(name)}"
    if visualization == 'log':
        file_path = f"{folder_name}/{name_to_path(name)}_generator_grad_norms_ylog.png"
    else:
        file_path = f"{folder_name}/{name_to_path(name)}_generator_grad_norms.png"
    title, file_path = _add_t_to_title_and_path(title, file_path, t, name)
    plt.title(title, fontsize=title_fontsize)
    plt.legend()
    plt.tight_layout()
    plt.savefig(file_path)
    plt.close()
    plt.show()

def plot_gaussian_real_vs_generated(real_data, generated_data, folder_name, pdf_path, name='WGAN (Optimizer: Adam)', t=None):
    """
    Plot and save real vs generated data distribution (histogram) for Gaussian GANs.
    """
    with PdfPages(pdf_path) as pdf:
        bins = 20
        alpha = 0.5
        plt.figure(figsize=(8, 5))
        plt.hist(real_data, bins=bins, alpha=alpha, label='Real Data', density=True, color='blue')
        plt.hist(generated_data, bins=bins, alpha=alpha, label='Generated Data', density=True, color='orange')
        title = f'Generated vs. Real Data Distribution for {remove_underscores(name)}'
        file_path = f"{folder_name}/{name_to_path(name)}_generated_vs_real_distribution.png"
        title, file_path = _add_t_to_title_and_path(title, file_path, t, name)
        plt.title(title, fontsize=title_fontsize)
        plt.xlabel('Value')
        plt.ylabel('Density')
        plt.legend()
        plt.tight_layout()
        plt.savefig(file_path)
        pdf.savefig(plt.gcf())
        plt.close()
        plt.show()


def plot_scatter(real_data, generated_data, folder_name, step, epoch=None, name='WGAN (Optimizer: Adam)', t=None, step_name='Step'):
    n_groups = len(generated_data)
    fig1, axs1 = plt.subplots(1, n_groups, figsize=(5 * n_groups, 5))
    for i, gen_data in enumerate(generated_data):
        axs1[i].scatter(real_data[:, 0], real_data[:, 1], alpha=0.5, label='Real', s=10, color='blue')
        axs1[i].scatter(gen_data[:, 0], gen_data[:, 1], alpha=0.5, label=f'Generated {i}', s=10, color='red')
        axs1[i].set_title(f'{step_name} {i * step + step}', fontsize=sub_fontsize)
        axs1[i].set_xlabel('X', fontsize=sub_fontsize+5)
        axs1[i].set_ylabel('Y', fontsize=sub_fontsize+5)
        axs1[i].legend()
    title = f'Scatter Plot of Generated vs. Real Data for {remove_underscores(name)}'
    file_path = f"{folder_name}/{name_to_path(name)}_generated_vs_real_data_scatter.png"
    title, file_path = _add_t_to_title_and_path(title, file_path, t, name)

    fig1.suptitle(title, fontsize=title_fontsize+5)  # Centered main title
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])  # Leave space for suptitle
    plt.tight_layout()
    plt.savefig(file_path)
    plt.close()
    plt.show()

def plot_heatmap(generated_data, folder_name, step, epoch=None, name='WGAN (Optimizer: Adam)', t=None, step_name='Step'):
    n_groups = len(generated_data)
    fig2, axs2 = plt.subplots(1, n_groups, figsize=(5 * n_groups, 5))
    bins = 30
    for i, gen_data in enumerate(generated_data):
        h = axs2[i].hist2d(gen_data[:, 0], gen_data[:, 1], bins=bins, cmap='Reds', alpha=0.7)
        cbar = plt.colorbar(h[3], ax=axs2[i])
        cbar.ax.tick_params(labelsize=sub_fontsize)
        axs2[i].set_title(f'{step_name} {i * step + step}', fontsize=sub_fontsize)
        axs2[i].set_xlabel('X', fontsize=sub_fontsize+5)
        axs2[i].set_ylabel('Y', fontsize=sub_fontsize+5)
    title = f'Heatmap Plot of Generated vs. Real Data for {remove_underscores(name)}'
    file_path = f"{folder_name}/{name_to_path(name)}_generated_vs_real_data_heatmap.png"
    title, file_path = _add_t_to_title_and_path(title, file_path, t, name)

    fig2.suptitle(title, fontsize=title_fontsize+5)  # Centered main title
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])  # Leave space for suptitle
    plt.tight_layout()
    plt.savefig(file_path)
    plt.close()
    plt.show()


def plot_gaussian_generated_subplots(wgan_gen_data, epochs_to_plot, folder_name, pdf_path = None, name='WGAN (Optimizer: Adam)', t=None):
    """
    Plot and save subplots of generated data distributions for selected epochs.
    """
    n_cols = 10
    n_rows = 2
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 4*n_rows), sharex=True, sharey=True)
    title = f'{remove_underscores(name)} Generated Data'
    file_path = f"{folder_name}/{name_to_path(name)}_generated_data_all_epochs.png"
    title, file_path = _add_t_to_title_and_path(title, file_path, t, name)
    plt.suptitle(title, fontsize=sub_fontsize)
    cmap = plt.cm.rainbow
    norm = plt.Normalize(min(epochs_to_plot), max(epochs_to_plot))
    axs = axs.flatten()
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    for i, epoch in enumerate(epochs_to_plot):
        color = cmap(norm(epoch))
        n, bins, patches = axs[i].hist(wgan_gen_data[i], bins=20, density=True, alpha=0.5, label=f'Gen (Epoch {epoch})', color=color)
        axs[i].set_title(f'Gen Data (Epoch {epoch})', fontsize=sub_fontsize)
        axs[i].set_xlabel('Value', fontsize=sub_fontsize)
        axs[i].set_ylabel('Density', fontsize=sub_fontsize)
        data_min, data_max = wgan_gen_data[i].min(), wgan_gen_data[i].max()
        data_range = data_max - data_min
        padding = data_range * 0.05
        axs[i].set_xlim(data_min - padding, data_max + padding)
        hist_max = n.max()
        axs[i].set_ylim(0, hist_max * 1.1)
        axs[i].tick_params(labelbottom=True)
        axs[i].legend()
    plt.savefig(file_path)
    
    plt.close(fig)
    plt.show()

def plot_gaussian_kde_all_epochs(wgan_gen_data, real_data, epochs_to_plot, folder_name, pdf_path = None, name='WGAN (Optimizer: Adam)', min_epoch=2000, t=None):
    """
    Plot and save KDE distributions for all epochs and real data.
    """
    cmap = plt.cm.rainbow
    norm = plt.Normalize(min(epochs_to_plot), max(epochs_to_plot))
    fig, ax = plt.subplots(figsize=(12, 8))
    title = f'{remove_underscores(name)} Generated Data Distribution - All Epochs'
    file_path = f"{folder_name}/{name_to_path(name)}_generated_data_all_epochs_optimized.png"
    title, file_path = _add_t_to_title_and_path(title, file_path, t, name)
    plt.suptitle(title, fontsize=sub_fontsize)
    # Find global data range
    all_data_min = min(wgan_gen_data[i].min() for i in range(len(epochs_to_plot)) if i < len(wgan_gen_data) and epochs_to_plot[i] > min_epoch)
    all_data_max = max(wgan_gen_data[i].max() for i in range(len(epochs_to_plot)) if i < len(wgan_gen_data) and epochs_to_plot[i] > min_epoch)
    global_range = all_data_max - all_data_min
    pad = 0 * global_range
    x_vals_global = np.linspace(all_data_min - pad, all_data_max + pad, 300)
    # Plot all distributions using the same x-axis range
    for i, epoch in enumerate(epochs_to_plot):
        if epoch <= min_epoch:
            continue
        if i >= len(wgan_gen_data):
            continue
        data = wgan_gen_data[i]
        color = cmap(norm(epoch))
        kde = stats.gaussian_kde(data)
        ax.plot(x_vals_global, kde(x_vals_global), color=color, linewidth=2, label=f'Epoch {epoch}', alpha=0.8)
    # Plot KDE for real data
    kde = stats.gaussian_kde(real_data)
    ax.plot(x_vals_global, kde(x_vals_global), color='black', linewidth=2, label='Real Data', alpha=0.8)
    ax.set_xlabel('Value', fontsize=sub_fontsize)
    ax.set_ylabel('Density', fontsize=sub_fontsize)
    ax.grid(True, alpha=0.3)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10)
    plt.tight_layout()
    plt.savefig(file_path, dpi=300, bbox_inches='tight')
    
    plt.close(fig)
    plt.show()


def plot_critic_score_heatmap(gan, x_min , x_max, y_min, y_max, num_points=100, folder_name=None, name='WGAN (Optimizer: Adam)', t=None): 
    data_x = torch.linspace(x_min, x_max, num_points)
    data_y = torch.linspace(y_min, y_max, num_points)
    grid_x, grid_y = torch.meshgrid(data_x, data_y)
    grid = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=1).to(gan.device)
    with torch.no_grad():
        scores = gan.C(grid).cpu().numpy().reshape(num_points, num_points)
    plt.figure(figsize=(8, 6))
    plt.imshow(scores, extent=(x_min, x_max, y_min, y_max),
               origin='lower', aspect='auto', cmap='viridis')
    cbar = plt.colorbar(label='Critic Score')
    cbar.ax.tick_params(labelsize=sub_fontsize)
    plt.xlabel('X', fontsize=sub_fontsize)
    plt.ylabel('Y', fontsize=sub_fontsize)
    # title = f'Critic Score Heatmap of {remove_underscores(name)}'
    title = f'Critic Score of {remove_underscores(name)}'
    file_path = f"{folder_name}/{name_to_path(name)}_critic_score_heatmap.png"
    title, file_path = _add_t_to_title_and_path(title, file_path, t, name)
    plt.title(title, fontsize=title_fontsize)
    # plt.tight_layout(rect=[0.05, 0.12, 0.95, 0.92])  # Leave more space for the title
    plt.savefig(file_path)
    plt.close()
    plt.show()

def plot_critic_score_contour(gan, x_min, x_max, y_min, y_max, num_points=100, folder_name=None, name='WGAN (Optimizer: Adam)', t=None):
    data_x = torch.linspace(x_min, x_max, num_points)
    data_y = torch.linspace(y_min, y_max, num_points)
    grid_x, grid_y = torch.meshgrid(data_x, data_y)
    grid = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=1).to(gan.device)
    with torch.no_grad():
        scores = gan.C(grid).cpu().numpy().reshape(num_points, num_points)
    plt.figure(figsize=(8, 6))
    # Draw contour lines for equal critic scores
    contour = plt.contour(grid_x.cpu().numpy(), grid_y.cpu().numpy(), scores, levels=20, cmap='viridis')
    plt.clabel(contour, inline=True, fontsize=8)
    plt.xlabel('X', fontsize=sub_fontsize)
    plt.ylabel('Y', fontsize=sub_fontsize)
    cbar = plt.colorbar(contour, label='Critic Score')
    cbar.ax.tick_params(labelsize=sub_fontsize)
    # title = f'Critic Score Contour of {remove_underscores(name)}'
    title = f'Critic Score of {remove_underscores(name)}'
    file_path = f"{folder_name}/{name_to_path(name)}_critic_score_contour.png"
    title, file_path = _add_t_to_title_and_path(title, file_path, t, name)
    plt.title(title, fontsize=title_fontsize)
    plt.tight_layout()
    # plt.tight_layout(rect=[0.05, 0.12, 0.95, 0.92])  # Adjust for vertical layout
    plt.savefig(file_path)
    plt.close()
    plt.show()

