
import torch
import numpy as np
import os
import glob
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from typing import List, Dict, Tuple, Optional, Union
import argparse




def sort_mask_by_degree(
    mask: np.ndarray,
    sort_rows: bool = True,
    sort_cols: bool = True,
    descending: bool = True
) -> np.ndarray:
    """
    Sort the neurons in a mask by their degree (connection density).
    
    Args:
        mask: Binary mask as numpy array
        sort_rows: Whether to sort rows (output neurons)
        sort_cols: Whether to sort columns (input neurons)
        descending: Whether to sort in descending order (highest degree first)
        
    Returns:
        Sorted mask
    """
    # Ensure mask is numpy array
    if isinstance(mask, torch.Tensor):
        mask = mask.numpy()
    
    # Calculate degrees
    row_degrees = mask.sum(axis=1)  # Out-degree for each output neuron
    col_degrees = mask.sum(axis=0)  # In-degree for each input neuron
    
    # Get sort indices
    row_indices = np.argsort(row_degrees)
    col_indices = np.argsort(col_degrees)
    
    # Reverse if descending
    if descending:
        row_indices = row_indices[::-1]
        col_indices = col_indices[::-1]
    
    # Apply sorting
    sorted_mask = mask.copy()
    if sort_rows:
        sorted_mask = sorted_mask[row_indices, :]
    if sort_cols:
        sorted_mask = sorted_mask[:, col_indices]
    
    return sorted_mask



def create_mask_histogram(
    masks: List[torch.Tensor],
    resolution: int = 64,
    sampling_method: str = 'all',  # 'all', 'random', 'stratified'
    max_samples_per_mask: Optional[int] = None,
    balanced_sampling: bool = True,
    mask_weights: Optional[List[float]] = None,
    sort_by_degree: bool = False,
    sort_descending: bool = True
) -> np.ndarray:
    """
    Create a 2D histogram representation of multiple masks.
    
    Args:
        masks: List of binary mask tensors
        resolution: Resolution of the histogram
        sampling_method: Method for sampling points from masks
        max_samples_per_mask: Maximum number of samples per mask
        balanced_sampling: Whether to balance 0s and 1s in sampling
        mask_weights: Optional weights for each mask
        sort_by_degree: Whether to sort neurons by degree before sampling
        sort_descending: Whether to sort degrees in descending order
        
    Returns:
        2D histogram as a numpy array
    """
    # Initialize empty histogram
    histogram = np.zeros((resolution, resolution), dtype=np.float32)
    counts = np.zeros((resolution, resolution), dtype=np.float32)
    
    # Process each mask
    for i, mask in enumerate(tqdm(masks, desc="Processing masks")):
        # Skip non-2D masks
        if len(mask.shape) != 2:
            continue
        
        rows, cols = mask.shape
        mask_np = mask.float().numpy() if isinstance(mask, torch.Tensor) else mask.astype(np.float32)

        # Sort mask by degree if requested
        if sort_by_degree:
            mask_np = sort_mask_by_degree(
                mask_np, 
                sort_rows=True, 
                sort_cols=True, 
                descending=sort_descending
            )
        
        # Determine weight for this mask
        weight = mask_weights[i] if mask_weights is not None else 1.0
        
        # Sample points based on the method
        if sampling_method == 'all':
            # Use all points
            samples = sample_all_points(mask_np, max_samples_per_mask)
        elif sampling_method == 'random':
            # Random sampling
            samples = sample_random_points(mask_np, max_samples_per_mask)
        elif sampling_method == 'stratified':
            # Stratified sampling (balance 0s and 1s)
            samples = sample_stratified_points(mask_np, max_samples_per_mask, balanced_sampling)
        else:
            raise ValueError(f"Unknown sampling method: {sampling_method}")
        
        # Convert to histogram
        for u, v, value in samples:
            # Convert to histogram indices
            i = min(int(u * resolution), resolution - 1)
            j = min(int(v * resolution), resolution - 1)
            
            # Update histogram
            histogram[i, j] += value * weight
            counts[i, j] += weight
    
    # Normalize histogram
    mask = counts > 0
    histogram[mask] /= counts[mask]
    
    return histogram


