#!/usr/bin/env python3
"""
KSKT Evaluation Script
Evaluates KSKT model on CharacterBench and role-user conflict scenarios.
"""

import os
import json
import torch
import numpy as np
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from tqdm import tqdm
import argparse
from typing import Dict, List, Tuple, Optional
import re
from sklearn.metrics import accuracy_score, f1_score
import pandas as pd

from kskt_model import KSKTForCausalLM, KSKTConfig


class CharacterBenchEvaluator:
    """Evaluator for CharacterBench dataset"""
    
    def __init__(self, model: KSKTForCausalLM, tokenizer, device: str = 'cuda'):
        self.model = model.to(device)
        self.tokenizer = tokenizer
        self.device = device
        self.model.eval()
        
        # CharacterBench metrics
        self.metrics = [
            'Memory', 'Knowledge_FA', 'Knowledge_BC_K', 'Persona_AC_b', 'Persona_AC_h',
            'Persona_BC_b_P', 'Persona_BC_h_P', 'Emotion_ES', 'Emotion_ER',
            'Morality_MS', 'Morality_MR', 'Believability_HL', 'Believability_EG'
        ]
    
    def evaluate_character_bench(self, data_path: str, max_samples: Optional[int] = None) -> Dict[str, float]:
        """Evaluate on CharacterBench dataset"""
        
        # Load CharacterBench data
        with open(data_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        if max_samples:
            data = data[:max_samples]
        
        results = {metric: [] for metric in self.metrics}
        
        for sample in tqdm(data, desc="Evaluating CharacterBench"):
            # Extract components
            character_profile = sample.get('character_profile', '')
            query = sample.get('query', '')
            expected_response = sample.get('expected_response', '')
            metric_type = sample.get('metric', '')
            
            if metric_type not in self.metrics:
                continue
            
            # Generate response
            response = self._generate_response(character_profile, query)
            
            # Evaluate based on metric type
            score = self._compute_metric_score(response, expected_response, metric_type, sample)
            results[metric_type].append(score)
        
        # Compute average scores
        avg_results = {}
        for metric, scores in results.items():
            if scores:
                avg_results[metric] = np.mean(scores)
            else:
                avg_results[metric] = 0.0
        
        # Compute overall average
        avg_results['Overall'] = np.mean(list(avg_results.values()))
        
        return avg_results
    
    def _generate_response(self, character_profile: str, query: str, max_length: int = 512) -> str:
        """Generate response from the model"""
        
        # Format input
        input_text = f"<role>{character_profile}</role>\n<user>{query}</user>\n<assistant>"
        
        # Tokenize
        inputs = self.tokenizer(
            input_text,
            return_tensors='pt',
            max_length=2048,
            truncation=True,
            padding=False
        ).to(self.device)
        
        # Create role and user masks
        role_mask = self._create_role_mask(input_text, inputs['input_ids'])
        user_mask = self._create_user_mask(input_text, inputs['input_ids'])
        
        # Generate
        with torch.no_grad():
            outputs = self.model.model(
                input_ids=inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                role_mask=role_mask,
                user_mask=user_mask
            )
            
            logits = self.model.lm_head(outputs['hidden_states'])
            
            # Simple greedy generation
            generated_ids = []
            current_ids = inputs['input_ids']
            
            for _ in range(max_length):
                next_logits = logits[0, -1, :]
                next_id = torch.argmax(next_logits).unsqueeze(0).unsqueeze(0)
                
                generated_ids.append(next_id.item())
                current_ids = torch.cat([current_ids, next_id], dim=1)
                
                if next_id.item() == self.tokenizer.eos_token_id:
                    break
                
                # Get next logits
                with torch.no_grad():
                    outputs = self.model.model(
                        input_ids=current_ids,
                        role_mask=torch.cat([role_mask, torch.zeros(1, 1, dtype=torch.bool).to(self.device)], dim=1),
                        user_mask=torch.cat([user_mask, torch.zeros(1, 1, dtype=torch.bool).to(self.device)], dim=1)
                    )
                    logits = self.model.lm_head(outputs['hidden_states'])
        
        # Decode response
        response = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
        return response
    
    def _create_role_mask(self, text: str, input_ids: torch.Tensor) -> torch.Tensor:
        """Create mask for role description tokens"""
        role_start = text.find('<role>')
        role_end = text.find('</role>')
        
        if role_start == -1 or role_end == -1:
            return torch.zeros_like(input_ids, dtype=torch.bool)
        
        # Simplified approach - mark first portion as role
        seq_len = input_ids.size(1)
        role_length = min(seq_len // 3, seq_len)
        
        mask = torch.zeros_like(input_ids, dtype=torch.bool)
        mask[:, :role_length] = True
        
        return mask
    
    def _create_user_mask(self, text: str, input_ids: torch.Tensor) -> torch.Tensor:
        """Create mask for user input tokens"""
        user_start = text.find('<user>')
        user_end = text.find('</user>')
        
        if user_start == -1 or user_end == -1:
            return torch.zeros_like(input_ids, dtype=torch.bool)
        
        # Simplified approach - mark middle portion as user
        seq_len = input_ids.size(1)
        user_start_idx = seq_len // 3
        user_end_idx = 2 * seq_len // 3
        
        mask = torch.zeros_like(input_ids, dtype=torch.bool)
        mask[:, user_start_idx:user_end_idx] = True
        
        return mask
    
    def _compute_metric_score(self, response: str, expected: str, metric_type: str, sample: Dict) -> float:
        """Compute score for a specific metric"""
        
        if metric_type == 'Memory':
            return self._evaluate_memory(response, expected, sample)
        elif metric_type.startswith('Knowledge'):
            return self._evaluate_knowledge(response, expected, sample)
        elif metric_type.startswith('Persona'):
            return self._evaluate_persona(response, expected, sample)
        elif metric_type.startswith('Emotion'):
            return self._evaluate_emotion(response, expected, sample)
        elif metric_type.startswith('Morality'):
            return self._evaluate_morality(response, expected, sample)
        elif metric_type.startswith('Believability'):
            return self._evaluate_believability(response, expected, sample)
        else:
            return 0.0
    
    def _evaluate_memory(self, response: str, expected: str, sample: Dict) -> float:
        """Evaluate memory consistency (5-point scale)"""
        # Check if response contains expected memory elements
        memory_elements = sample.get('memory_elements', [])
        
        found_elements = 0
        for element in memory_elements:
            if element.lower() in response.lower():
                found_elements += 1
        
        if not memory_elements:
            return 3.0  # Default neutral score
        
        ratio = found_elements / len(memory_elements)
        return min(5.0, max(1.0, 1 + 4 * ratio))  # Scale to 1-5
    
    def _evaluate_knowledge(self, response: str, expected: str, sample: Dict) -> float:
        """Evaluate knowledge accuracy"""
        # Simple keyword matching for factual accuracy
        expected_facts = sample.get('expected_facts', [])
        
        if not expected_facts:
            return 3.0
        
        correct_facts = 0
        for fact in expected_facts:
            if fact.lower() in response.lower():
                correct_facts += 1
        
        ratio = correct_facts / len(expected_facts)
        return min(5.0, max(1.0, 1 + 4 * ratio))
    
    def _evaluate_persona(self, response: str, expected: str, sample: Dict) -> float:
        """Evaluate persona consistency"""
        personality_traits = sample.get('personality_traits', [])
        
        if not personality_traits:
            return 3.0
        
        trait_scores = []
        for trait in personality_traits:
            # Check if response reflects this trait
            trait_keywords = sample.get(f'{trait}_keywords', [])
            if trait_keywords:
                found_keywords = sum(1 for kw in trait_keywords if kw.lower() in response.lower())
                trait_score = min(5.0, max(1.0, 1 + 4 * (found_keywords / len(trait_keywords))))
                trait_scores.append(trait_score)
        
        return np.mean(trait_scores) if trait_scores else 3.0
    
    def _evaluate_emotion(self, response: str, expected: str, sample: Dict) -> float:
        """Evaluate emotional appropriateness"""
        expected_emotion = sample.get('expected_emotion', 'neutral')
        
        # Simple emotion detection based on keywords
        emotion_keywords = {
            'happy': ['happy', 'joy', 'excited', 'pleased', 'glad'],
            'sad': ['sad', 'disappointed', 'sorry', 'regret'],
            'angry': ['angry', 'frustrated', 'annoyed'],
            'neutral': ['understand', 'see', 'think'],
            'supportive': ['support', 'help', 'encourage', 'care']
        }
        
        expected_kws = emotion_keywords.get(expected_emotion, [])
        if not expected_kws:
            return 3.0
        
        found_kws = sum(1 for kw in expected_kws if kw in response.lower())
        ratio = found_kws / len(expected_kws)
        
        return min(5.0, max(1.0, 1 + 4 * ratio))
    
    def _evaluate_morality(self, response: str, expected: str, sample: Dict) -> float:
        """Evaluate moral consistency"""
        moral_stance = sample.get('moral_stance', 'neutral')
        
        # Check if response aligns with expected moral stance
        stance_keywords = sample.get('stance_keywords', [])
        if not stance_keywords:
            return 3.0
        
        found_keywords = sum(1 for kw in stance_keywords if kw.lower() in response.lower())
        ratio = found_keywords / len(stance_keywords)
        
        return min(5.0, max(1.0, 1 + 4 * ratio))
    
    def _evaluate_believability(self, response: str, expected: str, sample: Dict) -> float:
        """Evaluate response believability"""
        # Check for character-breaking elements
        character_breaking = sample.get('character_breaking_elements', [])
        
        breaking_count = 0
        for element in character_breaking:
            if element.lower() in response.lower():
                breaking_count += 1
        
        if character_breaking:
            # Lower score for character breaking
            ratio = 1 - (breaking_count / len(character_breaking))
            return min(5.0, max(1.0, 1 + 4 * ratio))
        
        return 4.0  # Default good score if no breaking elements defined


class DualPerspectiveEvaluator:
    """Evaluator for dual-perspective reasoning in conflict scenarios"""
    
    def __init__(self, model: KSKTForCausalLM, tokenizer, device: str = 'cuda'):
        self.model = model.to(device)
        self.tokenizer = tokenizer
        self.device = device
        self.model.eval()
    
    def evaluate_conflict_scenarios(self, data_path: str) -> Dict[str, float]:
        """Evaluate on role-user conflict scenarios"""
        
        with open(data_path, 'r', encoding='utf-8') as f:
            conflict_data = json.load(f)
        
        results = {
            'knowledge_boundary': [],
            'value_system_conflict': [],
            'emotional_support': [],
            'expertise_boundary': [],
            'overall_self_awareness': [],
            'overall_other_awareness': []
        }
        
        for sample in tqdm(conflict_data, desc="Evaluating Conflict Scenarios"):
            scenario_type = sample.get('scenario_type', 'unknown')
            character_profile = sample.get('character_profile', '')
            user_query = sample.get('user_query', '')
            
            # Generate response with thinking
            response, thinking_process = self._generate_with_thinking(character_profile, user_query)
            
            # Evaluate dual-perspective reasoning
            self_awareness, other_awareness = self._evaluate_dual_perspective(
                thinking_process, response, sample
            )
            
            # Store results
            if scenario_type in results:
                results[scenario_type].append((self_awareness + other_awareness) / 2)
            
            results['overall_self_awareness'].append(self_awareness)
            results['overall_other_awareness'].append(other_awareness)
        
        # Compute averages
        avg_results = {}
        for scenario, scores in results.items():
            if scores:
                avg_results[scenario] = np.mean(scores)
            else:
                avg_results[scenario] = 0.0
        
        return avg_results
    
    def _generate_with_thinking(self, character_profile: str, query: str) -> Tuple[str, str]:
        """Generate response with explicit thinking process"""
        
        # Format input to trigger thinking mode
        input_text = f"<role>{character_profile}</role>\n<user>{query}</user>\n<think>"
        
        inputs = self.tokenizer(
            input_text,
            return_tensors='pt',
            max_length=2048,
            truncation=True,
            padding=False
        ).to(self.device)
        
        # Create masks
        role_mask = self._create_role_mask(input_text, inputs['input_ids'])
        user_mask = self._create_user_mask(input_text, inputs['input_ids'])
        
        # Generate with thinking
        with torch.no_grad():
            # Get model outputs to analyze internal state
            outputs = self.model.model(
                input_ids=inputs['input_ids'],
                role_mask=role_mask,
                user_mask=user_mask
            )
            
            # Extract fusion weights and routing probabilities for analysis
            fusion_weights = outputs.get('fusion_weights', [])
            routing_probs = outputs.get('routing_probs', [])
            
            # Generate text response (simplified)
            generated = self.tokenizer.decode(inputs['input_ids'][0], skip_special_tokens=True)
            
        # Extract thinking and response parts
        think_start = generated.find('<think>')
        think_end = generated.find('</think>')
        
        if think_start != -1 and think_end != -1:
            thinking_process = generated[think_start+7:think_end]
            response = generated[think_end+8:] if think_end+8 < len(generated) else ""
        else:
            thinking_process = ""
            response = generated
        
        return response, thinking_process
    
    def _create_role_mask(self, text: str, input_ids: torch.Tensor) -> torch.Tensor:
        """Create role mask"""
        seq_len = input_ids.size(1)
        role_length = min(seq_len // 3, seq_len)
        
        mask = torch.zeros_like(input_ids, dtype=torch.bool)
        mask[:, :role_length] = True
        
        return mask
    
    def _create_user_mask(self, text: str, input_ids: torch.Tensor) -> torch.Tensor:
        """Create user mask"""
        seq_len = input_ids.size(1)
        user_start_idx = seq_len // 3
        user_end_idx = 2 * seq_len // 3
        
        mask = torch.zeros_like(input_ids, dtype=torch.bool)
        mask[:, user_start_idx:user_end_idx] = True
        
        return mask
    
    def _evaluate_dual_perspective(self, thinking: str, response: str, sample: Dict) -> Tuple[float, float]:
        """Evaluate self-awareness and other-awareness in reasoning"""
        
        # Keywords indicating self-awareness (character constraints)
        self_keywords = [
            'as a', 'my character', 'my role', 'character constraint', 'my belief',
            'according to my', 'from my perspective', 'in my view', 'my values'
        ]
        
        # Keywords indicating other-awareness (user needs)  
        other_keywords = [
            'user', 'they want', 'their need', 'help them', 'understand their',
            'what they', 'their perspective', 'they might', 'user seems', 'they are asking'
        ]
        
        # Analyze thinking process
        thinking_lower = thinking.lower()
        response_lower = response.lower()
        combined_text = thinking_lower + " " + response_lower
        
        # Count self-awareness indicators
        self_count = sum(1 for kw in self_keywords if kw in combined_text)
        
        # Count other-awareness indicators
        other_count = sum(1 for kw in other_keywords if kw in combined_text)
        
        # Normalize scores (0-1 scale)
        max_indicators = len(self_keywords) + len(other_keywords)
        self_awareness = min(1.0, self_count / len(self_keywords))
        other_awareness = min(1.0, other_count / len(other_keywords))
        
        # Bonus for explicit dual consideration
        if "however" in combined_text or "but" in combined_text or "while" in combined_text:
            # Indicates considering multiple perspectives
            self_awareness = min(1.0, self_awareness + 0.1)
            other_awareness = min(1.0, other_awareness + 0.1)
        
        return self_awareness, other_awareness


def main():
    parser = argparse.ArgumentParser(description="Evaluate KSKT model")
    parser.add_argument('--model_path', type=str, required=True, help='Path to trained KSKT model checkpoint')
    parser.add_argument('--character_bench_data', type=str, help='Path to CharacterBench evaluation data')
    parser.add_argument('--conflict_data', type=str, help='Path to conflict scenarios data')
    parser.add_argument('--output_file', type=str, default='evaluation_results.json', help='Output file for results')
    parser.add_argument('--max_samples', type=int, help='Maximum number of samples to evaluate')
    
    args = parser.parse_args()
    
    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Load model
    print("Loading model...")
    checkpoint = torch.load(args.model_path, map_location=device)
    
    # Initialize config and model
    config = KSKTConfig()
    model = KSKTForCausalLM(config)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # Initialize tokenizer (assuming same as training)
    tokenizer = AutoTokenizer.from_pretrained('Qwen3-4B-Thinking')  # Adjust as needed
    
    print("Model loaded successfully!")
    
    # Results dictionary
    all_results = {}
    
    # Evaluate on CharacterBench
    if args.character_bench_data:
        print("Evaluating on CharacterBench...")
        char_evaluator = CharacterBenchEvaluator(model, tokenizer, device)
        char_results = char_evaluator.evaluate_character_bench(
            args.character_bench_data, 
            args.max_samples
        )
        all_results['character_bench'] = char_results
        
        print("CharacterBench Results:")
        for metric, score in char_results.items():
            print(f"  {metric}: {score:.3f}")
    
    # Evaluate on conflict scenarios
    if args.conflict_data:
        print("\nEvaluating dual-perspective reasoning...")
        dual_evaluator = DualPerspectiveEvaluator(model, tokenizer, device)
        dual_results = dual_evaluator.evaluate_conflict_scenarios(args.conflict_data)
        all_results['dual_perspective'] = dual_results
        
        print("Dual-Perspective Results:")
        for scenario, score in dual_results.items():
            print(f"  {scenario}: {score:.3f}")
    
    # Save results
    with open(args.output_file, 'w', encoding='utf-8') as f:
        json.dump(all_results, f, indent=2, ensure_ascii=False)
    
    print(f"\nResults saved to {args.output_file}")
    
    # Print summary
    if 'character_bench' in all_results:
        overall_cb = all_results['character_bench'].get('Overall', 0)
        print(f"\nCharacterBench Overall Score: {overall_cb:.3f}")
    
    if 'dual_perspective' in all_results:
        self_awareness = all_results['dual_perspective'].get('overall_self_awareness', 0)
        other_awareness = all_results['dual_perspective'].get('overall_other_awareness', 0)
        print(f"Self-Awareness Score: {self_awareness:.3f}")
        print(f"Other-Awareness Score: {other_awareness:.3f}")
        print(f"Balance Score: {abs(self_awareness - other_awareness):.3f}")


if __name__ == "__main__":
    main()
