import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import sys, os
project_root = os.path.abspath(os.getcwd())
sys.path.append(project_root)

import torch
import torch.nn.functional as F
from matplotlib.lines import Line2D
from sklearn.manifold import TSNE

def visualize_segmentation(image, foreground_intensity, save_path, patch_size):
    """Visualize segmentation results by overlaying a heatmap based on probability intensity.
    
    Args:
        image: Input image (PIL Image or numpy array)
        foreground_intensity: Foreground probability intensity (numpy array or tensor)
        save_path: Path to save the visualization result
        patch_size: Size of patches in the model
    
    Returns:
        None, saves the visualization to the specified path
    """
    # Convert PIL image to numpy array if needed
    if isinstance(image, Image.Image):
        # First resize then center crop, matching CLIP model preprocessing
        resize_size = 256  # Consistent with CLIP_utils
        image_size = 224
        
        # Resize while maintaining aspect ratio
        width, height = image.size
        scale = resize_size / max(width, height)
        new_size = (int(width * scale), int(height * scale))
        image_resized = image.resize(new_size, resample=Image.BICUBIC)
        
        # Then center crop
        left = (new_size[0] - image_size) // 2
        top = (new_size[1] - image_size) // 2
        right = left + image_size
        bottom = top + image_size
        image_cropped = image_resized.crop((left, top, right, bottom))
        
        image_np = np.array(image_cropped)
    else:
        # If already a numpy array, use the same processing
        img_pil = Image.fromarray(image)
        
        # Resize while maintaining aspect ratio
        resize_size = 256
        image_size = 224
        width, height = img_pil.size
        scale = resize_size / max(width, height)
        new_size = (int(width * scale), int(height * scale))
        image_resized = img_pil.resize(new_size, resample=Image.BICUBIC)
        
        # Then center crop
        left = (new_size[0] - image_size) // 2
        top = (new_size[1] - image_size) // 2
        right = left + image_size
        bottom = top + image_size
        image_cropped = image_resized.crop((left, top, right, bottom))
        
        image_np = np.array(image_cropped)
    
    # Convert foreground intensity to PyTorch tensor for processing
    if isinstance(foreground_intensity, np.ndarray):
        intensity_tensor = torch.from_numpy(foreground_intensity).float()
    else:
        intensity_tensor = foreground_intensity.float()
    
    # Ensure intensity is a 2D tensor (e.g., 7x7)
    if intensity_tensor.dim() == 1:
        size = int(np.sqrt(intensity_tensor.shape[0]))
        intensity_tensor = intensity_tensor.reshape(size, size)
    
    # # Ensure attention tensor is normalized (consistent with compute_prs_image_attention.py)
    # if intensity_tensor.shape[-1] > 1:  # If it's a feature vector
    #     intensity_tensor = intensity_tensor / intensity_tensor.norm(dim=-1, keepdim=True)
    
    # Add batch and channel dimensions for interpolation
    intensity_tensor = intensity_tensor.unsqueeze(0).unsqueeze(0)
    
    # Use bilinear interpolation to resize attention map to image size
    intensity_mask = F.interpolate(
        intensity_tensor,
        size=(224, 224),
        mode="bilinear",
        align_corners=False
    )
    
    # Get attention mask
    intensity_mask = intensity_mask[0, 0].cpu().numpy()
    
    # Normalize to 0-1 range
    if intensity_mask.max() > 0:
        intensity_mask = (intensity_mask - intensity_mask.min()) / (intensity_mask.max() - intensity_mask.min())
    
    # Create heatmap using blue to red gradient color scheme
    cmap = plt.cm.jet  # Use jet colormap, blue to red gradient
    heatmap = cmap(intensity_mask)[:, :, :3]  # Only take RGB channels
    
    # Overlay heatmap on original image, only apply to non-zero regions
    overlay = image_np.copy()
    mask = intensity_mask > 0  # Only apply to non-zero regions
    alpha = 0.7  # Transparency
    
    # Apply mask
    for c in range(3):  # RGB channels
        overlay[:, :, c] = np.where(
            mask, 
            (alpha * (heatmap[:, :, c] * 255) + (1 - alpha) * image_np[:, :, c]).astype(np.uint8),
            image_np[:, :, c]
        )
    
    # Plot and save
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.axis('off')
    img_plot = ax.imshow(overlay)
    
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
    plt.close()
    