def sample_all_points(
    mask: np.ndarray,
    max_samples: Optional[int] = None
) -> List[Tuple[float, float, float]]:
    """
    Sample all points from the mask.
    
    Args:
        mask: Binary mask as numpy array
        max_samples: Maximum number of samples to return
        
    Returns:
        List of (u, v, value) tuples
    """
    rows, cols = mask.shape
    samples = []
    
    # If max_samples is specified, use a stride to approximately get that many samples
    if max_samples is not None and max_samples < rows * cols:
        # Calculate stride to get approximately max_samples
        stride = int(np.sqrt(rows * cols / max_samples))
        if stride < 1:
            stride = 1
    else:
        stride = 1
    
    # Sample points
    for i in range(0, rows, stride):
        for j in range(0, cols, stride):
            # Normalize coordinates to [0,1]
            u = i / (rows - 1) if rows > 1 else 0.5
            v = j / (cols - 1) if cols > 1 else 0.5
            
            # Get value
            value = float(mask[i, j])
            
            samples.append((u, v, value))
    
    return samples


def sample_random_points(
    mask: np.ndarray,
    max_samples: Optional[int] = None
) -> List[Tuple[float, float, float]]:
    """
    Randomly sample points from the mask.
    
    Args:
        mask: Binary mask as numpy array
        max_samples: Maximum number of samples to return
        
    Returns:
        List of (u, v, value) tuples
    """
    rows, cols = mask.shape
    
    # Determine number of samples
    num_samples = rows * cols if max_samples is None else min(max_samples, rows * cols)
    
    # Random indices
    indices = np.random.choice(rows * cols, size=num_samples, replace=False)
    
    # Convert to row, col indices
    row_indices = indices // cols
    col_indices = indices % cols
    
    # Sample points
    samples = []
    for i, j in zip(row_indices, col_indices):
        # Normalize coordinates to [0,1]
        u = i / (rows - 1) if rows > 1 else 0.5
        v = j / (cols - 1) if cols > 1 else 0.5
        
        # Get value
        value = float(mask[i, j])
        
        samples.append((u, v, value))
    
    return samples


def sample_stratified_points(
    mask: np.ndarray,
    max_samples: Optional[int] = None,
    balanced: bool = True
) -> List[Tuple[float, float, float]]:
    """
    Sample points with stratification to ensure both 0s and 1s are well-represented.
    
    Args:
        mask: Binary mask as numpy array
        max_samples: Maximum number of samples to return
        balanced: Whether to balance 0s and 1s
        
    Returns:
        List of (u, v, value) tuples
    """
    rows, cols = mask.shape
    
    # Find indices of 0s and 1s
    ones_indices = np.argwhere(mask > 0.5)
    zeros_indices = np.argwhere(mask <= 0.5)
    
    # Determine number of samples from each category
    if max_samples is None:
        max_samples = rows * cols
    
    if balanced:
        # Equal sampling of 0s and 1s
        samples_ones = max_samples // 2
        samples_zeros = max_samples // 2
    else:
        # Proportional sampling
        num_ones = len(ones_indices)
        num_zeros = len(zeros_indices)
        total = num_ones + num_zeros
        
        samples_ones = int(max_samples * (num_ones / total))
        samples_zeros = max_samples - samples_ones
    
    # Sample indices
    samples = []
    
    if samples_ones > 0 and len(ones_indices) > 0:
        ones_sample = ones_indices[np.random.choice(len(ones_indices), 
            size=min(samples_ones, len(ones_indices)), replace=False)]
        
        for i, j in ones_sample:
            # Normalize coordinates to [0,1]
            u = i / (rows - 1) if rows > 1 else 0.5
            v = j / (cols - 1) if cols > 1 else 0.5
            
            samples.append((u, v, 1.0))
    
    if samples_zeros > 0 and len(zeros_indices) > 0:
        zeros_sample = zeros_indices[np.random.choice(len(zeros_indices), 
            size=min(samples_zeros, len(zeros_indices)), replace=False)]
        
        for i, j in zeros_sample:
            # Normalize coordinates to [0,1]
            u = i / (rows - 1) if rows > 1 else 0.5
            v = j / (cols - 1) if cols > 1 else 0.5
            
            samples.append((u, v, 0.0))
    
    return samples

