import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.manifold import TSNE
import numpy as np
import os
import torch
from torch_geometric.data import Batch

def extract_vision_features(model, dataloader, device):
    """
    Extract feature embeddings from a multimodal model.
    """
    model.eval()
    all_features = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)

            features = model.extract_features(images)
            all_features.append(features.cpu().numpy())
            all_labels.extend(labels.numpy())
            
    return np.vstack(all_features), np.array(all_labels)

def extract_gnn_features(model, dataloader, device):
    """
    Extract features from a GNN model.
    """
    model=model.to(device)
    model.eval()
    all_features = []
    all_labels = []
    
    with torch.no_grad():
        for data in dataloader:
            data = data.to(device)

            features = model.extract_features(data)
            all_features.append(features.cpu().numpy())
            all_labels.extend(data.y.cpu().numpy())
            
    return np.vstack(all_features), np.array(all_labels)

def evaluate_with_class_metrics(preds, labels, class_names):
    """
    Compute per-class classification metrics.
    """
    report = classification_report(labels, preds, 
                                 target_names=class_names, 
                                 output_dict=True)
    
    cm = confusion_matrix(labels, preds)
    
    return report, cm

def visualize_confusion_matrix(cm, class_names, title, save_path):
    """
    Visualize the confusion matrix with annotations and styling.
    """
    plt.rcParams.update({
        'font.family': 'sans-serif',
        'font.weight': 'bold',
        'font.size': 14
    })
    
    cmap = sns.color_palette("Blues", as_cmap=True)
    fig, ax = plt.subplots(figsize=(12, 10), dpi=100)
    sns.heatmap(cm, annot=True, fmt='d', cmap=cmap,
                xticklabels=class_names, yticklabels=class_names,
                linewidths=0.5, linecolor='white', cbar_kws={"shrink": 0.8}, ax=ax)
    
    plt.title(title, fontsize=24, fontweight='bold', pad=20)
    plt.xlabel('Predicted Label', fontsize=20, fontweight='bold', labelpad=15)
    plt.ylabel('True Label', fontsize=20, fontweight='bold', labelpad=15)
    
    plt.xticks(fontsize=14, fontweight='bold', rotation=45, ha='right')
    plt.yticks(fontsize=14, fontweight='bold')
    
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(2)
        spine.set_color('black')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def create_tsne_visualization(features, labels, class_names, model_name, difficulty, save_path):
    """
    Create and save a t-SNE visualization of feature embeddings.
    """
    plt.rcParams.update({
        'font.family': 'sans-serif',
        'font.weight': 'bold',
        'font.size': 14
    })
    
    tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(features)-1))
    embeddings = tsne.fit_transform(features)
    
    model_name_mapping = {
        'gat': 'GAT',
        'gcn': 'GCN', 
        'gin': 'GIN',
        'gps': 'GPS',
        'resnet': 'ResNet',
        'swin': 'Swin',
        'vit': 'ViT',
        'convnext': 'ConvNeXtV2'
    }
    
    display_name = model_name_mapping.get(model_name.lower(), model_name)
    title = f"{display_name} - Feature Space ({difficulty})"
    
    tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(features)-1))
    embeddings = tsne.fit_transform(features)
    
    custom_colors = [
        '#D32F2F', '#00897B', '#2C3E50', '#F1C40F', '#303F9F',  
        '#26A69A', '#F57C00', '#6A1B9A', '#0288D1', '#689F38',  
        '#C62828', '#673AB7', '#D84315', '#37474F', '#E53935', 
        '#0277BD', '#EF6C00', '#6A1B9A', '#C2185B', '#00838F',  
        '#1565C0', '#E65100', '#2E7D32', '#0097A7', '#D84315',  
        '#AD1457', '#004D40', '#F9A825', '#880E4F', '#00695C'   
    ]
    
    if len(class_names) > len(custom_colors):
        additional_colors = sns.color_palette("husl", n_colors=len(class_names)-len(custom_colors))
        colors = custom_colors + additional_colors
    else:
        colors = custom_colors[:len(class_names)]
    
    fig, ax = plt.subplots(figsize=(12, 10))
    
    ax.set_facecolor('white')
    ax.grid(color='lightgray', linestyle='-', linewidth=1, alpha=1)
    
    for i, class_name in enumerate(class_names):
        indices = labels == i
        plt.scatter(embeddings[indices, 0], embeddings[indices, 1], 
                   label=class_name, alpha=0.85, s=100,
                   color=colors[i], edgecolors='white', linewidth=0.5)
    
    plt.title(title, fontsize=26, fontweight='bold', pad=20)
    plt.xlabel('t-SNE Dimension 1', fontsize=20, fontweight='bold', labelpad=15)
    plt.ylabel('t-SNE Dimension 2', fontsize=20, fontweight='bold', labelpad=15)
    
    legend = plt.legend(loc='best', fontsize=20, frameon=True,
                       facecolor='white', edgecolor='black', 
                       framealpha=0.9, title='Classes')
    legend.get_title().set_fontweight('bold')
    legend.get_title().set_fontsize(20)
    
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(2)
        spine.set_color('black')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def analyze_difficult_classes(report, class_names, save_dir):
    """
    Analyze the most difficult classes based on F1-score and visualize them.
    """ 
    plt.rcParams.update({
        'font.family': 'sans-serif',
        'font.weight': 'bold',
        'font.size': 14
    })
    
    f1_scores = {class_name: report[class_name]['f1-score'] for class_name in class_names}
    difficult_classes = sorted(f1_scores.items(), key=lambda x: x[1])
    
    fig, ax = plt.subplots(figsize=(14, 8))
    classes, scores = zip(*sorted(f1_scores.items(), key=lambda x: x[1], reverse=True))
    custom_palette = sns.color_palette("viridis", len(classes))
    bars = sns.barplot(x=list(classes), y=list(scores), palette=custom_palette, ax=ax)
    
    for i, bar in enumerate(bars.patches):
        bars.text(
            bar.get_x() + bar.get_width()/2.,
            bar.get_height() + 0.02,
            f'{scores[i]:.2f}',
            ha='center',
            fontsize=12,
            fontweight='bold',
            color='black'
        )
    
    plt.title('F1 Score by Class', fontsize=24, fontweight='bold', pad=20)
    plt.xlabel('Class', fontsize=20, fontweight='bold', labelpad=15)
    plt.ylabel('F1 Score', fontsize=20, fontweight='bold', labelpad=15)
    
    ax.yaxis.grid(color='gray', linestyle='--', linewidth=0.7, alpha=0.3)
    
    plt.xticks(rotation=45, ha='right', fontsize=14, fontweight='bold')
    plt.yticks(fontsize=14, fontweight='bold')
    
    plt.ylim(0, max(scores) + 0.1)
    
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(2)
        spine.set_color('black')
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'f1_scores_by_class.pdf'), dpi=300, bbox_inches='tight')
    plt.close()
    
    return difficult_classes