import matplotlib.pyplot as plt
import numpy as np
import torch
import networkx as nx
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns


def visualize_congestion_map(grid_congestion, grid_size, prediction=None, title=None, save_path=None):
    """
    Visualize grid-based congestion map.

    Args:
        grid_congestion (torch.Tensor or np.ndarray): Grid congestion values
        grid_size (tuple): Size of the grid (M, N)
        prediction (torch.Tensor or np.ndarray, optional): Predicted grid congestion values
        title (str, optional): Plot title
        save_path (str, optional): Path to save the plot

    Returns:
        matplotlib.figure.Figure: Figure object
    """
    # Convert to numpy if tensors
    if isinstance(grid_congestion, torch.Tensor):
        grid_congestion = grid_congestion.detach().cpu().numpy()

    if prediction is not None and isinstance(prediction, torch.Tensor):
        prediction = prediction.detach().cpu().numpy()

    M, N = grid_size

    # Reshape congestion values to grid
    ground_truth = grid_congestion.reshape(M, N)

    if prediction is not None:
        pred_grid = prediction.reshape(M, N)
        diff_grid = np.abs(ground_truth - pred_grid)

        # Create figure with 3 subplots
        fig, axes = plt.subplots(1, 3, figsize=(18, 6))

        # Plot ground truth
        im0 = axes[0].imshow(ground_truth, cmap='hot', interpolation='nearest')
        axes[0].set_title('Ground Truth Congestion')
        plt.colorbar(im0, ax=axes[0])

        # Plot prediction
        im1 = axes[1].imshow(pred_grid, cmap='hot', interpolation='nearest')
        axes[1].set_title('Predicted Congestion')
        plt.colorbar(im1, ax=axes[1])

        # Plot difference
        im2 = axes[2].imshow(diff_grid, cmap='viridis', interpolation='nearest')
        axes[2].set_title('Absolute Difference')
        plt.colorbar(im2, ax=axes[2])
    else:
        # Create figure with single plot
        fig, ax = plt.subplots(figsize=(10, 8))

        # Plot ground truth
        im = ax.imshow(ground_truth, cmap='hot', interpolation='nearest')
        ax.set_title('Congestion Map')
        plt.colorbar(im, ax=ax)

    if title:
        fig.suptitle(title, fontsize=16)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')

    return fig


def visualize_bottleneck_subgraph(cell_hypergraph, cell_probs, cell_congestion=None, threshold=0.5,
                                  title=None, save_path=None):
    """
    Visualize cell-based bottleneck subgraph.

    Args:
        cell_hypergraph (HeteroData): Cell-based hypergraph
        cell_probs (torch.Tensor): Cell node probabilities
        cell_congestion (torch.Tensor, optional): Cell congestion values
        threshold (float): Probability threshold for bottleneck nodes
        title (str, optional): Plot title
        save_path (str, optional): Path to save the plot

    Returns:
        matplotlib.figure.Figure: Figure object
    """
    # Convert to numpy if tensors
    if isinstance(cell_probs, torch.Tensor):
        cell_probs = cell_probs.detach().cpu().numpy()

    if cell_congestion is not None and isinstance(cell_congestion, torch.Tensor):
        cell_congestion = cell_congestion.detach().cpu().numpy()

    # Create networkx graph
    G = nx.Graph()

    # Add cell nodes
    num_cells = len(cell_probs)
    for i in range(num_cells):
        G.add_node(f'cell_{i}', type='cell', prob=cell_probs[i],
                   congestion=cell_congestion[i] if cell_congestion is not None else None)

    # Add net nodes and edges if available
    if ('cell', 'to', 'net') in cell_hypergraph.edge_types:
        edge_index = cell_hypergraph[('cell', 'to', 'net')].edge_index.cpu().numpy()
        num_nets = cell_hypergraph['net'].num_nodes

        for i in range(num_nets):
            G.add_node(f'net_{i}', type='net')

        for i in range(edge_index.shape[1]):
            cell_idx = edge_index[0, i]
            net_idx = edge_index[1, i]
            G.add_edge(f'cell_{cell_idx}', f'net_{net_idx}')

    # Create figure
    fig, ax = plt.subplots(figsize=(12, 10))

    # Define node positions using spring layout
    pos = nx.spring_layout(G, seed=42)

    # Define node colors based on probabilities and type
    bottleneck_mask = np.array([data['prob'] > threshold if 'prob' in data else False
                                for _, data in G.nodes(data=True)])

    cell_mask = np.array(['type' in data and data['type'] == 'cell' for _, data in G.nodes(data=True)])
    net_mask = np.array(['type' in data and data['type'] == 'net' for _, data in G.nodes(data=True)])

    bottleneck_cells = bottleneck_mask & cell_mask
    normal_cells = (~bottleneck_mask) & cell_mask

    # Create custom colormap for congestion if available
    if cell_congestion is not None:
        vmin = np.min(cell_congestion)
        vmax = np.max(cell_congestion)
        norm = plt.Normalize(vmin, vmax)

        cell_congestion_values = np.array([data.get('congestion', 0) if 'congestion' in data else 0
                                           for _, data in G.nodes(data=True)])

        cell_colors = plt.cm.viridis(norm(cell_congestion_values))
    else:
        cell_colors = np.array(['skyblue' if normal else 'red' for normal in normal_cells])

    # Draw nodes
    nx.draw_networkx_nodes(G, pos, nodelist=np.array(list(G.nodes()))[bottleneck_cells],
                           node_color='red', node_size=300, alpha=0.8, ax=ax)

    nx.draw_networkx_nodes(G, pos, nodelist=np.array(list(G.nodes()))[normal_cells],
                           node_color='skyblue', node_size=200, alpha=0.6, ax=ax)

    nx.draw_networkx_nodes(G, pos, nodelist=np.array(list(G.nodes()))[net_mask],
                           node_color='lightgreen', node_shape='s', node_size=100, alpha=0.8, ax=ax)

    # Draw edges
    nx.draw_networkx_edges(G, pos, width=1.0, alpha=0.5, ax=ax)

    # Add labels for bottleneck nodes
    labels = {node: node for node, data in G.nodes(data=True)
              if 'prob' in data and data['prob'] > threshold}
    nx.draw_networkx_labels(G, pos, labels=labels, font_size=8, ax=ax)

    # Add legend
    ax.plot([], [], 'ro', markersize=10, label='Bottleneck Cell')
    ax.plot([], [], 'o', color='skyblue', markersize=10, label='Normal Cell')
    ax.plot([], [], 's', color='lightgreen', markersize=10, label='Net')
    ax.legend()

    if title:
        ax.set_title(title, fontsize=16)

    ax.set_axis_off()
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')

    return fig