def load_masks_from_files(
    mask_files: List[str],
    fc_only: bool = True,
    min_size: int = 50,
    max_size: int = 5000,
    is_first: bool = True
) -> List[torch.Tensor]:
    """
    Load masks from files.
    
    Args:
        mask_files: List of mask file paths
        fc_only: Whether to only include fully connected layers
        min_size: Minimum size of masks to include
        max_size: Maximum size of masks to include
        is_first: True to include the first square hidden layer | False: second square hidden layer
        
    Returns:
        List of mask tensors
    """
    masks = []
    
    for file_path in tqdm(mask_files, desc="Loading mask files"):
        try:
            # Load mask data
            mask_data = torch.load(file_path)
            
            # Reset square layer counter for each file
            square_layer_count = 0
            
            # Process each mask in the current file
            for name, mask in mask_data.items():
                # Check if mask is for a fully connected layer
                is_fc = 'weight_mask' in name or 'linear' in name
                
                if fc_only and not is_fc:
                    continue
                
                # Check if mask is 2D
                if len(mask.shape) != 2:
                    continue
                
                # Check size constraints
                rows, cols = mask.shape
                if rows != cols:  # Only considers square matrices
                    continue
                if rows < min_size or cols < min_size:
                    continue
                if rows > max_size or cols > max_size:
                    continue
                
                # Increment square layer counter
                square_layer_count += 1
                
                # Only include the first or second square layer based on is_first
                if (is_first and square_layer_count == 1) or (not is_first and square_layer_count == 2):
                    masks.append(mask)
                    # If we've found the layer we want, we can break the inner loop
                    break
        
        except Exception as e:
            print(f"Error loading {file_path}: {e}")
    
    return masks


def find_mask_files(mask_dir: str, n_layers: int=-1, hidden_dim: int=-1,
                    compression: float = -1.0) -> List[str]:
    """
    Find mask files in a directory.
    
    Args:
        mask_dir: Directory containing mask files
        
    Returns:
        List of mask file paths
    """
    pattern = os.path.join(mask_dir, "**", "*.pt")
    all_files = glob.glob(pattern, recursive=True)
    final_files = []
    # If no filtering needed, return all files
    if n_layers == -1 and hidden_dim == -1 and compression == -1:
        return all_files
    
    # Create filter conditions
    layer_filter = lambda f: f'L_{n_layers}/' in f if n_layers != -1 else True
    dim_filter = lambda f: f'N_{hidden_dim}/' in f if hidden_dim != -1 else True
    compression_filter = lambda f: f'compression_{compression}/' in f if compression != -1 else True
    
    # Apply both filters
    return [file for file in all_files if layer_filter(file) and dim_filter(file) and compression_filter(file)]



def visualize_histogram(
    histogram: np.ndarray,
    title: str = "Mask Distribution",
    cmap: str = "viridis",
    save_path: Optional[str] = None
):
    """
    Visualize a 2D histogram.
    
    Args:
        histogram: 2D histogram as a numpy array
        title: Title for the plot
        cmap: Colormap to use
        save_path: Path to save the visualization
    """
    plt.figure(figsize=(10, 8))
    
    # Create heatmap with colorbar label included
    sns.heatmap(histogram, cmap=cmap, cbar_kws={'label': 'Probability of Connection'}, vmin=0, vmax=1)
    
    # Add title and labels
    plt.title(title)
    plt.xlabel("Normalized Column Position")
    plt.ylabel("Normalized Row Position")
    
    # Save or show
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
        np.save(save_path.replace('.png', '.npy'), histogram)
    else:
        plt.show()

def compare_histograms(
    histograms: Dict[str, np.ndarray],
    title: str = "Mask Distribution Comparison",
    cmap: str = "viridis",
    save_path: Optional[str] = None
):
    """
    Compare multiple histograms.
    
    Args:
        histograms: Dictionary mapping names to histograms
        title: Title for the plot
        cmap: Colormap to use
        save_path: Path to save the visualization
    """
    n = len(histograms)
    if n == 0:
        return
    
    # Determine grid layout
    if n <= 3:
        rows, cols = 1, n
    else:
        rows = int(np.ceil(np.sqrt(n)))
        cols = int(np.ceil(n / rows))
    
    # Create figure
    fig, axes = plt.subplots(rows, cols, figsize=(5*cols, 4*rows))
    
    # Flatten axes for easier indexing
    if rows * cols > 1:
        axes = axes.flatten()
    else:
        axes = [axes]
    
    # Plot each histogram
    for i, (name, hist) in enumerate(histograms.items()):
        if i < len(axes):
            ax = axes[i]
            sns.heatmap(hist, cmap=cmap, ax=ax, vmin=0, vmax=1)
            ax.set_title(name)
            ax.set_xlabel("Normalized Column Position")
            ax.set_ylabel("Normalized Row Position")
    
    # Hide unused subplots
    for i in range(len(histograms), len(axes)):
        fig.delaxes(axes[i])
    
    # Add overall title
    fig.suptitle(title, fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.96])  # Make room for suptitle
    
    # Save or show
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
    else:
        plt.show()


