#!/usr/bin/env python3
"""
Unit tests for KSKT model components
"""

import unittest
import torch
import torch.nn.functional as F
import numpy as np
import sys
import os

# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from kskt_model import (
    KSKTConfig, KSKTForCausalLM, KSKTModel, KSKTLayer,
    DualStreamAxialAttention, BipolarReasoningModule, 
    SelfAwarenessMoE, MutualUnderstandingPositionEmbedding
)


class TestKSKTComponents(unittest.TestCase):
    """Test individual KSKT components"""
    
    def setUp(self):
        """Set up test fixtures"""
        self.config = KSKTConfig()
        self.batch_size = 2
        self.seq_len = 128
        self.hidden_size = self.config.hidden_size
        
        # Create sample inputs
        self.sample_hidden_states = torch.randn(self.batch_size, self.seq_len, self.hidden_size)
        self.sample_role_context = torch.randn(self.batch_size, self.seq_len // 3, self.hidden_size)
        self.sample_user_context = torch.randn(self.batch_size, self.seq_len // 3, self.hidden_size)
        self.sample_input_ids = torch.randint(0, self.config.vocab_size, (self.batch_size, self.seq_len))
    
    def test_dual_stream_attention_forward(self):
        """Test dual-stream axial attention forward pass"""
        dsaa = DualStreamAxialAttention(self.config)
        
        output, fusion_weights = dsaa(
            self.sample_hidden_states,
            self.sample_role_context,
            self.sample_user_context
        )
        
        # Check output shape
        self.assertEqual(output.shape, self.sample_hidden_states.shape)
        
        # Check fusion weights
        alpha, beta = fusion_weights
        self.assertEqual(alpha.shape, self.sample_hidden_states.shape)
        self.assertEqual(beta.shape, self.sample_hidden_states.shape)
        
        # Check fusion weights sum to 1 (approximately)
        fusion_sum = alpha + beta
        self.assertTrue(torch.allclose(fusion_sum, torch.ones_like(fusion_sum), atol=1e-6))
        
        # Check fusion weights are non-negative
        self.assertTrue(torch.all(alpha >= 0))
        self.assertTrue(torch.all(beta >= 0))
        
        print("✓ DualStreamAxialAttention forward pass test passed")
    
    def test_mutual_understanding_pe(self):
        """Test mutual understanding position encoding"""
        mupe = MutualUnderstandingPositionEmbedding(self.config)
        
        position_embeds = mupe(
            self.sample_input_ids,
            self.sample_role_context,
            self.sample_user_context
        )
        
        # Check output shape
        expected_shape = (self.batch_size, self.seq_len, self.hidden_size)
        self.assertEqual(position_embeds.shape, expected_shape)
        
        # Check that position encoding changes with context
        position_embeds_no_context = mupe(self.sample_input_ids)
        self.assertFalse(torch.allclose(position_embeds, position_embeds_no_context))
        
        print("✓ MutualUnderstandingPositionEmbedding test passed")
    
    def test_bipolar_reasoning_module(self):
        """Test bipolar reasoning module"""
        brm = BipolarReasoningModule(self.config)
        
        output = brm(
            self.sample_hidden_states,
            self.sample_role_context,
            self.sample_user_context
        )
        
        # Check output shape
        self.assertEqual(output.shape, self.sample_hidden_states.shape)
        
        # Check that output is different from input (processing occurred)
        self.assertFalse(torch.allclose(output, self.sample_hidden_states))
        
        print("✓ BipolarReasoningModule test passed")
    
    def test_self_awareness_moe(self):
        """Test self-awareness mixture of experts"""
        samoe = SelfAwarenessMoE(self.config)
        
        output, load_balance_loss, routing_probs = samoe(
            self.sample_hidden_states,
            self.sample_role_context
        )
        
        # Check output shape
        self.assertEqual(output.shape, self.sample_hidden_states.shape)
        
        # Check load balance loss
        self.assertIsInstance(load_balance_loss, torch.Tensor)
        self.assertEqual(load_balance_loss.dim(), 0)  # Scalar
        
        # Check routing probabilities
        self.assertEqual(routing_probs.shape, (self.batch_size, 4))  # 4 experts
        
        # Check routing probabilities sum to 1
        routing_sum = torch.sum(routing_probs, dim=-1)
        self.assertTrue(torch.allclose(routing_sum, torch.ones_like(routing_sum), atol=1e-6))
        
        # Check routing probabilities are non-negative
        self.assertTrue(torch.all(routing_probs >= 0))
        
        print("✓ SelfAwarenessMoE test passed")
    
    def test_kskt_layer(self):
        """Test complete KSKT layer"""
        layer = KSKTLayer(self.config)
        
        output, fusion_weights, load_balance_loss, routing_probs = layer(
            self.sample_hidden_states,
            self.sample_role_context,
            self.sample_user_context
        )
        
        # Check output shape
        self.assertEqual(output.shape, self.sample_hidden_states.shape)
        
        # Check all auxiliary outputs
        self.assertIsNotNone(fusion_weights)
        self.assertIsInstance(load_balance_loss, torch.Tensor)
        self.assertIsNotNone(routing_probs)
        
        print("✓ KSKTLayer test passed")
    
    def test_full_model_forward(self):
        """Test full KSKT model forward pass"""
        model = KSKTForCausalLM(self.config)
        
        # Create sample masks
        role_mask = torch.zeros(self.batch_size, self.seq_len, dtype=torch.bool)
        role_mask[:, :self.seq_len//3] = True
        
        user_mask = torch.zeros(self.batch_size, self.seq_len, dtype=torch.bool)
        user_mask[:, self.seq_len//3:2*self.seq_len//3] = True
        
        # Forward pass
        outputs = model(
            input_ids=self.sample_input_ids,
            labels=self.sample_input_ids,
            role_mask=role_mask,
            user_mask=user_mask
        )
        
        # Check outputs
        self.assertIn('loss', outputs)
        self.assertIn('logits', outputs)
        self.assertIn('auxiliary_losses', outputs)
        
        # Check logits shape
        expected_logits_shape = (self.batch_size, self.seq_len, self.config.vocab_size)
        self.assertEqual(outputs['logits'].shape, expected_logits_shape)
        
        # Check loss is scalar
        self.assertEqual(outputs['loss'].dim(), 0)
        
        print("✓ Full KSKT model forward pass test passed")
    
    def test_gradient_flow(self):
        """Test gradient flow through model"""
        model = KSKTForCausalLM(self.config)
        
        # Create sample data
        role_mask = torch.zeros(self.batch_size, self.seq_len, dtype=torch.bool)
        role_mask[:, :self.seq_len//3] = True
        
        user_mask = torch.zeros(self.batch_size, self.seq_len, dtype=torch.bool)
        user_mask[:, self.seq_len//3:2*self.seq_len//3] = True
        
        # Forward and backward pass
        outputs = model(
            input_ids=self.sample_input_ids,
            labels=self.sample_input_ids,
            role_mask=role_mask,
            user_mask=user_mask
        )
        
        loss = outputs['loss']
        loss.backward()
        
        # Check that gradients exist and are finite
        has_gradients = False
        for name, param in model.named_parameters():
            if param.grad is not None:
                has_gradients = True
                self.assertTrue(torch.isfinite(param.grad).all(), f"Non-finite gradient in {name}")
        
        self.assertTrue(has_gradients, "No gradients found")
        
        print("✓ Gradient flow test passed")


class TestKSKTTraining(unittest.TestCase):
    """Test KSKT training functionality"""
    
    def setUp(self):
        """Set up training test fixtures"""
        self.config = KSKTConfig()
        self.model = KSKTForCausalLM(self.config)
        
        # Sample training data
        self.batch_size = 2
        self.seq_len = 64
        self.sample_batch = {
            'input_ids': torch.randint(0, self.config.vocab_size, (self.batch_size, self.seq_len)),
            'labels': torch.randint(0, self.config.vocab_size, (self.batch_size, self.seq_len)),
            'role_mask': torch.zeros(self.batch_size, self.seq_len, dtype=torch.bool),
            'user_mask': torch.zeros(self.batch_size, self.seq_len, dtype=torch.bool)
        }
        
        # Set some masks
        self.sample_batch['role_mask'][:, :self.seq_len//3] = True
        self.sample_batch['user_mask'][:, self.seq_len//3:2*self.seq_len//3] = True
    
    def test_loss_computation(self):
        """Test loss computation with different lambda values"""
        # Test different phase configurations
        phase_configs = [
            {'lambda_consistency': 0.2, 'lambda_understanding': 0.0},  # Phase 1
            {'lambda_consistency': 0.1, 'lambda_understanding': 0.3},  # Phase 2
            {'lambda_consistency': 0.1, 'lambda_understanding': 0.2}   # Phase 3
        ]
        
        for i, config in enumerate(phase_configs):
            self.model.lambda_consistency = config['lambda_consistency']
            self.model.lambda_understanding = config['lambda_understanding']
            
            outputs = self.model(**self.sample_batch)
            loss = outputs['loss']
            
            # Check loss is finite and positive
            self.assertTrue(torch.isfinite(loss))
            self.assertTrue(loss > 0)
            
            print(f"✓ Phase {i+1} loss computation test passed")
    
    def test_auxiliary_loss_components(self):
        """Test individual auxiliary loss components"""
        outputs = self.model(**self.sample_batch)
        
        aux_losses = outputs['auxiliary_losses']
        
        # Check load balance loss
        self.assertIn('load_balance_loss', aux_losses)
        self.assertTrue(torch.isfinite(aux_losses['load_balance_loss']))
        
        # Check fusion weights exist
        self.assertIn('fusion_weights', aux_losses)
        self.assertIsInstance(aux_losses['fusion_weights'], list)
        
        # Check routing probabilities exist
        self.assertIn('routing_probs', aux_losses)
        self.assertIsInstance(aux_losses['routing_probs'], list)
        
        print("✓ Auxiliary loss components test passed")


def run_all_tests():
    """Run all tests"""
    print("Running KSKT Model Tests")
    print("=" * 50)
    
    # Create test suite
    loader = unittest.TestLoader()
    suite = unittest.TestSuite()
    
    # Add test cases
    suite.addTests(loader.loadTestsFromTestCase(TestKSKTComponents))
    suite.addTests(loader.loadTestsFromTestCase(TestKSKTTraining))
    
    # Run tests
    runner = unittest.TextTestRunner(verbosity=2)
    result = runner.run(suite)
    
    # Print summary
    if result.wasSuccessful():
        print("\n✓ All tests passed!")
        return True
    else:
        print(f"\n✗ {len(result.failures)} failures, {len(result.errors)} errors")
        return False


if __name__ == "__main__":
    success = run_all_tests()
    exit(0 if success else 1)
