#!/usr/bin/env python3
"""
Visualization utilities for KSKT model analysis
"""

import torch
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import json
from typing import Dict, List, Optional, Tuple
import pandas as pd
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px


class KSKTVisualizer:
    """Visualization tools for KSKT model analysis"""
    
    def __init__(self, figsize: Tuple[int, int] = (12, 8)):
        self.figsize = figsize
        plt.style.use('seaborn-v0_8')
        sns.set_palette("husl")
    
    def plot_fusion_weights_evolution(self, fusion_weights_history: List[Tuple[torch.Tensor, torch.Tensor]], 
                                    save_path: Optional[str] = None):
        """Plot evolution of fusion weights during generation"""
        
        if not fusion_weights_history:
            print("No fusion weights history provided")
            return
        
        # Extract alpha and beta values over time
        alphas = []
        betas = []
        
        for alpha, beta in fusion_weights_history:
            # Average across batch and sequence dimensions
            alphas.append(alpha.mean().item())
            betas.append(beta.mean().item())
        
        # Create plot
        fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(self.figsize[0], self.figsize[1]*1.5))
        
        steps = range(len(alphas))
        
        # Alpha values (self-awareness)
        ax1.plot(steps, alphas, 'b-', label='Self-awareness (α)', linewidth=2)
        ax1.set_ylabel('Alpha Weight')
        ax1.set_title('Self-Understanding Stream Weight Evolution')
        ax1.grid(True, alpha=0.3)
        ax1.legend()
        
        # Beta values (other-awareness)
        ax2.plot(steps, betas, 'r-', label='Other-awareness (β)', linewidth=2)
        ax2.set_ylabel('Beta Weight')
        ax2.set_title('Other-Understanding Stream Weight Evolution')
        ax2.grid(True, alpha=0.3)
        ax2.legend()
        
        # Balance (difference)
        balance = [abs(a - b) for a, b in zip(alphas, betas)]
        ax3.plot(steps, balance, 'g-', label='|α - β|', linewidth=2)
        ax3.axhline(y=0.1, color='orange', linestyle='--', alpha=0.7, label='Good Balance (<0.1)')
        ax3.set_xlabel('Generation Step')
        ax3.set_ylabel('Balance Score')
        ax3.set_title('Dual-Perspective Balance')
        ax3.grid(True, alpha=0.3)
        ax3.legend()
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"Fusion weights plot saved to {save_path}")
        
        plt.show()
    
    def plot_expert_routing_analysis(self, routing_probs_history: List[torch.Tensor],
                                   expert_names: List[str] = None,
                                   save_path: Optional[str] = None):
        """Analyze and plot expert routing patterns"""
        
        if not routing_probs_history:
            print("No routing probabilities history provided")
            return
        
        if expert_names is None:
            expert_names = ['Personality', 'Knowledge', 'Emotional', 'Capability']
        
        # Extract routing probabilities over time
        routing_data = []
        for step, routing_probs in enumerate(routing_probs_history):
            # Average across batch dimension
            avg_routing = routing_probs.mean(dim=0).cpu().numpy()
            for expert_idx, prob in enumerate(avg_routing):
                routing_data.append({
                    'step': step,
                    'expert': expert_names[expert_idx],
                    'probability': prob
                })
        
        df = pd.DataFrame(routing_data)
        
        # Create subplots
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
        
        # 1. Expert usage over time
        for expert in expert_names:
            expert_data = df[df['expert'] == expert]
            ax1.plot(expert_data['step'], expert_data['probability'], 
                    label=expert, linewidth=2, marker='o', markersize=3)
        
        ax1.set_xlabel('Generation Step')
        ax1.set_ylabel('Routing Probability')
        ax1.set_title('Expert Activation Over Time')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # 2. Average expert utilization
        avg_utilization = df.groupby('expert')['probability'].mean()
        bars = ax2.bar(avg_utilization.index, avg_utilization.values)
        ax2.set_ylabel('Average Probability')
        ax2.set_title('Average Expert Utilization')
        ax2.axhline(y=0.25, color='red', linestyle='--', alpha=0.7, label='Equal Usage (0.25)')
        ax2.legend()
        
        # Add value labels on bars
        for bar, value in zip(bars, avg_utilization.values):
            ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                    f'{value:.3f}', ha='center', va='bottom')
        
        # 3. Expert usage variance (specialization indicator)
        expert_variance = df.groupby('expert')['probability'].var()
        bars = ax3.bar(expert_variance.index, expert_variance.values, color='orange')
        ax3.set_ylabel('Probability Variance')
        ax3.set_title('Expert Specialization (Higher = More Specialized)')
        
        for bar, value in zip(bars, expert_variance.values):
            ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.001,
                    f'{value:.4f}', ha='center', va='bottom')
        
        # 4. Routing entropy over time (diversity measure)
        entropy_data = []
        for step_data in df.groupby('step'):
            step_num, step_df = step_data
            probs = step_df['probability'].values
            entropy = -np.sum(probs * np.log(probs + 1e-8))  # Add small epsilon for stability
            entropy_data.append({'step': step_num, 'entropy': entropy})
        
        entropy_df = pd.DataFrame(entropy_data)
        ax4.plot(entropy_df['step'], entropy_df['entropy'], 'purple', linewidth=2, marker='s', markersize=4)
        ax4.set_xlabel('Generation Step')
        ax4.set_ylabel('Routing Entropy')
        ax4.set_title('Routing Diversity Over Time')
        max_entropy = np.log(len(expert_names))
        ax4.axhline(y=max_entropy, color='red', linestyle='--', alpha=0.7, label=f'Max Entropy ({max_entropy:.2f})')
        ax4.legend()
        ax4.grid(True, alpha=0.3)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"Expert routing analysis saved to {save_path}")
        
        plt.show()
        
        return df
    
    def plot_dual_perspective_comparison(self, results: Dict[str, Dict], 
                                       save_path: Optional[str] = None):
        """Compare dual-perspective reasoning across different models/scenarios"""
        
        # Prepare data for visualization
        models = list(results.keys())
        scenarios = list(results[models[0]].keys()) if models else []
        
        if not models or not scenarios:
            print("No results data provided")
            return
        
        # Create comparison matrix
        comparison_data = []
        for model in models:
            for scenario in scenarios:
                if scenario in results[model]:
                    score = results[model][scenario]
                    comparison_data.append({
                        'Model': model,
                        'Scenario': scenario,
                        'Score': score
                    })
        
        df = pd.DataFrame(comparison_data)
        
        # Create visualization
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
        
        # 1. Heatmap of model performance across scenarios
        pivot_df = df.pivot(index='Model', columns='Scenario', values='Score')
        sns.heatmap(pivot_df, annot=True, cmap='RdYlGn', ax=ax1, cbar_kws={'label': 'Performance Score'})
        ax1.set_title('Model Performance Across Scenarios')
        ax1.set_xlabel('Conflict Scenarios')
        ax1.set_ylabel('Models')
        
        # 2. Bar plot comparison
        sns.barplot(data=df, x='Scenario', y='Score', hue='Model', ax=ax2)
        ax2.set_title('Dual-Perspective Performance Comparison')
        ax2.set_ylabel('Performance Score')
        ax2.tick_params(axis='x', rotation=45)
        ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"Dual-perspective comparison saved to {save_path}")
        
        plt.show()
        
        return df
    
    def create_interactive_dashboard(self, analysis_results: Dict, output_path: str = "kskt_dashboard.html"):
        """Create interactive dashboard for KSKT analysis results"""
        
        # Create subplots
        fig = make_subplots(
            rows=2, cols=2,
            subplot_titles=('Fusion Weights Evolution', 'Expert Routing', 
                          'Performance Metrics', 'Dual-Perspective Balance'),
            specs=[[{"secondary_y": True}, {"type": "bar"}],
                   [{"type": "scatter"}, {"type": "indicator"}]]
        )
        
        # 1. Fusion weights evolution
        if 'fusion_weights' in analysis_results:
            fusion_data = analysis_results['fusion_weights']
            steps = list(range(len(fusion_data)))
            alphas = [fw['alpha'] for fw in fusion_data]
            betas = [fw['beta'] for fw in fusion_data]
            
            fig.add_trace(
                go.Scatter(x=steps, y=alphas, name='Self-awareness (α)', line=dict(color='blue')),
                row=1, col=1
            )
            fig.add_trace(
                go.Scatter(x=steps, y=betas, name='Other-awareness (β)', line=dict(color='red')),
                row=1, col=1
            )
        
        # 2. Expert routing
        if 'expert_routing' in analysis_results:
            routing_data = analysis_results['expert_routing']
            expert_names = list(routing_data.keys())
            probabilities = list(routing_data.values())
            
            fig.add_trace(
                go.Bar(x=expert_names, y=probabilities, name='Expert Usage'),
                row=1, col=2
            )
        
        # 3. Performance metrics scatter
        if 'performance_metrics' in analysis_results:
            perf_data = analysis_results['performance_metrics']
            
            fig.add_trace(
                go.Scatter(
                    x=perf_data.get('character_consistency', []),
                    y=perf_data.get('user_satisfaction', []),
                    mode='markers',
                    marker=dict(size=10, opacity=0.7),
                    name='Performance Points'
                ),
                row=2, col=1
            )
        
        # 4. Balance indicator
        if 'dual_perspective_balance' in analysis_results:
            balance_score = analysis_results['dual_perspective_balance']
            
            fig.add_trace(
                go.Indicator(
                    mode="gauge+number+delta",
                    value=balance_score,
                    domain={'x': [0, 1], 'y': [0, 1]},
                    title={'text': "Balance Score"},
                    delta={'reference': 0.1},
                    gauge={
                        'axis': {'range': [None, 1]},
                        'bar': {'color': "darkblue"},
                        'steps': [
                            {'range': [0, 0.1], 'color': "lightgray"},
                            {'range': [0.1, 0.3], 'color': "yellow"},
                            {'range': [0.3, 1], 'color': "red"}
                        ],
                        'threshold': {
                            'line': {'color': "red", 'width': 4},
                            'thickness': 0.75,
                            'value': 0.1
                        }
                    }
                ),
                row=2, col=2
            )
        
        # Update layout
        fig.update_layout(
            height=800,
            title_text="KSKT Model Analysis Dashboard",
            showlegend=True
        )
        
        # Save interactive plot
        fig.write_html(output_path)
        print(f"Interactive dashboard saved to {output_path}")
        
        return fig
    
    def plot_training_dynamics(self, training_log_path: str, save_path: Optional[str] = None):
        """Plot training dynamics from wandb logs or training logs"""
        
        try:
            # Load training logs (assuming JSON format)
            with open(training_log_path, 'r') as f:
                logs = json.load(f)
        except Exception as e:
            print(f"Error loading training logs: {e}")
            return
        
        # Extract training metrics
        phases = ['self_understanding', 'other_understanding', 'mutual_understanding']
        
        fig, axes = plt.subplots(2, 3, figsize=(18, 10))
        
        for i, phase in enumerate(phases):
            phase_logs = [log for log in logs if log.get('phase') == phase]
            
            if not phase_logs:
                continue
            
            steps = [log['global_step'] for log in phase_logs]
            losses = [log[f'{phase}/step_loss'] for log in phase_logs]
            
            # Plot training loss
            axes[0, i].plot(steps, losses, label=f'{phase.title()} Loss')
            axes[0, i].set_title(f'Phase {i+1}: {phase.title()}')
            axes[0, i].set_xlabel('Global Step')
            axes[0, i].set_ylabel('Loss')
            axes[0, i].grid(True, alpha=0.3)
            axes[0, i].legend()
            
            # Plot auxiliary losses
            if f'{phase}/load_balance_loss' in phase_logs[0]:
                aux_losses = [log[f'{phase}/load_balance_loss'] for log in phase_logs]
                axes[1, i].plot(steps, aux_losses, color='orange', label='Load Balance Loss')
            
            if f'{phase}/fusion_balance' in phase_logs[0]:
                fusion_balance = [log[f'{phase}/fusion_balance'] for log in phase_logs]
                axes[1, i].plot(steps, fusion_balance, color='green', label='Fusion Balance')
            
            axes[1, i].set_title(f'Auxiliary Losses - {phase.title()}')
            axes[1, i].set_xlabel('Global Step')
            axes[1, i].set_ylabel('Auxiliary Loss')
            axes[1, i].grid(True, alpha=0.3)
            axes[1, i].legend()
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"Training dynamics plot saved to {save_path}")
        
        plt.show()


