import os
import os.path as osp
import numpy as np
from math import sqrt
from scipy.optimize import linear_sum_assignment as linear_assignment

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

from PIL import Image
import matplotlib.pyplot as plt

import torch
from torchvision import transforms as T
from torchvision.utils import make_grid

plt.rcParams["savefig.bbox"] = 'tight'

def load_img(path, device: str = 'cpu', is_main_process=True):
    image = Image.open(path).convert("RGB")
    # w, h = image.size
    # w, h = map(lambda x: x - x % 32, (w, h))  # resize to integer multiple of 32
    image = image.resize((512, 512), resample=Image.LANCZOS)
    image = np.array(image).astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image).to(device)
    return 2.0 * image - 1.0

# def load_image(image_path, device, forshow=False):
#     image = read_image(image_path)
#     image = image[:3].unsqueeze_(0).float() / 127.5 - 1.  # [-1, 1]
#     image = F.interpolate(image, (512, 512))
#     if forshow:
#         if image.shape[0] == 1: image = image[0]
#         if image.shape[0] == 3: image = np.transpose(image, (1, 2, 0))
#     image = image.to(device)
#     return image

def show(imgs: torch.Tensor):
    grid_img = make_grid(imgs, ncol=imgs.shape[0])
    plt.imshow(grid_img.permute(1, 2, 0).cpu())

def visualize_and_save_features_pca(feature_maps_fit_data,
                                    feature_maps_transform_data, 
                                    transform_experiment, 
                                    t, 
                                    b, 
                                    save_dir, 
                                    save_opt='t'):
    feature_maps_fit_data = feature_maps_fit_data.cpu().numpy()
    pca = PCA(n_components=3)
    pca.fit(feature_maps_fit_data)
    feature_maps_pca = pca.transform(feature_maps_transform_data.cpu().numpy())  # N X 3
    pca_img = feature_maps_pca.reshape(-1, 3)  # B x (H * W) x 3
    h = w = int(sqrt(pca_img.shape[0]))
    pca_img = pca_img.reshape(h, w, 3)
    pca_img_min = pca_img.min(axis=(0, 1))
    pca_img_max = pca_img.max(axis=(0, 1))
    pca_img = (pca_img - pca_img_min) / (pca_img_max - pca_img_min)
    pca_img = Image.fromarray((pca_img * 255).astype(np.uint8))
    pca_img = T.Resize(512, interpolation=T.InterpolationMode.NEAREST)(pca_img)
    if save_opt == 't':
        pca_img.save(os.path.join(save_dir, f"{transform_experiment}_time_{t}.png"))
    elif save_opt == 'b':
        pca_img.save(os.path.join(save_dir, f"{transform_experiment}_{b}.png"))
    else:
        pass


def numpy_to_pil(images):
    """
    Convert a numpy image or a batch of images to a PIL image.
    """
    if images.ndim == 3:
        images = images[None, ...]
    images = (images * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]

    return pil_images


def transpose_img(img):
    img = img.cpu()
    if img.shape[0] == 1: img = img[0]
    if img.shape[0] == 3: img = np.transpose(img, (1, 2, 0))
    return img

def show_images(images, num_cols=2, figsize=(8, 4), titles: str = []):
    """
    Display multiple images in a grid.

    Parameters:
    - images: NumPy array of shape (n, 3, h, w) where n is the number of images.
    - num_cols: Number of columns in the grid (default is 4).
    - figsize: Figure size for Matplotlib (default is (12, 9)).

    Example:
    - show_images(images, num_cols=3)
    """
    images = images.cpu()
    n, c, h, w = images.shape
    num_rows = (n + num_cols - 1) // num_cols  # Calculate the number of rows needed
    if len(titles) > 0:
        assert len(titles) == num_rows, 'The length of titles is inconsistent with the number of rows.'

    # Create a figure with subplots
    figsize = (figsize[0], figsize[1]*num_rows)
    fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    
    # Flatten the axes array in case there's only one row or column
    if num_rows == 1:
        axes = axes.reshape(1, -1)
    if num_cols == 1:
        axes = axes.reshape(-1, 1)
    
    for i in range(n):
        row_idx = i // num_cols
        col_idx = i % num_cols
        ax = axes[row_idx, col_idx]
        if len(titles) > 0:
            ax.set_title(titles[row_idx])
        ax.imshow(np.transpose(images[i], (1, 2, 0)))  # Transpose for RGB display
        ax.axis('off')  # Turn off axis labels and ticks
    
    # Remove any empty subplots
    for i in range(n, num_rows * num_cols):
        row_idx = i // num_cols
        col_idx = i % num_cols
        fig.delaxes(axes[row_idx, col_idx])
    
    plt.subplots_adjust(wspace=0.1, hspace=0.1)  # Adjust spacing between subplots
    plt.show()
    

