#!/usr/bin/env python3
"""
KSKT Inference Script
Generate responses using trained KSKT model
"""

import torch
import argparse
from transformers import AutoTokenizer
from kskt_model import KSKTForCausalLM, KSKTConfig
import json


class KSKTGenerator:
    """Generator class for KSKT model"""
    
    def __init__(self, model_path: str, device: str = 'cuda'):
        self.device = device
        
        # Load model
        checkpoint = torch.load(model_path, map_location=device)
        
        config = KSKTConfig()
        self.model = KSKTForCausalLM(config)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.to(device)
        self.model.eval()
        
        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained('Qwen3-4B-Thinking')  
        special_tokens = ['<role>', '</role>', '<user>', '</user>', 
                         '<assistant>', '</assistant>', '<think>', '</think>']
        self.tokenizer.add_special_tokens({'additional_special_tokens': special_tokens})
        
    def generate_response(
        self, 
        character_profile: str, 
        user_query: str, 
        max_length: int = 512,
        temperature: float = 0.7,
        top_p: float = 0.9,
        use_thinking: bool = True
    ) -> dict:
        """Generate response with optional thinking process"""
        
        # Format input
        if use_thinking:
            input_text = f"<role>{character_profile}</role>\n<user>{user_query}</user>\n<think>"
        else:
            input_text = f"<role>{character_profile}</role>\n<user>{user_query}</user>\n<assistant>"
        
        # Tokenize
        inputs = self.tokenizer(
            input_text,
            return_tensors='pt',
            max_length=2048,
            truncation=True,
            padding=False
        ).to(self.device)
        
        # Create masks
        role_mask, user_mask = self._create_masks(input_text, inputs['input_ids'])
        
        # Generate
        with torch.no_grad():
            outputs = self.model.model(
                input_ids=inputs['input_ids'],
                role_mask=role_mask,
                user_mask=user_mask
            )
            
            # Get auxiliary information
            fusion_weights = outputs.get('fusion_weights', [])
            routing_probs = outputs.get('routing_probs', [])
            
            # Generate text using simple greedy decoding (can be enhanced)
            logits = self.model.lm_head(outputs['hidden_states'])
            
            generated_ids = []
            current_ids = inputs['input_ids']
            
            for _ in range(max_length):
                next_logits = logits[0, -1, :] / temperature
                
                # Apply top-p sampling
                sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
                cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
                sorted_indices_to_remove[0] = 0
                
                indices_to_remove = sorted_indices[sorted_indices_to_remove]
                next_logits[indices_to_remove] = float('-inf')
                
                # Sample next token
                probs = torch.softmax(next_logits, dim=-1)
                next_id = torch.multinomial(probs, num_samples=1)
                
                generated_ids.append(next_id.item())
                current_ids = torch.cat([current_ids, next_id.unsqueeze(0)], dim=1)
                
                if next_id.item() == self.tokenizer.eos_token_id:
                    break
                
                # Get next logits
                with torch.no_grad():
                    new_role_mask = torch.cat([role_mask, torch.zeros(1, 1, dtype=torch.bool).to(self.device)], dim=1)
                    new_user_mask = torch.cat([user_mask, torch.zeros(1, 1, dtype=torch.bool).to(self.device)], dim=1)
                    
                    outputs = self.model.model(
                        input_ids=current_ids,
                        role_mask=new_role_mask,
                        user_mask=new_user_mask
                    )
                    logits = self.model.lm_head(outputs['hidden_states'])
        
        # Decode and parse response
        full_response = self.tokenizer.decode(inputs['input_ids'][0].tolist() + generated_ids, skip_special_tokens=True)
        
        # Extract thinking and response
        if use_thinking:
            thinking, response = self._parse_thinking_response(full_response)
        else:
            thinking = ""
            assistant_pos = full_response.find('<assistant>')
            response = full_response[assistant_pos+11:] if assistant_pos != -1 else full_response
        
        # Analyze dual-perspective reasoning
        analysis = self._analyze_dual_perspective(thinking, response, fusion_weights, routing_probs)
        
        return {
            'response': response.strip(),
            'thinking_process': thinking.strip() if thinking else None,
            'character_profile': character_profile,
            'user_query': user_query,
            'analysis': analysis
        }
    
    def _create_masks(self, input_text: str, input_ids: torch.Tensor):
        """Create role and user masks"""
        seq_len = input_ids.size(1)
        
        # Simple approach: first 1/3 is role, middle 1/3 is user
        role_mask = torch.zeros_like(input_ids, dtype=torch.bool)
        user_mask = torch.zeros_like(input_ids, dtype=torch.bool)
        
        role_mask[:, :seq_len//3] = True
        user_mask[:, seq_len//3:2*seq_len//3] = True
        
        return role_mask, user_mask
    
    def _parse_thinking_response(self, full_text: str):
        """Parse thinking and response from generated text"""
        think_start = full_text.find('<think>')
        think_end = full_text.find('</think>')
        
        if think_start != -1 and think_end != -1:
            thinking = full_text[think_start+7:think_end]
            remaining = full_text[think_end+8:]
            
            assistant_start = remaining.find('<assistant>')
            if assistant_start != -1:
                response = remaining[assistant_start+11:]
            else:
                response = remaining
        else:
            thinking = ""
            response = full_text
        
        return thinking, response
    
    def _analyze_dual_perspective(self, thinking: str, response: str, fusion_weights: list, routing_probs: list):
        """Analyze dual-perspective reasoning"""
        analysis = {
            'self_awareness_score': 0.0,
            'other_awareness_score': 0.0,
            'fusion_balance': 0.0,
            'expert_routing': {},
            'reasoning_analysis': ""
        }
        
        # Analyze text for dual-perspective indicators
        combined_text = (thinking + " " + response).lower()
        
        self_indicators = ['as a', 'my character', 'my role', 'my belief', 'from my perspective']
        other_indicators = ['user', 'they want', 'help them', 'their need', 'they are asking']
        
        self_count = sum(1 for indicator in self_indicators if indicator in combined_text)
        other_count = sum(1 for indicator in other_indicators if indicator in combined_text)
        
        analysis['self_awareness_score'] = min(1.0, self_count / len(self_indicators))
        analysis['other_awareness_score'] = min(1.0, other_count / len(other_indicators))
        
        # Analyze fusion weights
        if fusion_weights:
            last_fusion = fusion_weights[-1]  # Last layer
            alpha, beta = last_fusion
            alpha_mean = alpha.mean().item()
            beta_mean = beta.mean().item()
            analysis['fusion_balance'] = abs(alpha_mean - beta_mean)
        
        # Analyze expert routing
        if routing_probs:
            last_routing = routing_probs[-1]  # Last layer
            expert_names = ['Personality', 'Knowledge', 'Emotional', 'Capability']
            routing_mean = last_routing.mean(dim=0)
            
            for i, name in enumerate(expert_names):
                analysis['expert_routing'][name] = routing_mean[i].item()
        
        return analysis


def main():
    parser = argparse.ArgumentParser(description="Generate responses with KSKT")
    parser.add_argument('--model_path', type=str, required=True, help='Path to trained model checkpoint')
    parser.add_argument('--character_profile', type=str, help='Character profile description')
    parser.add_argument('--user_query', type=str, help='User query')
    parser.add_argument('--input_file', type=str, help='Input file with multiple queries')
    parser.add_argument('--output_file', type=str, help='Output file for results')
    parser.add_argument('--max_length', type=int, default=512, help='Maximum generation length')
    parser.add_argument('--temperature', type=float, default=0.7, help='Generation temperature')
    parser.add_argument('--top_p', type=float, default=0.9, help='Top-p sampling parameter')
    parser.add_argument('--no_thinking', action='store_true', help='Disable thinking mode')
    
    args = parser.parse_args()
    
    # Initialize generator
    generator = KSKTGenerator(args.model_path)
    
    results = []
    
    if args.input_file:
        # Batch processing
        with open(args.input_file, 'r', encoding='utf-8') as f:
            queries = json.load(f)
        
        for query in queries:
            result = generator.generate_response(
                character_profile=query['character_profile'],
                user_query=query['user_query'],
                max_length=args.max_length,
                temperature=args.temperature,
                top_p=args.top_p,
                use_thinking=not args.no_thinking
            )
            results.append(result)
            print(f"Generated response for: {query['user_query'][:50]}...")
    
    else:
        # Single query
        if not args.character_profile or not args.user_query:
            print("Please provide --character_profile and --user_query for single query mode")
            return
        
        result = generator.generate_response(
            character_profile=args.character_profile,
            user_query=args.user_query,
            max_length=args.max_length,
            temperature=args.temperature,
            top_p=args.top_p,
            use_thinking=not args.no_thinking
        )
        results.append(result)
        
        # Print result
        print("="*50)
        print("CHARACTER PROFILE:")
        print(result['character_profile'])
        print("\nUSER QUERY:")
        print(result['user_query'])
        if result['thinking_process']:
            print("\nTHINKING PROCESS:")
            print(result['thinking_process'])
        print("\nRESPONSE:")
        print(result['response'])
        print("\nANALYSIS:")
        analysis = result['analysis']
        print(f"Self-awareness: {analysis['self_awareness_score']:.3f}")
        print(f"Other-awareness: {analysis['other_awareness_score']:.3f}")
        print(f"Fusion balance: {analysis['fusion_balance']:.3f}")
        if analysis['expert_routing']:
            print("Expert routing:")
            for expert, prob in analysis['expert_routing'].items():
                print(f"  {expert}: {prob:.3f}")
    
    # Save results if output file specified
    if args.output_file:
        with open(args.output_file, 'w', encoding='utf-8') as f:
            json.dump(results, f, indent=2, ensure_ascii=False)
        print(f"\nResults saved to {args.output_file}")


if __name__ == "__main__":
    main()