def visual_segmentation_process(path, sample_foreground, save_path=f"text.png"):
    """Process and visualize segmentation for an image.
    
    Args:
        path: Path to the input image
        sample_foreground: Foreground intensity data for segmentation
        save_path: Path to save the visualization (default: "text.png")
        
    Returns:
        save_path: Path where the visualization was saved
    """
    sample_image = Image.open(path)
    
    visualize_segmentation(sample_image, sample_foreground, save_path, patch_size=14)
    return save_path

def visualize_bar(class_pool, val, head_num, tail_num, color='blue', save_path="others/prompt/bar.png", type=None):
    """Visualize bar charts for the top and bottom classes based on values.
    
    Creates a split bar chart showing the highest and lowest scoring classes.
    
    Args:
        class_pool: List of class names
        val: Tensor of values for each class
        head_num: Number of top classes to display
        tail_num: Number of bottom classes to display
        color: Bar color (default: 'blue')
        save_path: Path to save the visualization
        type: Type of visualization ("projection_norm" or "cos_sim")
        
    Returns:
        None, saves the visualization to the specified path
    """
    if type == "projection_norm":
        head_exclude = ["wool","mountain"]
        tail_exclude = ["limpkin"]
    elif type == "cos_sim":
        head_exclude = ["wool"]
        tail_exclude = ["limpkin"]

    label_dic = {"projection_norm":"Projection Norm", "cos_sim":"Cosine Similarity"}
    
    top_classes = torch.argsort(val, dim=-1, descending=True)
    head = top_classes[:head_num]
    tail = top_classes[-tail_num:]

    list1 = [i.item() for i in head]
    list2 = [i.item() for i in tail]

    list1 = [i for i in list1 if class_pool[i] not in head_exclude]
    list2 = [i for i in list2 if class_pool[i] not in tail_exclude]

    class_names_head = [class_pool[i] for i in list1]
    class_values_head = [round(val[i].item(), 2) for i in list1]

    class_names_tail = [class_pool[i] for i in list2]
    class_values_tail = [round(val[i].item(), 2) for i in list2]

    x_values = np.arange(len(class_names_head)+len(class_names_tail))
    label = label_dic[type]
    # Create two subplots with shared y-axis
    fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True, gridspec_kw={'width_ratios': [len(class_names_head), len(class_names_tail)]}, figsize=(6, 2), dpi=300)
    ax1.bar(x_values[:len(class_names_head)], class_values_head, color=color)
    ax1.set_xticks(x_values[:len(class_names_head)])
    ax1.set_xticklabels(class_names_head, rotation=45, ha="right", fontsize=7)
    # ax1.set_ylabel(label, fontsize=9, fontweight="bold")
    
    
    
    ax2.bar(x_values[len(class_names_head):], class_values_tail, color=color,label=label)
    ax2.set_xticks(x_values[len(class_names_head):])
    ax2.set_xticklabels(class_names_tail, rotation=45, ha="right", fontsize=7)

    # ax2.legend()
    # Hide ticks at the break point
    ax1.xaxis.tick_bottom()
    ax2.xaxis.tick_bottom()

    if type == "projection_norm":
        ax1.set_ylim(1.5, 8)
        ax2.set_ylim(1.5, 8)
    elif type == "cos_sim":
        ax1.set_ylim(0.7, 1)
        ax2.set_ylim(0.7, 1)

    ax1.spines['right'].set_visible(False)
    ax2.spines['left'].set_visible(False)

    ax1.tick_params(axis='y', labelright=False, right=False, labelsize=7)  # Don't show tick labels on the right side
    ax2.tick_params(axis='y', labelleft=False, left=False, labelright=False, labelsize=7)   # Don't show tick labels on the left side

    # Add diagonal break lines
    kwargs = dict(marker=[(-1, -1), (1, 1)], markersize=8, linestyle='none', color='k', mec='k', mew=0.5, clip_on=False)
    ax1.plot([1], [0], transform=ax1.transAxes, **kwargs)
    ax2.plot([0], [0], transform=ax2.transAxes, **kwargs)

    # Adjust spacing between subplots
    fig.subplots_adjust(wspace=0.05)
    # fig.supxlabel("ClassName", fontsize=8, ha="right",x=1)
    # fig.supylabel("Value", fontsize=8, x=0.08)


    ax1.spines['top'].set_visible(False)
    ax2.spines['top'].set_visible(False)

    ax1.spines['bottom'].set_linewidth(0.5)
    ax1.spines['left'].set_linewidth(0.5)
    ax2.spines['bottom'].set_linewidth(0.5)
    ax2.spines['right'].set_linewidth(0.5)
    ax1.tick_params(axis='both', width=0.5)
    ax2.tick_params(axis='both', width=0.5)

    # Get bounding boxes of the two subplots (in figure coordinates)
    bbox1 = ax1.get_position()
    bbox2 = ax2.get_position()

    # Use top of subplot for y-coordinate
    y = bbox1.y1

    # Use left and right boundaries of the subplots for x-coordinates
    x0 = bbox1.x0
    x1 = bbox2.x1

    # Draw line
    line = Line2D([x0, x1], [y, y], color="black", linewidth=0.5, transform=fig.transFigure, clip_on=False, zorder=10)
    fig.add_artist(line)
    
    plt.savefig(save_path, bbox_inches='tight')
    plt.close()


