from typing import Tuple, List

import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from torch import Tensor
import argparse
import ast


def parse_cluster_list(cluster_str):
    """Parse cluster configuration from string to list of lists"""
    try:
        return ast.literal_eval(cluster_str)
    except (ValueError, SyntaxError):
        raise argparse.ArgumentTypeError("Invalid cluster list format. Use Python list syntax: '[[3,2],[4,3]]'")


def mask_from_clusters(num_clusters: List,
                       data: Tensor,
                       image_size: Tuple,
                       save_path: str = None,
                       elevation: int = 15,
                       azimuthal: int = -150,
                       head_dims: int = 3,
                       random_state: int = 30
                       ):
    """
    visualize the segmentation mask obtained from attention clusters
    for an image. After running simulation and obtaining the number of
    clusters formed for each batch and for each head, pass the saved output.pt
    and the corresponding number of clusters for each batch and each head.

    :param num_clusters: number of clusters per batch per head. eg: for a batch of two images
        each with single head (3 dims for RGB) ->  [[2], [4]] two clusters for first image, 4 for second.
    :param data: Tensor .pt data of shape (batch, seq, dim)
    :param image_size: Tuple (height, width) of image
    :param save_path: Str, path to save visualizations
    :param elevation: int, Elevation angle of 3d plot
    :param azimuthal: int, Azimuthal angle of 3d plot
    :param head_dims: int, head dims, (3 for RGB)
    :param random_state: seed for reproducibility
    """

    if isinstance(data, torch.Tensor):
        data = data.numpy()

    batch_size, seq_len, features = data.shape
    assert features % head_dims == 0, "features not a multiple of head_dims (3)"
    num_heads = features // head_dims

    assert len(num_clusters) == batch_size, f"num_clusters must have {batch_size} elements (batch size)"
    for i, batch_clusters in enumerate(num_clusters):
        assert len(batch_clusters) == num_heads, f"num_clusters[{i}] must have {num_heads} elements (number of heads)"

    # Store results
    results = {
        'silhouette_scores': [],
        'masks': [],
        'individual_masks': []
    }

    for batch_idx in range(batch_size):
        batch_data = data[batch_idx]  # Shape: (seq_len, features)
        batch_c = num_clusters[batch_idx]  # cluster counts for this batch's heads
        fig_3d = plt.figure(figsize=(5 * num_heads, 5))
        max_clusters_this_batch = max(batch_c)
        fig_masks = plt.figure(figsize=(4 * num_heads, 4 * max_clusters_this_batch))

        for head in range(num_heads):
            head_data = batch_data[:, head * head_dims: (head + 1) * head_dims]
            C = batch_c[head]  # Number of clusters for this specific head
            kmeans = KMeans(n_clusters=C, random_state=random_state)
            clusters = kmeans.fit_predict(head_data)
            silhouette = silhouette_score(head_data, clusters)
            results['silhouette_scores'].append(silhouette)

            ax_3d = fig_3d.add_subplot(1, num_heads, head + 1, projection='3d')
            cluster_colors = plt.cm.viridis(np.linspace(0, 1, C))
            mask = clusters.reshape(image_size)
            results['masks'].append(mask)
            head_individual_masks = []

            for cluster_id in range(C):
                # Create binary mask for this specific cluster
                binary_mask = (clusters == cluster_id).astype(np.uint8).reshape(image_size)
                head_individual_masks.append(binary_mask)
                ax_mask = fig_masks.add_subplot(max_clusters_this_batch, num_heads,
                                                cluster_id * num_heads + head + 1)
                ax_mask.imshow(binary_mask, cmap='gray')
                ax_mask.set_title(f'Head {head + 1} (C={C})\nCluster {cluster_id}', fontsize=8)
                ax_mask.axis('off')

                # Hide unused subplots if this head has fewer clusters than max
                if cluster_id >= C:
                    ax_mask.set_visible(False)

            results['individual_masks'].append(head_individual_masks)

            # Plot each cluster in 3D with distinct colors
            for cluster_id in range(C):
                cluster_points = head_data[clusters == cluster_id]
                if len(cluster_points) > 0:
                    x, y, z = cluster_points[:, 0], cluster_points[:, 1], cluster_points[:, 2]
                    color = cluster_colors[cluster_id]
                    ax_3d.scatter(x, y, z, label=f'C{cluster_id}', s=10, marker='o',
                                  color=color, alpha=0.9, linewidths=0.1)

            title = f'Batch {batch_idx + 1}, Head {head + 1}\nC={C}, Silhouette: {silhouette:.2f}'
            ax_3d.set_title(title, fontsize=10)
            ax_3d.grid(False)

            # Add legend for clusters
            if head == 0:  # Only add legend to first subplot to avoid clutter
                ax_3d.legend(loc='upper right', fontsize=8)

            ax_3d.set_xticks([])
            ax_3d.set_yticks([])
            ax_3d.set_zticks([])
            ax_3d.view_init(elev=elevation, azim=azimuthal)

        plt.tight_layout()

        # Save figures if requested
        if save_path:
            fig_3d.savefig(f"{save_path}/batch{batch_idx + 1}_3d_clusters.png",
                           dpi=300, bbox_inches='tight', pad_inches=0.50)
            fig_masks.savefig(f"{save_path}/batch{batch_idx + 1}_masks.png",
                              dpi=300, bbox_inches='tight', pad_inches=0.50)

            # Save individual mask arrays
            for head in range(num_heads):
                C = batch_c[head]
                for cluster_id in range(C):
                    mask = results['individual_masks'][-num_heads + head][cluster_id]
                    plt.imsave(f"{save_path}/batch{batch_idx + 1}_head{head + 1}_mask{cluster_id}.png",
                               mask, cmap='gray')

        # Show figures non-blocking
        plt.draw()
        plt.pause(0.001)
        print(f"Batch {batch_idx + 1} figures displayed. Press Enter to continue...")
        input()
        plt.close('all')

        # Print summary
    print("\n" + "=" * 60)
    print("CLUSTERING VISUALIZATION SUMMARY")
    print("=" * 60)
    print(f"Processed {batch_size} batch elements")
    print(f"Each with {num_heads} attention heads")
    print(f"Cluster configuration per batch-head:")
    return results


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Visualize attention cluster masks')

    parser.add_argument('--data_path', type=str, required=True,
                        help='Path to .pt data file')
    parser.add_argument('--clusters', type=parse_cluster_list, required=True,
                        help='Cluster counts per batch per head as nested list. Example: "[[3,2],[4,3]]"')
    parser.add_argument('--save_path', type=str, default='figures/',
                        help='Path to save results')
    parser.add_argument('--image_size', type=int, nargs=2, default=[64, 64],
                        help='Image dimensions (height width)')
    parser.add_argument('--random_state', type=int, default=33,
                        help='Random seed')

    args = parser.parse_args()

    # Load data
    data = torch.load(args.data_path, weights_only=True).numpy()
    image_size = tuple(args.image_size)

    # Run visualization
    results = mask_from_clusters(
        num_clusters=args.clusters,
        data=data,
        image_size=image_size,
        save_path=args.save_path,
        random_state=args.random_state
    )