#!/usr/bin/env python3
"""
Gradio Web Interface for KSKT Model
Provides an easy-to-use web interface for testing the model
"""

import gradio as gr
import torch
import json
import os
import sys
from typing import Dict, Tuple, Optional
import matplotlib.pyplot as plt
import numpy as np

# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from inference import KSKTGenerator
from utils.visualization import KSKTVisualizer


class KSKTGradioInterface:
    """Gradio interface for KSKT model"""
    
    def __init__(self, model_path: str):
        self.generator = KSKTGenerator(model_path)
        self.visualizer = KSKTVisualizer()
        
        # Predefined character profiles for quick testing
        self.sample_characters = {
            "Medieval Knight": "You are Sir Gareth, a noble knight from the 12th century. You are devoted to chivalry, honor, and protecting the innocent. You value courage, loyalty, and have strong religious faith.",
            
            "Victorian Lady": "You are Lady Eleanor, an educated Victorian lady from the 1880s. You are interested in literature and social reform but are bound by the social conventions of your time. You are intelligent, reserved, and proper.",
            
            "Stoic Philosopher": "You are Marcus, a follower of Stoic philosophy in ancient Rome. You believe in virtue, reason, and accepting what cannot be changed. You are rational, calm, and seek wisdom above all.",
            
            "Pirate Captain": "You are Captain Blackwood, a notorious pirate from the Caribbean in the early 18th century. You are cunning, adventurous, and live by your own code. You value freedom above all else.",
            
            "Space Explorer": "You are Commander Nova, a space explorer from the year 2350. You have advanced technological knowledge and have traveled to many planets. You are curious, brave, and optimistic about the future."
        }
        
        # Sample conflict scenarios
        self.conflict_scenarios = {
            "Knowledge Boundary": [
                "Can you explain quantum mechanics?",
                "How does a smartphone work?",
                "What do you know about artificial intelligence?"
            ],
            "Value System Conflict": [
                "What do you think about gender equality?",
                "Should people have the right to choose their own religion?",
                "Is it okay to disobey authority for moral reasons?"
            ],
            "Emotional Support": [
                "I'm feeling really depressed. Can you help me?",
                "I'm going through a difficult breakup. What should I do?",
                "I'm struggling with anxiety about my future."
            ],
            "Expertise Boundary": [
                "Can you perform surgery on my injury?",
                "Help me write a legal contract.",
                "I need you to fix my computer."
            ]
        }
    
    def generate_response(
        self, 
        character_profile: str, 
        user_query: str,
        max_length: int = 512,
        temperature: float = 0.7,
        top_p: float = 0.9,
        use_thinking: bool = True
    ) -> Tuple[str, str, str, str]:
        """Generate response and return analysis"""
        
        try:
            result = self.generator.generate_response(
                character_profile=character_profile,
                user_query=user_query,
                max_length=max_length,
                temperature=temperature,
                top_p=top_p,
                use_thinking=use_thinking
            )
            
            response = result['response']
            thinking = result.get('thinking_process', 'No thinking process available')
            
            # Format analysis
            analysis = result['analysis']
            analysis_text = f"""
**Dual-Perspective Analysis:**
- Self-awareness Score: {analysis['self_awareness_score']:.3f}
- Other-awareness Score: {analysis['other_awareness_score']:.3f}
- Fusion Balance: {analysis['fusion_balance']:.3f}

**Expert Routing:**
"""
            for expert, prob in analysis['expert_routing'].items():
                analysis_text += f"- {expert}: {prob:.3f}\n"
            
            # Create analysis visualization
            viz_html = self._create_analysis_viz(analysis)
            
            return response, thinking, analysis_text, viz_html
            
        except Exception as e:
            error_msg = f"Error generating response: {str(e)}"
            return error_msg, "", "", ""
    
    def _create_analysis_viz(self, analysis: Dict) -> str:
        """Create visualization of analysis results"""
        try:
            # Create a simple bar chart of expert routing
            experts = list(analysis['expert_routing'].keys())
            probabilities = list(analysis['expert_routing'].values())
            
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
            
            # Expert routing bar chart
            bars = ax1.bar(experts, probabilities, color=['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4'])
            ax1.set_title('Expert Routing Probabilities')
            ax1.set_ylabel('Probability')
            ax1.tick_params(axis='x', rotation=45)
            
            # Add value labels on bars
            for bar, prob in zip(bars, probabilities):
                height = bar.get_height()
                ax1.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                        f'{prob:.2f}', ha='center', va='bottom')
            
            # Dual-perspective balance
            perspectives = ['Self-Awareness', 'Other-Awareness']
            scores = [analysis['self_awareness_score'], analysis['other_awareness_score']]
            colors = ['#FF6B6B', '#4ECDC4']
            
            bars2 = ax2.bar(perspectives, scores, color=colors)
            ax2.set_title('Dual-Perspective Reasoning')
            ax2.set_ylabel('Score')
            ax2.set_ylim(0, 1)
            
            # Add balance line
            balance_line = abs(scores[0] - scores[1])
            ax2.axhline(y=0.1, color='orange', linestyle='--', alpha=0.7, 
                       label=f'Good Balance (<0.1)\nCurrent: {balance_line:.3f}')
            ax2.legend()
            
            # Add value labels
            for bar, score in zip(bars2, scores):
                height = bar.get_height()
                ax2.text(bar.get_x() + bar.get_width()/2., height + 0.02,
                        f'{score:.2f}', ha='center', va='bottom')
            
            plt.tight_layout()
            
            # Save to temporary file and convert to HTML
            import tempfile
            import base64
            from io import BytesIO
            
            buffer = BytesIO()
            plt.savefig(buffer, format='png', dpi=100, bbox_inches='tight')
            buffer.seek(0)
            img_base64 = base64.b64encode(buffer.getvalue()).decode()
            plt.close()
            
            html = f'<img src="data:image/png;base64,{img_base64}" style="width:100%; max-width:800px;">'
            return html
            
        except Exception as e:
            return f"<p>Error creating visualization: {str(e)}</p>"
    
    def load_sample_character(self, character_name: str) -> str:
        """Load a predefined character profile"""
        return self.sample_characters.get(character_name, "")
    
    def load_conflict_scenario(self, scenario_type: str) -> str:
        """Load a random conflict scenario"""
        scenarios = self.conflict_scenarios.get(scenario_type, ["Custom scenario"])
        import random
        return random.choice(scenarios)
    
    def create_interface(self):
        """Create the Gradio interface"""
        
        with gr.Blocks(title="KSKT: Know Thyself, Know Thy User", theme=gr.themes.Soft()) as interface:
            
            gr.HTML("""
            <div style="text-align: center; margin-bottom: 20px;">
                <h1>🎭 KSKT: Know Thyself, Know Thy User</h1>
                <p style="font-size: 18px; color: #666;">
                    Dual-Perspective Reasoning for Role-Playing Language Models
                </p>
                <p style="font-size: 14px; color: #888;">
                    Experience balanced character authenticity and user satisfaction
                </p>
            </div>
            """)
            
            with gr.Row():
                with gr.Column(scale=1):
                    gr.HTML("<h3>📝 Character Setup</h3>")
                    
                    # Character selection
                    sample_character = gr.Dropdown(
                        choices=list(self.sample_characters.keys()),
                        label="Select Sample Character",
                        value="Medieval Knight"
                    )
                    
                    character_profile = gr.Textbox(
                        lines=4,
                        label="Character Profile",
                        placeholder="Describe your character's personality, background, and traits...",
                        value=self.sample_characters["Medieval Knight"]
                    )
                    
                    gr.HTML("<h3>💬 Conversation</h3>")
                    
                    # Conflict scenario selection
                    conflict_type = gr.Dropdown(
                        choices=list(self.conflict_scenarios.keys()),
                        label="Select Conflict Scenario Type",
                        value="Knowledge Boundary"
                    )
                    
                    user_query = gr.Textbox(
                        lines=2,
                        label="User Query",
                        placeholder="Ask your question or make a request...",
                        value="Can you explain quantum mechanics?"
                    )
                    
                    gr.HTML("<h3>⚙️ Generation Settings</h3>")
                    
                    with gr.Row():
                        max_length = gr.Slider(
                            minimum=100,
                            maximum=1000,
                            value=512,
                            step=50,
                            label="Max Length"
                        )
                        temperature = gr.Slider(
                            minimum=0.1,
                            maximum=1.5,
                            value=0.7,
                            step=0.1,
                            label="Temperature"
                        )
                    
                    with gr.Row():
                        top_p = gr.Slider(
                            minimum=0.1,
                            maximum=1.0,
                            value=0.9,
                            step=0.05,
                            label="Top-p"
                        )
                        use_thinking = gr.Checkbox(
                            value=True,
                            label="Enable Thinking Process"
                        )
                    
                    generate_btn = gr.Button("🚀 Generate Response", variant="primary", size="lg")
                
                with gr.Column(scale=2):
                    gr.HTML("<h3>🤖 Model Response</h3>")
                    
                    with gr.Tabs():
                        with gr.Tab("Response"):
                            response_output = gr.Textbox(
                                lines=10,
                                label="Character Response",
                                show_copy_button=True
                            )
                        
                        with gr.Tab("Thinking Process"):
                            thinking_output = gr.Textbox(
                                lines=10,
                                label="Internal Thinking Process",
                                show_copy_button=True
                            )
                        
                        with gr.Tab("Analysis"):
                            analysis_output = gr.Markdown(
                                label="Dual-Perspective Analysis"
                            )
                        
                        with gr.Tab("Visualization"):
                            viz_output = gr.HTML(
                                label="Analysis Visualization"
                            )
            
            # Examples section
            gr.HTML("<hr><h3>📚 Example Interactions</h3>")
            
            examples = gr.Examples(
                examples=[
                    [
                        self.sample_characters["Medieval Knight"],
                        "Can you help me with my smartphone?",
                        512, 0.7, 0.9, True
                    ],
                    [
                        self.sample_characters["Victorian Lady"],
                        "What do you think about women's rights?",
                        512, 0.7, 0.9, True
                    ],
                    [
                        self.sample_characters["Stoic Philosopher"],
                        "I'm feeling really anxious about my future.",
                        512, 0.7, 0.9, True
                    ],
                    [
                        self.sample_characters["Pirate Captain"],
                        "Can you perform surgery on my injury?",
                        512, 0.7, 0.9, True
                    ]
                ],
                inputs=[character_profile, user_query, max_length, temperature, top_p, use_thinking],
                outputs=[response_output, thinking_output, analysis_output, viz_output],
                fn=self.generate_response,
                cache_examples=False,
                label="Try these examples"
            )
            
            # Event handlers
            sample_character.change(
                fn=self.load_sample_character,
                inputs=[sample_character],
                outputs=[character_profile]
            )
            
            conflict_type.change(
                fn=self.load_conflict_scenario,
                inputs=[conflict_type],
                outputs=[user_query]
            )
            
            generate_btn.click(
                fn=self.generate_response,
                inputs=[character_profile, user_query, max_length, temperature, top_p, use_thinking],
                outputs=[response_output, thinking_output, analysis_output, viz_output]
            )
            
            # Footer
            gr.HTML("""
            <div style="text-align: center; margin-top: 30px; padding-top: 20px; border-top: 1px solid #eee;">
                <p style="color: #888; font-size: 12px;">
                    KSKT: Dual-Perspective Reasoning Architecture for Role-Playing Language Models<br>
                    Balancing Character Authenticity with User Satisfaction
                </p>
            </div>
            """)
        
        return interface


def main():
    import argparse
    
    parser = argparse.ArgumentParser(description="Launch KSKT Gradio Interface")
    parser.add_argument('--model_path', type=str, required=True, help='Path to trained KSKT model')
    parser.add_argument('--port', type=int, default=7860, help='Port to run the interface')
    parser.add_argument('--share', action='store_true', help='Create public link')
    parser.add_argument('--debug', action='store_true', help='Enable debug mode')
    
    args = parser.parse_args()
    
    if not os.path.exists(args.model_path):
        print(f"Error: Model file not found at {args.model_path}")
        return
    
    print("Loading KSKT model...")
    try:
        interface_app = KSKTGradioInterface(args.model_path)
        interface = interface_app.create_interface()
        
        print(f"Launching interface on port {args.port}")
        print("Interface will be available at:")
        print(f"  Local: http://localhost:{args.port}")
        
        if args.share:
            print("  Public link will be generated...")
        
        interface.launch(
            server_port=args.port,
            share=args.share,
            debug=args.debug,
            show_error=True
        )
        
    except Exception as e:
        print(f"Error launching interface: {e}")
        if args.debug:
            import traceback
            traceback.print_exc()


if __name__ == "__main__":
    main()