def visualize_training_curves(metrics_history, title=None, save_path=None):
    """
    Visualize training and validation curves.

    Args:
        metrics_history (dict): Dictionary containing training history
        title (str, optional): Plot title
        save_path (str, optional): Path to save the plot

    Returns:
        matplotlib.figure.Figure: Figure object
    """
    # Create figure with subplots
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    axes = axes.flatten()

    # Plot loss curves
    axes[0].plot(metrics_history['epoch'], metrics_history['train_loss'], 'b-', label='Train')
    axes[0].plot(metrics_history['epoch'], metrics_history['val_loss'], 'r-', label='Validation')
    axes[0].set_title('Total Loss')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].legend()

    # Plot NMAE curves
    axes[1].plot(metrics_history['epoch'], metrics_history['train_cell_nmae'], 'b-', label='Train Cell')
    axes[1].plot(metrics_history['epoch'], metrics_history['val_cell_nmae'], 'r-', label='Val Cell')
    axes[1].plot(metrics_history['epoch'], metrics_history['train_grid_nmae'], 'b--', label='Train Grid')
    axes[1].plot(metrics_history['epoch'], metrics_history['val_grid_nmae'], 'r--', label='Val Grid')
    axes[1].set_title('NMAE')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('NMAE')
    axes[1].legend()

    # Plot NRMS curves
    axes[2].plot(metrics_history['epoch'], metrics_history['train_cell_nrms'], 'b-', label='Train Cell')
    axes[2].plot(metrics_history['epoch'], metrics_history['val_cell_nrms'], 'r-', label='Val Cell')
    axes[2].plot(metrics_history['epoch'], metrics_history['train_grid_nrms'], 'b--', label='Train Grid')
    axes[2].plot(metrics_history['epoch'], metrics_history['val_grid_nrms'], 'r--', label='Val Grid')
    axes[2].set_title('NRMS')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('NRMS')
    axes[2].legend()

    # Plot Pearson correlation curves
    axes[3].plot(metrics_history['epoch'], metrics_history['train_cell_pearson'], 'b-', label='Train Cell')
    axes[3].plot(metrics_history['epoch'], metrics_history['val_cell_pearson'], 'r-', label='Val Cell')
    axes[3].plot(metrics_history['epoch'], metrics_history['train_grid_pearson'], 'b--', label='Train Grid')
    axes[3].plot(metrics_history['epoch'], metrics_history['val_grid_pearson'], 'r--', label='Val Grid')
    axes[3].set_title('Pearson Correlation')
    axes[3].set_xlabel('Epoch')
    axes[3].set_ylabel('Correlation')
    axes[3].legend()

    # Plot Spearman correlation curves
    axes[4].plot(metrics_history['epoch'], metrics_history['train_cell_spearman'], 'b-', label='Train Cell')
    axes[4].plot(metrics_history['epoch'], metrics_history['val_cell_spearman'], 'r-', label='Val Cell')
    axes[4].plot(metrics_history['epoch'], metrics_history['train_grid_spearman'], 'b--', label='Train Grid')
    axes[4].plot(metrics_history['epoch'], metrics_history['val_grid_spearman'], 'r--', label='Val Grid')
    axes[4].set_title('Spearman Correlation')
    axes[4].set_xlabel('Epoch')
    axes[4].set_ylabel('Correlation')
    axes[4].legend()

    # Plot Kendall correlation curves
    axes[5].plot(metrics_history['epoch'], metrics_history['train_cell_kendall'], 'b-', label='Train Cell')
    axes[5].plot(metrics_history['epoch'], metrics_history['val_cell_kendall'], 'r-', label='Val Cell')
    axes[5].plot(metrics_history['epoch'], metrics_history['train_grid_kendall'], 'b--', label='Train Grid')
    axes[5].plot(metrics_history['epoch'], metrics_history['val_grid_kendall'], 'r--', label='Val Grid')
    axes[5].set_title('Kendall Correlation')
    axes[5].set_xlabel('Epoch')
    axes[5].set_ylabel('Correlation')
    axes[5].legend()

    if title:
        fig.suptitle(title, fontsize=16)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')

    return fig