def generate_mask_from_histogram(
    histogram: np.ndarray,
    shape: Tuple[int, int],
    sparsity: float = 0.5
) -> np.ndarray:
    """
    Generate a binary mask from a histogram.
    
    Args:
        histogram: 2D histogram
        shape: Shape of the mask to generate
        sparsity: Target sparsity level
        
    Returns:
        Binary mask
    """
    rows, cols = shape
    mask = np.zeros(shape, dtype=np.float32)
    
    # Get histogram dimensions
    h_rows, h_cols = histogram.shape
    
    # Generate mask
    for i in range(rows):
        for j in range(cols):
            # Normalize coordinates to histogram indices
            u = min(int(i / (rows - 1) * h_rows), h_rows - 1) if rows > 1 else 0
            v = min(int(j / (cols - 1) * h_cols), h_cols - 1) if cols > 1 else 0
            
            # Get probability from histogram
            prob = histogram[u, v]
            
            # Set mask value
            mask[i, j] = prob
    
    # Apply threshold to achieve target sparsity
    if sparsity > 0 and sparsity < 1:
        # Sort values
        flat_values = mask.flatten()
        sorted_values = np.sort(flat_values)
        
        # Find threshold
        k = int(len(sorted_values) * sparsity)
        threshold = sorted_values[k]
        
        # Apply threshold
        mask = (mask > threshold).astype(np.float32)
    
    return mask


