import torch
from typing import Dict, List, Tuple, Optional
import numpy as np
import os
from datetime import datetime


try:
    import matplotlib.pyplot as plt
    MATPLOTLIB_AVAILABLE = True
except ImportError:
    MATPLOTLIB_AVAILABLE = False

try:
    import seaborn as sns
    SEABORN_AVAILABLE = True
except ImportError:
    SEABORN_AVAILABLE = False


class DimensionGroupedHyCa:

    
    def __init__(self, 
                 feature_dim: int,
                 n_clusters: int = 4,
                 orders: List[int] = [0, 1, 2, 3],
                 history_window: int = 4):

        self.feature_dim = feature_dim
        self.n_clusters = n_clusters
        self.orders = orders[:n_clusters] if orders else [0] * n_clusters  # Initial orders, will be dynamically updated
        self.history_window = max(4, history_window)  # At least 4 steps to calculate third-order differences
        
        # Dimension grouping
        self.dimension_groups = None  # Group assignment for each dimension
        self.is_fitted = False
        
        # History features for clustering analysis
        self.feature_history = []

        # Record clustering count
        self.cluster_count = 0
        
    def _collect_dimension_statistics(self, features: List[torch.Tensor]) -> Optional[np.ndarray]:

        if len(features) < 2:
            return None

        # Convert to numpy and compute statistics (convert to float32 first to avoid bfloat16 conversion issues)
        features_np = [f.detach().cpu().to(dtype=torch.float32).numpy() for f in features]
        
        # Check if all features have the same shape
        first_shape = features_np[0].shape
        if not all(f.shape == first_shape for f in features_np):
            # Find minimum common shape, maintain 3D structure [batch, seq_len, feature_dim]
            if len(first_shape) == 3:
                min_batch = min(f.shape[0] for f in features_np)
                min_seq = min(f.shape[1] for f in features_np)
                min_feat = min(f.shape[2] for f in features_np)
                features_np = [f[:min_batch, :min_seq, :min_feat] for f in features_np]
            elif len(first_shape) == 2:
                min_n = min(f.shape[0] for f in features_np)
                min_d = min(f.shape[1] for f in features_np)
                features_np = [f[:min_n, :min_d] for f in features_np]
        
        stacked_features = np.stack(features_np, axis=0)  # [T, batch, seq_len, feature_dim] or [T, N, D]
        
        # Handle different dimension cases
        if len(stacked_features.shape) == 4:  # [T, batch, seq_len, feature_dim]
            T, batch, seq_len, D = stacked_features.shape
            # Reshape to [T, N, D] where N = batch * seq_len
            stacked_features = stacked_features.reshape(T, batch * seq_len, D)
            T, N, D = stacked_features.shape
        elif len(stacked_features.shape) == 3:  # [T, N, D]
            T, N, D = stacked_features.shape
        else:
            print(f"[GroupedHyCa] Error: Unexpected feature shape {stacked_features.shape}")
            return None
        
        stats_list = []
        
        # Compute numerical analysis features for each dimension
        for d in range(D):
            v = stacked_features[:, :, d]  # [T, N] - dimension d token vectors across time steps
            
            # Use the refactored function to compute statistics
            from utils import compute_dimension_statistics_for_single_dim
            dim_stats = compute_dimension_statistics_for_single_dim(v)
            stats_list.append(dim_stats)
        
        return np.array(stats_list)  # [D, 10]
    
    def _assign_hyca_orders(self, dimension_groups: np.ndarray, dimension_stats: np.ndarray) -> List[int]:

        return self._compute_group_orders(dimension_groups, dimension_stats)
    
    def _compute_group_orders(self, dimension_groups: np.ndarray, dimension_stats: np.ndarray) -> List[int]:

        group_orders = []        
        for group_id in range(self.n_clusters):
            group_mask = (dimension_groups == group_id)
            if not np.any(group_mask):
                group_orders.append(0)
                continue
                
            # Extract feature statistics for this group
            group_stats = dimension_stats[group_mask]  # [n_dims_in_group, 10]
            
            # Calculate group-level features (median or mean)
            group_stats = np.median(group_stats, axis=0)
            
            # Determine order based on change magnitudes
            from utils import determine_order_by_thresholds
            order = determine_order_by_thresholds(group_stats)
            group_orders.append(order)
        
        return group_orders


    def _save_dimension_stats(self, dimension_stats: np.ndarray, dimension_stats_scaled: np.ndarray,
                             dimension_stats_mean: np.ndarray, dimension_stats_std: np.ndarray,
                             save_path: str, image_idx: int):

        try:
            # Create save directory
            stats_save_path = os.path.join(save_path, "dimension_stats")
            os.makedirs(stats_save_path, exist_ok=True)
            
            # Generate filename
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            filename = f'dimension_stats_img_{image_idx:04d}_{timestamp}.npz'
            full_path = os.path.join(stats_save_path, filename)
            
            # Feature names for subsequent analysis
            feature_names = ['d1', 'd2', 'd3', 'eta', 'kappa', 'rho', 'gamma', 'e', 'lfr', 'sf']
            
            # Prepare save data
            save_data = self._prepare_save_data(
                dimension_stats, dimension_stats_scaled, dimension_stats_mean, 
                dimension_stats_std, feature_names, image_idx, timestamp
            )
            
            # Save data
            np.savez_compressed(full_path, **save_data)
            
            # Save readable text summary
            self._save_text_summary(stats_save_path, dimension_stats, dimension_stats_mean, 
                                  dimension_stats_std, feature_names, image_idx, timestamp)
            
        except Exception as e:
            print(f"[GroupedHyCa] Failed to save dimension stats: {e}")
            import traceback
            traceback.print_exc()
    
    def _prepare_save_data(self, dimension_stats, dimension_stats_scaled, dimension_stats_mean, 
                          dimension_stats_std, feature_names, image_idx, timestamp):

        save_data = {
            'dimension_stats_raw': dimension_stats,
            'dimension_stats_scaled': dimension_stats_scaled,
            'dimension_stats_mean': dimension_stats_mean,
            'dimension_stats_std': dimension_stats_std,
            'feature_names': np.array(feature_names),
            'image_idx': image_idx,
            'timestamp': timestamp,
            'n_clusters': self.n_clusters,
            'feature_dim': self.feature_dim
        }
        
        # Only save clustering results when clustering is completed
        if self.is_fitted and self.dimension_groups is not None:
            save_data['dimension_groups'] = self.dimension_groups
            save_data['group_orders'] = np.array(self.orders)
        
        return save_data
    
    def _save_text_summary(self, stats_save_path, dimension_stats, dimension_stats_mean, 
                          dimension_stats_std, feature_names, image_idx, timestamp):

        summary_filename = f'dimension_stats_summary_img_{image_idx:04d}_{timestamp}.txt'
        summary_path = os.path.join(stats_save_path, summary_filename)
        
        with open(summary_path, 'w', encoding='utf-8') as f:
            f.write(f"Dimension Statistics Summary - Image {image_idx}\n")
            f.write(f"Generation Time: {timestamp}\n")
            f.write(f"Feature Dimension: {self.feature_dim}\n")
            f.write(f"Number of Clusters: {self.n_clusters}\n")
            f.write("\nFeature Descriptions:\n")
            
            feature_descriptions = [
                "d1: First-order change (velocity)",
                "d2: Second-order change (acceleration/curvature)", 
                "d3: Third-order change (jerk)",
                "eta: Curvature ratio (d2/d1)",
                "kappa: Jerk ratio (d3/d2)",
                "rho: Direction consistency (cosine of velocity vector angle)",
                "gamma: Relative change rate (d1/||v_{k-1}||)",
                "e: Energy (||v_k||)",
                "lfr: Low frequency ratio (DCT first 20% frequency power ratio)",
                "sf: Spectral flatness (geometric mean/arithmetic mean)"
            ]
            for desc in feature_descriptions:
                f.write(f"  {desc}\n")
            
            f.write("\nRaw Statistical Feature Ranges:\n")
            for i, name in enumerate(feature_names):
                min_val = np.min(dimension_stats[:, i])
                max_val = np.max(dimension_stats[:, i])
                mean_val = dimension_stats_mean[i]
                std_val = dimension_stats_std[i]
                f.write(f"  {name}: [{min_val:.6f}, {max_val:.6f}], mean={mean_val:.6f}, std={std_val:.6f}\n")
            
            if self.is_fitted:
                f.write(f"\nClustering Results:\n")
                f.write(f"  Dimension count per group: {[np.sum(self.dimension_groups == i) for i in range(self.n_clusters)]}\n")
                f.write(f"  HyCa orders per group: {self.orders}\n")

    def update_and_cluster(self, feature: torch.Tensor, save_visualization: bool = False, 
                          save_path: str = "./results", 
                          image_idx: int = 0) -> bool:

        self.feature_history.append(feature.clone())
        
        if len(self.feature_history) > self.history_window:
            self.feature_history.pop(0)
        
        
        if len(self.feature_history) >= 4 and not self.is_fitted:
            dimension_stats = self._collect_dimension_statistics(self.feature_history)
            self.cluster_count += 1
            
            if dimension_stats is not None:

                clustering_success = self._perform_clustering_and_assignment(
                    dimension_stats, save_path, image_idx, save_visualization
                )
                return clustering_success
        
        return False
    
    def _perform_clustering_and_assignment(self, dimension_stats, save_path, image_idx, save_visualization):

        dimension_stats_mean = np.mean(dimension_stats, axis=0)
        dimension_stats_std = np.std(dimension_stats, axis=0)
        dimension_stats_std = np.where(dimension_stats_std == 0, 1, dimension_stats_std)  
        dimension_stats_scaled = (dimension_stats - dimension_stats_mean) / dimension_stats_std
        

        self._save_dimension_stats(dimension_stats, dimension_stats_scaled, 
                                 dimension_stats_mean, dimension_stats_std, 
                                 save_path, image_idx)
        

        clustering_features = dimension_stats_scaled[:, [2, 3]] 
        

        from utils import simple_kmeans
        self.dimension_groups = simple_kmeans(clustering_features, self.n_clusters)
        

        self.orders = self._assign_hyca_orders(self.dimension_groups, dimension_stats)
        
        self.is_fitted = True
        

        if save_visualization:
            self.visualize_clustering(dimension_stats, save_path, image_idx)
        
        return True
    
    def get_dimension_groups(self) -> Optional[np.ndarray]:
        """Get dimension grouping results"""
        return self.dimension_groups if self.is_fitted else None
    
    def visualize_clustering(self, dimension_stats: np.ndarray, save_path: str, image_idx: int = 0):
        """
        Visualize dimension clustering results and save
        
        Args:
            dimension_stats: [D, 10] Numerical analysis features for each dimension
            save_path: Save path
            image_idx: Image index
        """
        if not self.is_fitted or self.dimension_groups is None:
            return
            
        if not MATPLOTLIB_AVAILABLE:
            print("matplotlib not available, skipping visualization")
            return
            
        try:
            os.makedirs(save_path, exist_ok=True)
            
            fig, axes = plt.subplots(2, 3, figsize=(18, 12))
            fig.suptitle(f'Dimension Clustering Results - Image {image_idx}', fontsize=16, fontweight='bold')
            
            colors = ['red', 'blue', 'green', 'orange', 'purple', 'brown', 'pink', 'gray']
            group_colors = [colors[i % len(colors)] for i in range(self.n_clusters)]
            
            ax1 = axes[0, 0]
            dimension_matrix = np.zeros((1, self.feature_dim))
            for i, group_id in enumerate(self.dimension_groups):
                dimension_matrix[0, i] = group_id
            
            im1 = ax1.imshow(dimension_matrix, aspect='auto', cmap='tab10')
            plt.colorbar(im1, ax=ax1, label='Cluster Group ID')
            ax1.set_title('Dimension Group Heatmap')
            ax1.set_xlabel('Dimension Index')
            ax1.set_ylabel('Group Assignment')
            
            ax2 = axes[0, 1]
            order_counts = [self.orders.count(i) for i in range(4)]
            bars = ax2.bar(range(4), order_counts, color=['lightcoral', 'lightblue', 'lightgreen', 'lightsalmon'])
            ax2.set_title('HyCa Order Distribution')
            ax2.set_xlabel('HyCa Order')
            ax2.set_ylabel('Number of Groups')
            ax2.set_xticks(range(4))
            for i, v in enumerate(order_counts):
                if v > 0:
                    ax2.text(i, v + 0.05, str(v), ha='center', va='bottom')
            
            ax3 = axes[0, 2]
            eta_values = dimension_stats[:, 3]  
            kappa_values = dimension_stats[:, 4]  
            
            for group_id in range(self.n_clusters):
                group_mask = (self.dimension_groups == group_id)
                if np.any(group_mask):
                    ax3.scatter(eta_values[group_mask], kappa_values[group_mask], 
                              c=group_colors[group_id], label=f'G{group_id} (O{self.orders[group_id]})',
                              alpha=0.7, s=30)
            
            ax3.set_xlabel('Curvature Ratio (η)')
            ax3.set_ylabel('Jerk Ratio (κ)')
            ax3.set_title('Feature Space Distribution (η vs κ)')
            ax3.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
            ax3.grid(True, alpha=0.3)
            ax4 = axes[1, 0]
            rho_values = dimension_stats[:, 5]  
            
            for group_id in range(self.n_clusters):
                group_mask = (self.dimension_groups == group_id)
                if np.any(group_mask):
                    group_rho = rho_values[group_mask]
                    ax4.hist(group_rho, bins=15, alpha=0.6, color=group_colors[group_id], 
                           label=f'G{group_id} (O{self.orders[group_id]})')
            
            ax4.set_xlabel('Direction Consistency (ρ)')
            ax4.set_ylabel('Number of Dimensions')
            ax4.set_title('Direction Consistency Distribution')
            ax4.legend()
            ax4.grid(True, alpha=0.3)
            ax5 = axes[1, 1]
            energy_values = dimension_stats[:, 7] 
            
            for group_id in range(self.n_clusters):
                group_mask = (self.dimension_groups == group_id)
                if np.any(group_mask):
                    group_energy = energy_values[group_mask]
                    ax5.hist(group_energy, bins=15, alpha=0.6, color=group_colors[group_id],
                           label=f'G{group_id} (O{self.orders[group_id]})')
            
            ax5.set_xlabel('Energy (e)')
            ax5.set_ylabel('Number of Dimensions')
            ax5.set_title('Energy Distribution')
            ax5.legend()
            ax5.grid(True, alpha=0.3)
            
            ax6 = axes[1, 2]
            ax6.axis('off')
            
            table_data = self._create_statistics_table(dimension_stats)
            
            if table_data:
                table = ax6.table(cellText=table_data,
                                colLabels=['Group', 'Order', 'Dims', 'η', 'κ', 'ρ', 'Energy'],
                                cellLoc='center',
                                loc='center')
                table.auto_set_font_size(False)
                table.set_fontsize(10)
                table.scale(1.2, 2)
                ax6.set_title('Group Statistics Summary', pad=20)
            
            plt.tight_layout()
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            filename = f'clustering_img_{image_idx:04d}_{timestamp}.png'
            full_path = os.path.join(save_path, filename)
            plt.savefig(full_path, dpi=300, bbox_inches='tight')
            plt.close()
            
        except Exception as e:
            print(f"Visualization save failed: {e}")
    
    def _create_statistics_table(self, dimension_stats):
        table_data = []
        for group_id in range(self.n_clusters):
            group_mask = (self.dimension_groups == group_id)
            if np.any(group_mask):
                group_stats = dimension_stats[group_mask]
                n_dims = np.sum(group_mask)
                avg_eta = np.median(group_stats[:, 3])
                avg_kappa = np.median(group_stats[:, 4])
                avg_rho = np.median(group_stats[:, 5])
                avg_energy = np.median(group_stats[:, 7])
                
                table_data.append([
                    f'G{group_id}',
                    f'O{self.orders[group_id]}',
                    f'{n_dims}',
                    f'{avg_eta:.3f}',
                    f'{avg_kappa:.3f}',
                    f'{avg_rho:.3f}',
                    f'{avg_energy:.3f}'
                ])
        return table_data

    def get_orders_for_groups(self) -> List[int]:
        """Get orders corresponding to each group"""
        return self.orders

