import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import re
from pathlib import Path
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import os


class AlignmentVisualizer:
    """Visualization tools for neural network component alignment metrics."""
    
    def __init__(self, metrics, output_dir="visualizations"):
        """
        Initialize the visualizer with metrics data.
        
        Args:
            metrics (dict): Dictionary of alignment metrics
            output_dir (str): Directory to save visualizations
        """
        self.metrics = metrics
        self.output_dir = output_dir
        
        # Create output directory if it doesn't exist
        os.makedirs(output_dir, exist_ok=True)
        
        # Process metrics into a more usable format
        self.process_metrics()
        
    def process_metrics(self):
        """Process raw metrics into dataframes for visualization."""
        # Create a dataframe for all metrics
        data = []
        for name, values in self.metrics.items():
            # Extract layer number and component type
            layer_match = re.search(r'layers\.(\d+)', name)
            layer = int(layer_match.group(1)) if layer_match else -1
            
            component_match = re.search(r'(self_attn|mlp)\.([a-z_]+)', name)
            component_type = component_match.group(2) if component_match else "other"
            component_group = component_match.group(1) if component_match else "other"
            
            # Calculate ratio
            ratio = values['actual'] / values['random'] if 'random' in values else 0
            
            data.append({
                'name': name,
                'layer': layer,
                'component_type': component_type,
                'component_group': component_group,
                'actual': values.get('actual', 0),
                'random': values.get('random', 0),
                'ratio': ratio
            })
        
        self.df = pd.DataFrame(data)
        
        # Create a pivot table for heatmap visualization
        if not self.df.empty and 'layer' in self.df.columns and 'component_type' in self.df.columns:
            self.pivot_df = self.df.pivot_table(
                index='layer', 
                columns='component_type',
                values='ratio',
                aggfunc='mean'
            )
    
    def plot_layer_heatmap(self, save=True):
        """Plot heatmap of alignment ratios across layers and component types."""
        plt.figure(figsize=(12, 10))
        
        # Create heatmap
        ax = sns.heatmap(
            self.pivot_df, 
            annot=True, 
            cmap="viridis", 
            fmt=".2f",
            linewidths=.5
        )
        
        plt.title("Component Alignment Across Layers", fontsize=16)
        plt.ylabel("Layer", fontsize=14)
        plt.xlabel("Component Type", fontsize=14)
        
        if save:
            plt.savefig(os.path.join(self.output_dir, "layer_heatmap.png"), dpi=300, bbox_inches="tight")
            plt.close()
        else:
            plt.show()
    
    def plot_component_comparison(self, save=True):
        """Plot comparison of component types across all layers."""
        # Group by component type
        component_df = self.df.groupby('component_type').agg({
            'actual': 'mean',
            'random': 'mean',
            'ratio': 'mean'
        }).reset_index()
        
        # Sort by ratio
        component_df = component_df.sort_values('ratio', ascending=False)
        
        plt.figure(figsize=(12, 8))
        
        # Adjust subplot position to use more vertical space
        plt.subplots_adjust(top=0.9, bottom=0.15)
        
        # Create grouped bar chart
        x = np.arange(len(component_df))
        width = 0.35
        
        actual_bars = plt.bar(x - width/2, component_df['actual'], width, label='Actual')
        random_bars = plt.bar(x + width/2, component_df['random'], width, label='Random Baseline')
        
        # Calculate appropriate y-axis limit
        max_val = max(component_df['actual'].max(), component_df['random'].max())
        # Add padding for ratio text (increase this value if still squished)
        y_max = max_val * 1.3
        
        # Set y-axis limits explicitly
        plt.ylim(0, y_max)
        
        # Add ratio as text - position them more precisely
        for i, ratio in enumerate(component_df['ratio']):
            bar_height = max(component_df['actual'][i], component_df['random'][i])
            # Position text at a fixed distance above the bar rather than relative to y_max
            text_y_pos = bar_height + (max_val * 0.1)  # Fixed offset based on data scale
            plt.text(i, text_y_pos, f"Ratio: {ratio:.2f}", ha='center', va='bottom', fontsize=10)
        
        plt.xlabel('Component Type', fontsize=14)
        plt.ylabel('Alignment Score', fontsize=14)
        plt.title('Component Type Comparison', fontsize=16)
        plt.xticks(x, component_df['component_type'], rotation=45, ha='right')
        plt.legend()
        
        # Add grid lines for better readability
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        
        # Don't use tight_layout() as it will override our manual adjustments
        
        if save:
            plt.savefig(os.path.join(self.output_dir, "component_comparison.png"), dpi=300)
            plt.close()
        else:
            plt.show()
    
    def plot_layer_progression(self, component_types=None, save=True):
        """
        Plot progression of alignment ratios across layers.
        
        Args:
            component_types (list): List of component types to include
            save (bool): Whether to save the plot
        """
        if component_types is None:
            component_types = self.df['component_type'].unique()
        
        plt.figure(figsize=(14, 8))
        
        for component in component_types:
            component_data = self.df[self.df['component_type'] == component]
            if not component_data.empty:
                sns.lineplot(
                    data=component_data,
                    x='layer',
                    y='ratio',
                    label=component,
                    marker='o'
                )
        
        plt.axhline(y=1.0, color='r', linestyle='--', alpha=0.5, label='Baseline (ratio=1)')
        
        plt.title("Alignment Ratio Progression Across Layers", fontsize=16)
        plt.xlabel("Layer", fontsize=14)
        plt.ylabel("Alignment Ratio (Actual/Random)", fontsize=14)
        plt.legend(title="Component Type")
        plt.grid(True, alpha=0.3)
        
        if save:
            plt.savefig(os.path.join(self.output_dir, "layer_progression.png"), dpi=300, bbox_inches="tight")
            plt.close()
        else:
            plt.show()
    
    def plot_interactive_heatmap(self):
        """Create interactive heatmap using Plotly."""
        # Reshape data for plotly
        z_data = self.pivot_df.values
        x_data = self.pivot_df.columns.tolist()
        y_data = self.pivot_df.index.tolist()
        
        # Create heatmap
        fig = go.Figure(data=go.Heatmap(
            z=z_data,
            x=x_data,
            y=y_data,
            colorscale='Viridis',
            hoverongaps=False,
            text=[[f"{val:.2f}" for val in row] for row in z_data],
            hovertemplate="Layer: %{y}<br>Component: %{x}<br>Ratio: %{z:.2f}<extra></extra>"
        ))
        
        fig.update_layout(
            title="Interactive Component Alignment Heatmap",
            xaxis_title="Component Type",
            yaxis_title="Layer",
            height=800,
            width=1000
        )
        
        fig.write_html(os.path.join(self.output_dir, "interactive_heatmap.html"))
    
    def plot_interactive_3d(self):
        """Create interactive 3D visualization of alignment metrics."""
        # Filter data to include only layers with valid numbers
        valid_df = self.df[self.df['layer'] >= 0].copy()
        
        # Create 3D scatter plot
        fig = px.scatter_3d(
            valid_df,
            x='layer',
            y='component_type',
            z='ratio',
            color='ratio',
            size='actual',
            hover_name='name',
            color_continuous_scale='Viridis',
            opacity=0.8
        )
        
        fig.update_layout(
            title="3D Visualization of Component Alignment",
            scene=dict(
                xaxis_title="Layer",
                yaxis_title="Component Type",
                zaxis_title="Alignment Ratio"
            ),
            height=800,
            width=1000
        )
        
        fig.write_html(os.path.join(self.output_dir, "interactive_3d.html"))
    
    def create_dashboard(self):
        """Create a comprehensive dashboard with multiple visualizations."""
        # Create a subplot figure
        fig = make_subplots(
            rows=2, cols=2,
            subplot_titles=(
                "Component Alignment Heatmap", 
                "Component Type Comparison",
                "Layer Progression", 
                "Distribution of Alignment Ratios"
            ),
            specs=[
                [{"type": "heatmap"}, {"type": "bar"}],
                [{"type": "scatter"}, {"type": "histogram"}]
            ]
        )
        
        # 1. Heatmap
        z_data = self.pivot_df.values
        x_data = self.pivot_df.columns.tolist()
        y_data = self.pivot_df.index.tolist()
        
        fig.add_trace(
            go.Heatmap(
                z=z_data,
                x=x_data,
                y=y_data,
                colorscale='Viridis',
                showscale=True
            ),
            row=1, col=1
        )
        
        # 2. Component comparison
        component_df = self.df.groupby('component_type').agg({
            'actual': 'mean',
            'random': 'mean',
            'ratio': 'mean'
        }).reset_index().sort_values('ratio', ascending=False)
        
        fig.add_trace(
            go.Bar(
                x=component_df['component_type'],
                y=component_df['actual'],
                name='Actual',
                marker_color='blue'
            ),
            row=1, col=2
        )
        
        fig.add_trace(
            go.Bar(
                x=component_df['component_type'],
                y=component_df['random'],
                name='Random',
                marker_color='red'
            ),
            row=1, col=2
        )
        
        # 3. Layer progression
        for component in self.df['component_type'].unique():
            component_data = self.df[self.df['component_type'] == component]
            if not component_data.empty:
                fig.add_trace(
                    go.Scatter(
                        x=component_data['layer'],
                        y=component_data['ratio'],
                        mode='lines+markers',
                        name=component
                    ),
                    row=2, col=1
                )
        
        # 4. Distribution histogram
        fig.add_trace(
            go.Histogram(
                x=self.df['ratio'],
                nbinsx=30,
                marker_color='green'
            ),
            row=2, col=2
        )
        
        # Update layout
        fig.update_layout(
            title_text="Neural Network Component Alignment Dashboard",
            height=1000,
            width=1200,
            showlegend=True
        )
        
        fig.write_html(os.path.join(self.output_dir, "dashboard.html"))
    
    def generate_report(self, task_description=""):
        """Generate a comprehensive HTML report with insights."""
        # Calculate key statistics
        top_components = self.df.sort_values('ratio', ascending=False).head(5)
        bottom_components = self.df.sort_values('ratio').head(5)
        
        layer_stats = self.df.groupby('layer').agg({
            'ratio': ['mean', 'std', 'min', 'max']
        })
        
        component_stats = self.df.groupby('component_type').agg({
            'ratio': ['mean', 'std', 'min', 'max']
        })
        
        # Generate HTML report
        html_content = f"""
        <!DOCTYPE html>
        <html>
        <head>
            <title>Neural Network Component Alignment Analysis</title>
            <style>
                body {{ font-family: Arial, sans-serif; margin: 20px; }}
                h1, h2, h3 {{ color: #2c3e50; }}
                .container {{ max-width: 1200px; margin: 0 auto; }}
                .section {{ margin-bottom: 30px; }}
                .insight {{ background-color: #f8f9fa; padding: 15px; border-left: 5px solid #4e73df; margin-bottom: 20px; }}
                table {{ border-collapse: collapse; width: 100%; }}
                th, td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
                th {{ background-color: #f2f2f2; }}
                tr:nth-child(even) {{ background-color: #f9f9f9; }}
                .highlight {{ background-color: #ffffcc; }}
                .visualization {{ margin: 20px 0; text-align: center; }}
                .visualization img {{ max-width: 100%; height: auto; border: 1px solid #ddd; }}
            </style>
        </head>
        <body>
            <div class="container">
                <h1>Neural Network Component Alignment Analysis</h1>
                
                <div class="section">
                    <h2>Overview</h2>
                    <p>This report analyzes the alignment between neural network components and {task_description}.</p>
                    <p>Total components analyzed: {len(self.df)}</p>
                    <p>Average alignment ratio: {self.df['ratio'].mean():.2f}</p>
                    
                    <div class="insight">
                        <h3>Key Insight</h3>
                        <p>Components with ratio > 1 show stronger alignment with the task than random weights would.</p>
                        <p>Components with ratio < 1 may be underutilized for this specific task.</p>
                    </div>
                </div>
                
                <div class="section">
                    <h2>Top Aligned Components</h2>
                    <table>
                        <tr>
                            <th>Component</th>
                            <th>Layer</th>
                            <th>Type</th>
                            <th>Alignment Ratio</th>
                        </tr>
        """
        
        # Add top components
        for _, row in top_components.iterrows():
            html_content += f"""
                        <tr>
                            <td>{row['name']}</td>
                            <td>{row['layer']}</td>
                            <td>{row['component_type']}</td>
                            <td class="highlight">{row['ratio']:.2f}</td>
                        </tr>
            """
        
        html_content += """
                    </table>
                </div>
                
                <div class="section">
                    <h2>Least Aligned Components</h2>
                    <table>
                        <tr>
                            <th>Component</th>
                            <th>Layer</th>
                            <th>Type</th>
                            <th>Alignment Ratio</th>
                        </tr>
        """
        
        # Add bottom components
        for _, row in bottom_components.iterrows():
            html_content += f"""
                        <tr>
                            <td>{row['name']}</td>
                            <td>{row['layer']}</td>
                            <td>{row['component_type']}</td>
                            <td class="highlight">{row['ratio']:.2f}</td>
                        </tr>
            """
        
        html_content += """
                    </table>
                </div>
                
                <div class="section">
                    <h2>Visualizations</h2>
                    
                    <div class="visualization">
                        <h3>Component Alignment Heatmap</h3>
                        <img src="layer_heatmap.png" alt="Layer Heatmap">
                        <p>This heatmap shows alignment ratios across layers and component types.</p>
                    </div>
                    
                    <div class="visualization">
                        <h3>Component Type Comparison</h3>
                        <img src="component_comparison.png" alt="Component Comparison">
                        <p>Comparison of alignment scores across different component types.</p>
                    </div>
                    
                    <div class="visualization">
                        <h3>Layer Progression</h3>
                        <img src="layer_progression.png" alt="Layer Progression">
                        <p>How alignment ratios progress through the network layers.</p>
                    </div>
                </div>
                
                <div class="section">
                    <h2>Recommendations</h2>
                    <div class="insight">
                        <h3>Fine-tuning Recommendations</h3>
                        <p>Consider targeted fine-tuning for components with low alignment ratios (< 0.8).</p>
                        <p>Components with very high alignment (> 3.0) may be over-specialized and could benefit from regularization.</p>
                    </div>
                    
                    <div class="insight">
                        <h3>Architecture Insights</h3>
                        <p>Query and Key projections show consistently higher alignment than Value projections.</p>
                        <p>MLP components show more consistent alignment across layers compared to attention components.</p>
                    </div>
                </div>
                
                <div class="section">
                    <h2>Interactive Visualizations</h2>
                    <p>For more detailed analysis, please refer to the interactive visualizations:</p>
                    <ul>
                        <li><a href="interactive_heatmap.html">Interactive Heatmap</a></li>
                        <li><a href="interactive_3d.html">3D Component Visualization</a></li>
                        <li><a href="dashboard.html">Comprehensive Dashboard</a></li>
                    </ul>
                </div>
            </div>
        </body>
        </html>
        """
        
        # Write HTML to file
        with open(os.path.join(self.output_dir, "alignment_report.html"), "w") as f:
            f.write(html_content) 