def create_histogram_by_method(
    masks: List[torch.Tensor],
    methods: List[str] = ['mag', 'grasp', 'rand', 'snip', 'synflow'],
    resolution: int = 64,
    mask_files: List[str] = None,
    sort_by_degree: bool = False,
    sort_descending: bool = True
) -> Dict[str, np.ndarray]:
    """
    Create histogram for each pruning method.
    
    Args:
        masks: List of mask tensors
        methods: List of pruning methods
        resolution: Resolution of the histogram
        mask_files: List of mask file paths
        sort_by_degree: Whether to sort neurons by degree
        
    Returns:
        Dictionary mapping methods to histograms
    """
    histograms = {}
    
    # Extract method information from file paths
    if mask_files:
        method_masks = {method: [] for method in methods}
        
        for file_path, mask in zip(mask_files, masks):
            # Determine method
            file_method = None
            for method in methods:
                if method in file_path:
                    file_method = method
                    break
            
            if file_method:
                method_masks[file_method].append(mask)
        
        # Create histogram for each method
        for method, method_masks_list in method_masks.items():
            if method_masks_list:
                histograms[method] = create_mask_histogram(
                    method_masks_list,
                    resolution=resolution,
                    sampling_method='all',
                    sort_by_degree=sort_by_degree,
                    sort_descending=sort_descending
                )
    
    return histograms


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Create and visualize mask histograms")
    parser.add_argument("--mask_dir", type=str, default='../pai/saved_masks/mnist/default/fc/synflow/',help="Directory containing mask files")
    parser.add_argument("--mask_files", type=str, help="Comma-separated list of mask files")
    parser.add_argument("--n_layers", type=int, choices=[-1,3,4,5], default=-1, help="number of layer, -1 considers all")
    parser.add_argument("--hidden_dim", type=int, choices=[-1, 100, 300, 400, 500, 1000, 1024, 2000], default=-1)
    parser.add_argument("--compression", type=float, choices=[-1.0, 0.5, 0.75, 1.0, 1.25, 1.5, 2.0], default=-1)
    parser.add_argument("--resolution", type=int, default=64, help="Histogram resolution")
    parser.add_argument("--sampling_method", type=str, default="stratified", 
                        choices=["all", "random", "stratified"], help="Sampling method")
    parser.add_argument("--max_samples_per_mask", type=int, default=10000, 
                        help="Maximum samples per mask")
    parser.add_argument("--balanced_sampling", action="store_true", 
                        help="Balance 0s and 1s in sampling")
    parser.add_argument("--fc_only", action="store_true", 
                        help="Only include fully connected layers")
    parser.add_argument("--min_size", type=int, default=50, 
                        help="Minimum size of masks to include")
    parser.add_argument("--max_size", type=int, default=5000, 
                        help="Maximum size of masks to include")
    parser.add_argument("--split_by_method", action="store_true", 
                        help="Create separate histograms for each pruning method")
    parser.add_argument("--output_dir", type=str, default="histogram_results", 
                        help="Output directory")
    parser.add_argument("--colormap", type=str, default="viridis", 
                        help="Colormap for visualization (e.g., viridis, plasma, magma, Blues, hot)")
    parser.add_argument("--sort_by_degree", action="store_true", default=False,
                        help="Sort neurons by connection degree")
    parser.add_argument("--sort_ascending", action="store_true", default=False,
                        help="Sort in ascending order (lowest degree first)")
    parser.add_argument("--is_first", action="store_true", default=False)
    
    args = parser.parse_args()
    
    # Load mask files
    if args.mask_dir:
        mask_files = find_mask_files(args.mask_dir, args.n_layers, args.hidden_dim, args.compression)
        name_specific = args.mask_dir.replace('../pai/saved_masks/mnist/default/', '')
        name_specific = '_'.join(name_specific.split('/'))
        if args.n_layers != -1:
            name_specific = name_specific + f'L_{args.n_layers}'
        if args.hidden_dim != -1:
            name_specific = name_specific + f'_N_{args.hidden_dim}'
        if args.compression != -1:
            name_specific = name_specific + f'_compression_{args.compression}'
        args.output_dir = f'{args.output_dir}/{name_specific}'
        
    elif args.mask_files:
        mask_files = args.mask_files.split(",")
    else:
        raise ValueError("Either --mask_dir or --mask_files must be provided")
    

    if not args.sort_by_degree:
        args.output_dir = args.output_dir + '_unsorted'
    else:
        args.output_dir = args.output_dir + '_sorted'

    if args.is_first:
        args.output_dir = args.output_dir + 'first'
    else:
        args.output_dir = args.output_dir + 'second'
        
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    print(mask_files)
    print(len(mask_files))

    # Load masks
    masks = load_masks_from_files(
        mask_files,
        fc_only=args.fc_only,
        min_size=args.min_size,
        max_size=args.max_size,
        is_first=args.is_first
    )
    
    print(f"Loaded {len(masks)} masks")
    
    if args.split_by_method:
        # Create histograms by method
        histograms = create_histogram_by_method(
            masks,
            methods=['rand', 'snip', 'synflow', 'grasp'],
            resolution=args.resolution,
            mask_files=mask_files,
            sort_by_degree=args.sort_by_degree,
            sort_descending=not args.sort_ascending
        )
        
        # Visualize each histogram
        for method, histogram in histograms.items():
            visualize_histogram(
                histogram,
                title=f"Mask Distribution for {method.upper()} Pruning",
                save_path=os.path.join(args.output_dir, f"histogram_{method}.png"),
                cmap=args.colormap
            )
        
        # Compare histograms
        compare_histograms(
            histograms,
            title="Comparison of Pruning Methods",
            save_path=os.path.join(args.output_dir, "method_comparison.png"),
            cmap=args.colormap
        )
    else:
        # Create overall histogram
        histogram = create_mask_histogram(
            masks,
            resolution=args.resolution,
            sampling_method=args.sampling_method,
            max_samples_per_mask=args.max_samples_per_mask,
            balanced_sampling=args.balanced_sampling
        )
        
        # Visualize histogram
        visualize_histogram(
            histogram,
            title="Overall Mask Distribution",
            save_path=os.path.join(args.output_dir, "overall_histogram.png"),
            cmap=args.colormap
        )
        
        # Generate example masks from the histogram
        for size in [100, 500, 1000]:
            for sparsity in [0.5, 0.8, 0.9]:
                # Generate mask
                mask = generate_mask_from_histogram(
                    histogram,
                    shape=(size, size),
                    sparsity=sparsity
                )
                
                # Visualize mask
                plt.figure(figsize=(8, 6))
                plt.imshow(mask, cmap="Blues")
                plt.title(f"Generated Mask ({size}×{size}, Sparsity={sparsity:.1f})")
                plt.colorbar(label="Mask Value")
                plt.savefig(os.path.join(args.output_dir, f"generated_mask_{size}x{size}_sparsity{sparsity:.1f}.png"))
                plt.close()
    
    print(f"Visualization complete. Results saved to {args.output_dir}")