import ot
import os
import torch
import numpy as np
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import matplotlib.tri as tri
import seaborn as sns
import torch.nn.functional as F
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
# specify device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def show_img(image, labels):
    grid = torchvision.utils.make_grid(image, nrow=3)
    plt.figure(figsize=(7, 7))
    plt.imshow(np.transpose(grid, (1, 2, 0)))
    print('labels: ', labels)


def load_imgnet_valdata(data_dir, preprocess, cal_ratio, batch_size, seed):
    # load dataset
    val_dataset = datasets.ImageFolder(root=data_dir, transform=preprocess)
    # split data and load
    cal_size = int(cal_ratio * len(val_dataset))
    test_size = len(val_dataset) - cal_size
    generator = torch.Generator().manual_seed(seed)
    cal_dataset, test_dataset = random_split(val_dataset, [cal_size, test_size], generator=generator)
    cal_loader = DataLoader(cal_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return cal_loader, test_loader


def load_mnist_valdata(data_dir, preprocess, cal_ratio, batch_size, seed):
    # load dataset
    val_dataset = datasets.MNIST(root=data_dir, train=False, download=False, transform=preprocess)
    # split data and load
    cal_size = int(cal_ratio * len(val_dataset))
    test_size = len(val_dataset) - cal_size
    generator = torch.Generator().manual_seed(seed)
    cal_dataset, test_dataset = random_split(val_dataset, [cal_size, test_size], generator=generator)
    cal_loader = DataLoader(cal_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return cal_loader, test_loader


def load_cifar100_valdata(data_dir, preprocess, cal_ratio, batch_size, seed):
    # load dataset
    val_dataset = datasets.CIFAR100(root=data_dir, train=False, download=False, transform=preprocess)
    # split data and load
    cal_size = int(cal_ratio * len(val_dataset))
    test_size = len(val_dataset) - cal_size
    generator = torch.Generator().manual_seed(seed)
    cal_dataset, test_dataset = random_split(val_dataset, [cal_size, test_size], generator=generator)
    cal_loader = DataLoader(cal_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return cal_loader, test_loader


def load_wild_calib(id_data, batch_size, seed, cal_size=2000):
    test_size = len(id_data) - cal_size
    generator = torch.Generator().manual_seed(seed)
    cal_dataset, _ = random_split(id_data, [cal_size, test_size], generator=generator)
    cal_loader = DataLoader(cal_dataset, batch_size=batch_size, shuffle=False)
    return cal_loader


def nll_score(model, features, labels=None):
    with torch.no_grad():
        outputs = model(features.to(device))
        return -F.log_softmax(outputs, dim=1)    # shape (n_images, n_labels)
    
    
def indicator_cost_plan(samples1, samples2, epsilon):
    n = len(samples1)
    m = len(samples2)
    a = np.ones(n) / n  # uniform weights for source
    b = np.ones(m) / m  # uniform weights for target
    # Cost matrix: indicator(|x - y| >= epsilon)
    x = samples1[:, np.newaxis]
    y = samples2[np.newaxis, :]
    cost_matrix = (np.abs(x - y) >= epsilon).astype(float)
    # Compute optimal transport plan
    # transport_plan = ot.emd(a, b, cost_matrix)
    transport_plan = ot.sinkhorn(a, b, cost_matrix, reg=0.05)
    # Total cost
    total_cost = np.sum(transport_plan * cost_matrix)
    return total_cost, transport_plan, cost_matrix


def perturb_test_data(features, labels, corrupt_ratio, noise_upper=1., noise_lower=-1., worst_case=False):
    n_ex = labels.shape[0]
    # add uniform noise
    if worst_case is True:
        max_noise_level = np.max(np.abs((noise_upper, noise_lower)))
        noise = torch.where(torch.rand_like(features) > 0.5, max_noise_level, -max_noise_level)
    else:
        noise = (noise_lower - noise_upper) * torch.rand(features.size()) + noise_upper
    noised_features = features + noise

    # corrupt labels
    if corrupt_ratio == 0.:
        perturbed_labels = labels
    else:    # roll the labels
        perturbed_labels = torch.clone(labels)
        pert_idx = np.random.choice(n_ex, int(corrupt_ratio * n_ex), replace=False)
        vals_perturb = perturbed_labels[pert_idx]
        vals_perturb = torch.roll(vals_perturb, 1)
        perturbed_labels[pert_idx] = vals_perturb
    return noised_features, perturbed_labels


def perturb_test_scores(tst_scores, corrupt_ratio, noise_upper=1., noise_lower=-1., worst_case=False):
    n_ex = tst_scores.shape[0]
    # corruption
    if corrupt_ratio == 0.:
        tst_scores = tst_scores
    else:
        perturbed_tstscores = torch.clone(tst_scores)
        pert_idx = np.random.choice(n_ex, int(corrupt_ratio * n_ex), replace=False)
        scores_perturb = perturbed_tstscores[pert_idx]
        scores_perturb = torch.roll(scores_perturb, 1)
        perturbed_tstscores[pert_idx] = scores_perturb

    # add uniform noise
    if worst_case is True:
        max_noise_level = np.max(np.abs((noise_upper, noise_lower)))
        noise = torch.where(torch.rand_like(tst_scores) > 0.5, max_noise_level, -max_noise_level)
    else:
        noise = (noise_lower - noise_upper) * torch.rand(tst_scores.size()) + noise_upper
    
    return perturbed_tstscores + noise


# def plot_cp(data, plt_type, plt_name, alpha=0.1, save_dir=None, 
#             group_labels=['SC', '$LP_\epsilon$', '$LP^{est}_\epsilon$', '$\chi^2$']):
def plot_cp(data, plt_type, plt_name, alpha=0.1, save_dir=None, 
            group_labels=['SC', '$LP_\epsilon$']):
    
    # colors = ['#1f77b4', '#dc143c', '#ffff00', '#2ca02c']
    colors = ['#1f77b4', '#dc143c']

    # plt.figure(figsize=(4, 2.8))
    plt.figure(figsize=(2, 2.8))
    for i, group_data in enumerate(data):
        group_data = np.array(group_data)
        x_center = i
        jitter = (np.random.rand(len(group_data)) - 0.5) * 0.4  
        x_vals = x_center + jitter

        col = colors[i % len(colors)]

        # Scatter for each group
        plt.scatter(x_vals, group_data,
                    color=col,
                    alpha=0.5,
                    edgecolor='white',
                    s=100)

        # Draw a horizontal line for the mean
        mean_val = np.mean(group_data)
        plt.hlines(y=mean_val,
                   xmin=x_center - 0.4, 
                   xmax=x_center + 0.4,
                   color=col,
                   linewidth=3)

    plt.xticks(range(len(data)), group_labels, fontsize=11)
    plt.title(plt_type, fontsize=15)
    if plt_type == 'Coverage':
        plt.axhline(y=1 - alpha, color='darkred', linestyle='-', alpha=0.9, linewidth=2)

    plt.tight_layout()
    os.makedirs(save_dir, exist_ok=True)    
    plt.savefig(os.path.join(save_dir, plt_name), dpi=300, bbox_inches='tight')
    plt.show()
    
    
def eps_rho_plot(arr,
                 plt_type='Coverage',
                 scatter_points=True, 
                 levels=50, 
                 style='darkgrid',
                 context='talk',
                 figsize=(8, 6),
                 point_size=70,
                 alpha=0.8,
                 highlight_val=0.9, 
                 savefig_path=None):

    # Set a seaborn style/theme
    sns.set_style(style)
    sns.set_context(context)
    
    if plt_type == 'Coverage':
        palette = 'rocket'
    else:
        palette='mako'
    cmap = sns.color_palette(palette, as_cmap=True)

    # Extract val, x, and y from arr
    val = arr[:, 0]
    x = arr[:, 1]
    y = arr[:, 2]

    # Create a Triangulation object (handles irregularly spaced points)
    triang = tri.Triangulation(x, y)

    # Create a figure and axis
    fig, ax = plt.subplots(figsize=figsize)

    # Plot a filled contour using tricontourf
    contour_f = ax.tricontourf(triang, val, levels=levels, cmap=cmap)

    # Optionally overlay the scatter points (using seaborn for style)
    if scatter_points:
        # Use seaborn scatterplot for a bit more styling
        sns.scatterplot(x=x, y=y, hue=val, palette=palette, 
                        alpha=alpha, edgecolor='white', 
                        s=point_size, ax=ax, legend=False)

    # Highlight contour at highlight_val
    highlight_level = [highlight_val]
    highlight_contours = ax.tricontour(triang, val, levels=highlight_level,
                                       colors='black', linewidths=2, linestyles='--')
    if plt_type == 'Coverage':
        ax.clabel(highlight_contours, inline=True, 
                  fmt={highlight_val: f'val={highlight_val}'})

    # Add a colorbar to show the val scale
    cbar = fig.colorbar(contour_f, ax=ax)
    cbar.set_label(plt_type)

    # Add labels/title
    ax.set_xlabel(r'$\epsilon$', fontsize=25)
    ax.set_ylabel(r'$\rho$', fontsize=25)

    # Adjust layout
    plt.tight_layout()
    
    # Save the figure
    if savefig_path is not None:
        fig.savefig(savefig_path, dpi=300, bbox_inches='tight')
        print(f"Figure saved to {savefig_path}")

    return fig, ax

    
def eps_rho_coverageplot(arr,
                         scatter_points=True, 
                         levels=50, 
                         palette='rocket', 
                         style='darkgrid',
                         context='talk',
                         figsize=(8, 6),
                         point_size=70,
                         alpha=0.8,
                         highlight_val=0.9, 
                         savefig_path=None):

    # Set a seaborn style/theme
    sns.set_style(style)
    sns.set_context(context)
    
    cmap = sns.color_palette(palette, as_cmap=True)

    # Extract val, x, and y from arr
    val = arr[:, 0]
    x = arr[:, 1]
    y = arr[:, 2]

    # Create a Triangulation object (handles irregularly spaced points)
    triang = tri.Triangulation(x, y)

    # Create a figure and axis
    fig, ax = plt.subplots(figsize=figsize)

    # Plot a filled contour using tricontourf
    contour_f = ax.tricontourf(triang, val, levels=levels, cmap=cmap)

    # Optionally overlay the scatter points (using seaborn for style)
    if scatter_points:
        # Use seaborn scatterplot for a bit more styling
        sns.scatterplot(x=x, y=y, hue=val, palette=palette, 
                        alpha=alpha, edgecolor='white', 
                        s=point_size, ax=ax, legend=False)

    # Highlight contour at highlight_val (e.g., val=0.9)
    highlight_level = [highlight_val]
    highlight_contours = ax.tricontour(triang, val, levels=highlight_level,
                                       colors='black', linewidths=2, linestyles='--')
    ax.clabel(highlight_contours, inline=True, 
              fmt={highlight_val: f'val={highlight_val}'})

    # Add a colorbar to show the val scale
    cbar = fig.colorbar(contour_f, ax=ax)
    cbar.set_label('Coverage')

    # Add labels/title
    ax.set_xlabel(r'$\epsilon$', fontsize=25)
    ax.set_ylabel(r'$\rho$', fontsize=25)

    # Adjust layout
    plt.tight_layout()
    
    # Save the figure
    if savefig_path is not None:
        fig.savefig(savefig_path, dpi=300, bbox_inches='tight')
        print(f"Figure saved to {savefig_path}")

    return fig, ax


def eps_rho_sizeplot(arr, 
                     scatter_points=True, 
                     levels=50, 
                     cmap='viridis', 
                     figsize=(8, 6), 
                     point_size=70,
                     alpha=0.8, 
                     highlight_val=None,
                     savefig_path=None):

    # Extract val, x, and y
    val = arr[:, 0]
    x = arr[:, 1]
    y = arr[:, 2]

    # Create a Triangulation object (handles irregularly spaced points)
    triang = tri.Triangulation(x, y)

    # Create a figure and axis
    fig, ax = plt.subplots(figsize=figsize)

    # Plot a filled contour using tricontourf
    contour_f = ax.tricontourf(triang, val, levels=levels, cmap=cmap)

    # Optionally overlay the scatter for actual data points
    if scatter_points:
        sc = ax.scatter(x, y, c=val, cmap=cmap, edgecolor='white', 
                        s=point_size, alpha=alpha)
        
    # Highlight contour at highlight_val (e.g., val=0.9)
    highlight_level = [highlight_val-0.03]
    highlight_contours = ax.tricontour(triang, val, levels=highlight_level,
                                       colors='black', linewidths=2, linestyles='--')
    
    # Add a colorbar to show the val scale
    cbar = fig.colorbar(contour_f, ax=ax)
    cbar.set_label('Size')

    # Add labels/title
    ax.set_xlabel(r'$\epsilon$', fontsize=25)
    ax.set_ylabel(r'$\rho$', fontsize=25)
    
    # Make the plot a bit more visually appealing
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    
    if savefig_path is not None:
        plt.savefig(savefig_path, dpi=300, bbox_inches='tight')
        print(f"Figure saved to {savefig_path}")

    return fig, ax