def visualize_class_probility(labels_name, label_predict_probility, save_path):
    """Visualize class probabilities as a horizontal bar chart.
    
    Args:
        labels_name: List of class names
        label_predict_probility: Probabilities for each class
        save_path: Path to save the visualization
        
    Returns:
        None, saves the visualization to the specified path
    """
    # Create figure and axis
    fig, ax = plt.subplots(figsize=(3, 3), dpi=600)
    
    # Sort data
    sorted_indices = np.argsort(label_predict_probility)[::-1]
    sorted_labels = [labels_name[i] for i in sorted_indices]
    sorted_probs = [label_predict_probility[i] for i in sorted_indices]
    
    # Calculate number of items to display
    n_items = len(sorted_labels)
    
    # Create y-axis positions
    y_pos = np.arange(n_items)
    
    # Draw horizontal bar chart
    for i, (label, prob) in enumerate(zip(sorted_labels, sorted_probs)):
        color = '#90EE90' if prob > 10 else '#A6C8E8'
        # color = '#90EE90' if prob == np.max(sorted_probs) or (label=="tabby cat") else '#A6C8E8'
        
        ax.barh(i, prob, color=color, alpha=1.0)
        
        # Add labels and percentages
        percentage = prob
        ax.text(21, i, f"{label} {percentage:.2f}", va='center', ha='right', fontsize=17, color='black')  #20,14
    
    # Set axis range and appearance
    ax.set_ylim(n_items - 0.5, -0.5)  # Reverse y-axis to display large values at top
    # ax.set_xlim(0, max(sorted_probs) * 1.1)  # Add some space to prevent overlap
    ax.set_xlim(0, 25)
    
    # Remove axis lines and ticks
    ax.set_yticks([])
    ax.set_xticks([])
    
    # Add thin border
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_color('black')
        spine.set_linewidth(0.8)  # Set thicker border
    
    # Save figure
    plt.tight_layout()
    plt.savefig(save_path, dpi=600, bbox_inches='tight', pad_inches=0.1)
    plt.close()