# Example usage and demo functions
def demo_visualizations():
    """Demonstrate visualization capabilities"""
    print("KSKT Visualization Demo")
    print("=" * 30)
    
    # Create sample data
    np.random.seed(42)
    
    # Sample fusion weights evolution
    fusion_weights_sample = []
    for step in range(20):
        # Simulate realistic fusion weight evolution
        base_alpha = 0.5 + 0.2 * np.sin(step * 0.3) + np.random.normal(0, 0.05)
        base_beta = 1 - base_alpha + np.random.normal(0, 0.02)
        
        # Ensure normalization
        total = base_alpha + base_beta
        alpha = torch.tensor([[base_alpha / total] * 10] * 2)  # [batch_size=2, seq_len=10]
        beta = torch.tensor([[base_beta / total] * 10] * 2)
        
        fusion_weights_sample.append((alpha, beta))
    
    # Sample expert routing probabilities
    routing_probs_sample = []
    for step in range(20):
        # Simulate expert specialization
        probs = np.random.dirichlet([2, 1, 3, 1.5])  # Bias toward P and E experts
        routing_tensor = torch.tensor([probs, probs])  # [batch_size=2, num_experts=4]
        routing_probs_sample.append(routing_tensor)
    
    # Sample dual-perspective results
    dual_perspective_results = {
        'KSKT': {
            'knowledge_boundary': 0.87,
            'value_system_conflict': 0.83,
            'emotional_support': 0.91,
            'expertise_boundary': 0.79
        },
        'Baseline': {
            'knowledge_boundary': 0.45,
            'value_system_conflict': 0.52,
            'emotional_support': 0.38,
            'expertise_boundary': 0.41
        },
        'GPT-4': {
            'knowledge_boundary': 0.72,
            'value_system_conflict': 0.68,
            'emotional_support': 0.75,
            'expertise_boundary': 0.70
        }
    }
    
    # Create visualizer
    visualizer = KSKTVisualizer()
    
    # Generate visualizations
    print("1. Generating fusion weights evolution plot...")
    visualizer.plot_fusion_weights_evolution(fusion_weights_sample, "demo_fusion_weights.png")
    
    print("2. Generating expert routing analysis...")
    visualizer.plot_expert_routing_analysis(routing_probs_sample, save_path="demo_expert_routing.png")
    
    print("3. Generating dual-perspective comparison...")
    visualizer.plot_dual_perspective_comparison(dual_perspective_results, "demo_dual_perspective.png")
    
    print("4. Creating interactive dashboard...")
    dashboard_data = {
        'fusion_weights': [
            {'alpha': 0.6 + np.random.normal(0, 0.1), 'beta': 0.4 + np.random.normal(0, 0.1)}
            for _ in range(20)
        ],
        'expert_routing': {'Personality': 0.35, 'Knowledge': 0.15, 'Emotional': 0.30, 'Capability': 0.20},
        'dual_perspective_balance': 0.08
    }
    
    visualizer.create_interactive_dashboard(dashboard_data, "demo_dashboard.html")
    
    print("Demo completed! Check generated files.")


if __name__ == "__main__":
    demo_visualizations()