def embedding(feats, labels, n_pca=10):
    assert feats.shape[0] == labels.shape[0]
    assert feats.shape[0] > 0

    # feature = feats.cpu().detach().numpy()
    # label = labels.cpu().detach().numpy()
    # Using PCA to reduce dimension to a reasonable dimension as recommended in
    # https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html
    feats = PCA(n_components=n_pca).fit_transform(feats)
    feature_embedded = TSNE(n_components=2).fit_transform(feats)
    return feature_embedded, labels
    # print(f"feature shape: {feature.shape}")

import os.path as osp
def plot_features(features, labels, num_classes, epoch, dirname, legend=True, add_circle=False, **kwargs):
    colors = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9']
    markers = ['o', 's', '^', 'v', 'D', 'p', '*', 'X', '+', 'H']
    if len(labels) > len(colors):
        colors = colors * ((len(np.unique(labels)) // len(colors)) + 1)
        markers = markers * ((len(np.unique(labels)) // len(markers)) + 1)
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
    for label_idx in range(num_classes):
        if add_circle:
            class_instances = features[labels == label_idx]
            # Calculate center of mass for the instances of this class
            center_x = np.mean(class_instances[:, 0])
            center_y = np.mean(class_instances[:, 1])

            # Calculate radius by considering the distance from center to farthest instance
            distances = [np.linalg.norm(item - np.array([[center_x, center_y]]), axis=1) for item in class_instances]
            radius = np.quantile(distances, 0.9)

            # Draw a circle for the boundary of the class instances
            boundary_circle = plt.Circle((center_x, center_y), radius, color=colors[label_idx], fill=False, alpha=0.3)
            ax.add_patch(boundary_circle)  # Add boundary circle as a patch

            # Draw a circle for the filled background of the class instances
            filled_circle = plt.Circle((center_x, center_y), radius, color=colors[label_idx], fill=True, alpha=0.05)
            ax.add_patch(filled_circle)  # Add filled circle as a patch
            
        ax.scatter(
            features[labels == label_idx, 0],
            features[labels == label_idx, 1],
            c=colors[label_idx],
            marker=markers[label_idx],
            s=40,  # Marker size
            label=str(label_idx),  # Use label as legend entry
            alpha=0.6  # Adjust transparency for better visualization
        )
    ax.set_xticks([])  # Remove x-axis ticks and labels
    ax.set_yticks([])  # Remove y-axis ticks and labels
    if legend:
        legend = ax.legend(loc='lower center', title='Classes',  bbox_to_anchor=(0.5, -0.15), ncol=num_classes)
    # for handle in legend.legendHandles:
    #     handle.set_sizes([40])  # Set size of legend markers
    # for artist in legend.legend_handles:
    #         if isinstance(artist, plt.Line2D):  # Check if the artist is a scatter plot handle
    #             artist.set_marker(markers[int(artist.get_label())])  # Set marker style
    #             artist.set_markersize(10)  # Set marker size
    if not osp.exists(dirname):
        os.makedirs(dirname)
    if kwargs.get('title', None):
        plt.title(kwargs['title'])
    filename = kwargs['filename'] if kwargs.get('filename', None) else f'epoch_{epoch}'
    save_name = osp.join(dirname, f'{filename}.png')
    print(f'Pic saved to {save_name}.')
    plt.savefig(save_name, bbox_inches='tight')
    plt.show()
    plt.close()
    


# Function to align true and predicted labels using the Hungarian algorithm
def align_predictions(y_true, y_pred):
    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros((D, D), dtype=np.int64)
    for i in range(y_pred.size):
        w[y_pred[i], y_true[i]] += 1
    ind = linear_assignment(w.max() - w)
    ind = np.vstack(ind).T
    return [i for i, j in ind], [j for i, j in ind]


# Function to calculate class-wise accuracy
def classwise_accuracy(y_true, y_pred):
    aligned_pred_indices, aligned_true_indices = align_predictions(y_true, y_pred)
    accuracies = []
    for i, j in zip(aligned_pred_indices, aligned_true_indices):
        # Calculate accuracy only for the aligned pairs
        accuracies.append((y_pred[y_true == j] == i).mean())
    return accuracies

# Function to plot accuracies and distributions
def classwise_plot(y_true, y_pred, num_classes, disable_title=False, legend_fontsize=14, only_distribution=False):
    accuracies = classwise_accuracy(y_true, y_pred)
    classes = np.arange(num_classes)
        
    # Visualization
    if not only_distribution:
        fig, axes = plt.subplots(ncols=2, figsize=(10, 3))
        axes[0].bar(classes, accuracies, color='skyblue')
        axes[0].set_xlabel('Classes')
        axes[0].set_ylabel('Accuracy')
        axes[0].set_ylim(0, 1)  # Accuracies range from 0 to 1
        if not disable_title:
            axes[0].set_title('Accuracy for each class')
    else:
        fig, axes = plt.subplots(ncols=1, figsize=(5, 3))
        
    # Plot settings
    bins = np.arange(-0.5, num_classes, 1)  # fixed bin size

    # Align predictions to true labels
    aligned_pred_indices, aligned_true_indices = align_predictions(y_true, y_pred)
    aligned_preds = np.zeros_like(y_pred)
    for i, j in zip(aligned_pred_indices, aligned_true_indices):
        aligned_preds[y_pred == i] = j

    # Combined histogram for predictions and targets
    if getattr(axes, '__len__', None) is None: dist_ax = axes
    elif len(axes) == 2: dist_ax = axes[1]
    dist_ax.hist(aligned_preds, bins=bins, color='skyblue', alpha=0.6, label='Aligned Predictions')
    dist_ax.hist(y_true, bins=bins, color='salmon', alpha=0.5, label='Targets')
    dist_ax.set_xlabel('Class', fontsize=16)
    dist_ax.set_ylabel('Frequency', fontsize=16)
    dist_ax.legend(loc='upper left', fontsize=legend_fontsize)
    dist_ax.grid(True)
    if not disable_title:
        dist_ax.set_title('True vs Predicted Class Distribution', fontsize=16)
    
    plt.tight_layout()
    plt.show()

    return fig, axes

def plot_class_prediction_distribution(y_true, y_pred, class_info, ds_name=None, method_name=''):
    # Align predictions to true labels using Hungarian algorithm
    aligned_pred_indices, aligned_true_indices = align_predictions(y_true, y_pred)
    aligned_preds = y_pred.copy()
    for pred_idx, true_idx in zip(aligned_pred_indices, aligned_true_indices):
        aligned_preds[y_pred == pred_idx] = true_idx

    # Extract known and novel class indices
    known_classes = class_info['known_classes']
    novel_classes = class_info['unknown_classes']  # Assuming 'unknown_classes' refers to novel classes
    
    # Count the instances of each class in aligned predictions
    known_pred_counts = {cls: np.sum(aligned_preds == cls) for cls in known_classes}
    novel_pred_counts = {cls: np.sum(aligned_preds == cls) for cls in novel_classes}
    
    # Sort the known and novel classes by prediction count
    sorted_known_classes = [cls for cls, _ in sorted(known_pred_counts.items(), key=lambda item: item[1], reverse=True)]
    sorted_novel_classes = [cls for cls, _ in sorted(novel_pred_counts.items(), key=lambda item: item[1], reverse=True)]

    # Create mappings for known and novel class indices to sorted indices
    known_mapping = {cls: i for i, cls in enumerate(sorted_known_classes)}
    novel_mapping = {cls: i + len(known_classes) for i, cls in enumerate(sorted_novel_classes)}
    combined_mapping = {**known_mapping, **novel_mapping}

    # Remap y_true and aligned_preds to sorted indices
    remapped_y_true = np.vectorize(combined_mapping.get)(y_true)
    remapped_aligned_preds = np.vectorize(combined_mapping.get)(aligned_preds)
    
    # Visualization
    fig, ax = plt.subplots(figsize=(10, 5))

    # Bins for histogram - one more than the number of classes
    bins = np.arange(-0.5, len(known_classes) + len(novel_classes), 1)

    # Histogram for aligned predictions (separate known and novel)
    ax.hist(remapped_aligned_preds, bins=bins, color='skyblue', alpha=0.6, label=f'{method_name} Predictions')
    
    # Histogram for true classes (separate known and novel)
    ax.hist(remapped_y_true, bins=bins, color='salmon', alpha=0.5, label='GT')
    
    ax.set_xlabel('Class (sorted by prediction count within known and novel)', fontsize=26)
    ax.set_ylabel('Frequency', fontsize=25)
    title = f'Class Prediction Distribution ({ds_name})'
    ax.set_title(title, fontsize=26)
    ax.legend(loc='upper left', fontsize=26, ncols=2)
    ax.grid(True)
    
    # Custom x-ticks to show intervals
    # Get the current x-tick locations and labels
    locs, labels = plt.xticks()
    # Set new locations and labels
    new_locs = np.linspace(start=locs[0], stop=locs[-1], num=11)  # Adjust the number 11 based on the desired intervals
    new_labels = [f'{int(new_loc)}' for new_loc in new_locs]
    plt.xticks(ticks=new_locs, labels=new_labels, fontsize=22)  # Set new ticks and increase font size
    plt.xlim(-5, len(known_classes) + len(novel_classes) + 5)
    
    # Increase y-ticks font size
    plt.yticks(fontsize=22)

    plt.tight_layout()
    plt.show()

    return remapped_y_true, remapped_aligned_preds, fig


def naive_classwise_plot(preds, targets, num_classes):
    # Calculate accuracies for each class
    accuracies = [(preds[targets == i] == i).mean() for i in range(num_classes)]

    # Visualization
    fig, axes = plt.subplots(ncols=2, figsize=(10, 3))
    classes = np.arange(num_classes)
    axes[0].bar(classes, accuracies, color='skyblue')

    # Add some text for labels, title and custom x-axis tick labels, etc.
    axes[0].set_xlabel('Classes')
    axes[0].set_ylabel('Accuracy')
    axes[0].set_title('Accuracy for each class')
    # axes[0].set_xticks(classes)
    # axes[0].set_xticklabels(classes)
    axes[0].set_ylim(0, 1)  # Accuracies range from 0 to 1

    # # Adding the accuracy values on top of the bars
    # for i, v in enumerate(accuracies):
    #     axes[0].text(i, v + 0.02, "{:.2f}".format(v), ha='center', va='bottom')

    # Plot settings
    bins = np.arange(-0.5, num_classes, 1)  # fixed bin size

    # Combined histogram for predictions and targets
    axes[1].hist(preds, bins=bins, color='skyblue', alpha=0.6, label='Predictions')
    axes[1].hist(targets, bins=bins, color='salmon', alpha=0.5, label='Targets')

    axes[1].set_xlabel('Class', fontsize=26)
    axes[1].set_ylabel('Frequency', fontsize=16)
    axes[1].set_title('True vs Predicted Class Distribution', fontsize=16)
    axes[1].legend(loc='upper left', fontsize=16)
    axes[1].grid(True)
        
    plt.tight_layout()
    plt.show()
    
    
def calculate_class_coverage_rate(pred, targets, class_info):
    # Initialize dictionaries to hold the count of predictions and actual instances per class
    all_classes = np.concatenate([class_info['known_classes'], class_info['unknown_classes']])
    pred_count = {cls: 0 for cls in all_classes}
    target_count = {cls: 0 for cls in all_classes}
    
    # Count the predictions and targets per class
    for p in pred:
        if p in pred_count:
            pred_count[p] += 1
    for t in targets:
        if t in target_count:
            target_count[t] += 1

    # Calculate the coverage rate per class
    coverage_rates_pred, coverage_rates_target = {}, {}
    for cls in all_classes:
        if pred_count[cls] > 0:  # To prevent division by zero
            coverage_rates_pred[cls] = min(pred_count[cls], target_count[cls]) / pred_count[cls]
        else:
            coverage_rates_pred[cls] = 0
        
        if target_count[cls] > 0:  # To prevent division by zero
            coverage_rates_target[cls] = min(pred_count[cls], target_count[cls]) / target_count[cls]
        else:
            coverage_rates_target[cls] = 0

    # Calculate the average coverage rate for known and unknown classes separately
    known_rates = [(2 * coverage_rates_pred[cls] * coverage_rates_target[cls]) / (coverage_rates_pred[cls] + coverage_rates_target[cls]) for cls in class_info['known_classes']]
    unknown_rates = [(2 * coverage_rates_pred[cls] * coverage_rates_target[cls]) / (coverage_rates_pred[cls] + coverage_rates_target[cls]) for cls in class_info['unknown_classes']]
    
    # Avoid division by zero if there are no known or unknown classes
    known_average_coverage = sum(known_rates) / len(known_rates) if known_rates else 0
    unknown_average_coverage = sum(unknown_rates) / len(unknown_rates) if unknown_rates else 0
    
    return known_average_coverage, unknown_average_coverage