import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Optional, Tuple
from dataclasses import dataclass
from enum import Enum


class TrainingPhase(Enum):
    PHASE1_SELF_UNDERSTANDING = 1  
    PHASE2_OTHER_UNDERSTANDING = 2  
    PHASE3_MUTUAL_INTEGRATION = 3   


@dataclass
class TrainingConfig:
    lambda_consistency: float = 0.1 
    lambda_understanding: float = 0.2  
    lambda_balance: float = 0.01 
    
    lr_phase1: float = 2e-5
    lr_phase2: float = 1e-5
    lr_phase3: float = 5e-6
    
    warmup_steps: int = 500
    weight_decay: float = 0.01
    gradient_clip: float = 1.0
    batch_size: int = 128
    sequence_length: int = 2048
    
    phase1_epochs: int = 2
    phase2_epochs: int = 1
    phase3_epochs: int = 1


class ConsistencyLoss(nn.Module):
    """
    Character Consistency Loss (L_consistency).
    """
    
    def __init__(self, hidden_size: int):
        super().__init__()
        self.hidden_size = hidden_size
        # Projection for character features
        self.char_proj = nn.Linear(hidden_size, hidden_size)
    
    def forward(
        self,
        hidden_states: torch.Tensor,
        R_proc: torch.Tensor,
        character_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Compute consistency loss.
        """
        batch_size, seq_len, hidden = hidden_states.shape
        
        # Project hidden states to character feature space
        char_features = self.char_proj(hidden_states)  
        
        # Compute character reference from R_proc
        char_ref = R_proc.mean(dim=1, keepdim=True)  
        char_ref = self.char_proj(char_ref)
        
        # Normalize for cosine similarity
        char_features_norm = F.normalize(char_features, dim=-1)
        char_ref_norm = F.normalize(char_ref, dim=-1)
        
        # Compute cosine similarity
        similarity = (char_features_norm * char_ref_norm).sum(dim=-1)  
        
        if character_mask is not None:
            similarity = similarity * character_mask
            num_valid = character_mask.sum() + 1e-8
        else:
            num_valid = batch_size * seq_len
        
        # Loss
        consistency_loss = 1.0 - similarity.sum() / num_valid
        
        return consistency_loss


class UnderstandingLoss(nn.Module):
    """
    User Understanding Loss (L_understanding).
    """
    
    def __init__(self, hidden_size: int, num_intent_classes: int = 8):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_intent_classes = num_intent_classes
        
        # Intent classifier
        self.intent_classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size // 2, num_intent_classes)
        )
    
    def forward(
        self,
        hidden_states: torch.Tensor,
        U_proc: torch.Tensor,
        intent_labels: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Compute understanding loss.
        """

        U_mean = U_proc.mean(dim=1)  
        
        # Compute attention weights
        attn_scores = torch.matmul(
            hidden_states, U_mean.unsqueeze(-1)
        ).squeeze(-1)  
        attn_weights = F.softmax(attn_scores, dim=-1)
        
        # Weighted sum of hidden states
        pooled = torch.matmul(attn_weights.unsqueeze(1), hidden_states).squeeze(1)  
        
        # Predict intent
        intent_logits = self.intent_classifier(pooled) 
        
        if intent_labels is not None:
            # Supervised loss with ground-truth labels
            loss = F.cross_entropy(intent_logits, intent_labels)
        else:
            # Self-supervised: encourage confident predictions
            probs = F.softmax(intent_logits, dim=-1)
            # Entropy minimization as proxy for understanding
            entropy = -(probs * torch.log(probs + 1e-8)).sum(dim=-1)
            loss = entropy.mean()
        
        return loss


class KSKTLoss(nn.Module):
    """
    Complete KSKT Loss Function.
    """
    
    def __init__(self, config: TrainingConfig, hidden_size: int):
        super().__init__()
        self.config = config
        self.hidden_size = hidden_size
        
        # Loss components
        self.consistency_loss = ConsistencyLoss(hidden_size)
        self.understanding_loss = UnderstandingLoss(hidden_size)
        
        # Current training phase
        self.current_phase = TrainingPhase.PHASE1_SELF_UNDERSTANDING
    
    def set_phase(self, phase: TrainingPhase):
        """Set current training phase to control loss activation."""
        self.current_phase = phase
    
    def forward(
        self,
        logits: torch.Tensor,
        labels: torch.Tensor,
        hidden_states: torch.Tensor,
        R_proc: torch.Tensor,
        U_proc: torch.Tensor,
        balance_loss: torch.Tensor,
        intent_labels: Optional[torch.Tensor] = None,
        character_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, Dict[str, float]]:
        """
        Compute total loss based on current training phase.
        """
        loss_dict = {}
        
        # L_CLM: Causal Language Modeling Loss
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        l_clm = F.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
            ignore_index=-100
        )
        loss_dict['l_clm'] = l_clm.item()
        total_loss = l_clm
        
        # Phase 1+: L_consistency
        if self.current_phase.value >= TrainingPhase.PHASE1_SELF_UNDERSTANDING.value:
            l_consistency = self.consistency_loss(hidden_states, R_proc, character_mask)
            loss_dict['l_consistency'] = l_consistency.item()
            total_loss = total_loss + self.config.lambda_consistency * l_consistency
        
        # Phase 2+: L_understanding
        if self.current_phase.value >= TrainingPhase.PHASE2_OTHER_UNDERSTANDING.value:
            l_understanding = self.understanding_loss(hidden_states, U_proc, intent_labels)
            loss_dict['l_understanding'] = l_understanding.item()
            total_loss = total_loss + self.config.lambda_understanding * l_understanding
        
        # Phase 3: L_balance
        if self.current_phase.value >= TrainingPhase.PHASE3_MUTUAL_INTEGRATION.value:
            loss_dict['l_balance'] = balance_loss.item() if isinstance(balance_loss, torch.Tensor) else balance_loss
            total_loss = total_loss + self.config.lambda_balance * balance_loss
        
        loss_dict['total'] = total_loss.item()
        
        return total_loss, loss_dict


class BudgetSupervisor(nn.Module):
    """
    Thinking Budget Supervision.
    """
    
    def __init__(
        self,
        lambda_fusion: float = 0.6,
        lambda_entropy: float = 0.4,
        num_budgets: int = 5
    ):
        super().__init__()
        self.lambda_fusion = lambda_fusion
        self.lambda_entropy = lambda_entropy
        self.num_budgets = num_budgets
        
        # Budget boundaries for discretization
        # Scores are binned into 5 categories: T ∈ {2, 4, 6, 8, 10}
        self.register_buffer(
            'boundaries',
            torch.tensor([0.2, 0.4, 0.6, 0.8])
        )
    
    def compute_conflict_score(
        self,
        alpha: torch.Tensor,
        beta: torch.Tensor,
        routing_probs: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute conflict score for budget supervision.
        """
        # Fusion weight divergence
        fusion_divergence = (alpha - beta).abs().mean(dim=[1, 2])  
        
        # Expert routing entropy
        entropy = -(routing_probs * torch.log(routing_probs + 1e-8)).sum(dim=-1)  
        # Normalize entropy by maximum possible entropy
        max_entropy = torch.log(torch.tensor(routing_probs.shape[-1], dtype=torch.float32))
        entropy_normalized = entropy / max_entropy
        
        # Combined conflict score
        conflict_score = (
            self.lambda_fusion * fusion_divergence +
            self.lambda_entropy * entropy_normalized
        )
        
        return conflict_score
    
    def discretize_budget(self, conflict_scores: torch.Tensor) -> torch.Tensor:
        """
        Discretize conflict scores into budget labels.
        """
        budget_labels = torch.bucketize(conflict_scores, self.boundaries)
        return budget_labels
    
    def forward(
        self,
        budget_logits: torch.Tensor,
        alpha: torch.Tensor,
        beta: torch.Tensor,
        routing_probs: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute budget supervision loss.
        """
        # Compute ground-truth conflict scores
        conflict_scores = self.compute_conflict_score(alpha, beta, routing_probs)
        
        # Discretize to get labels
        budget_labels = self.discretize_budget(conflict_scores)
        
        # Cross-entropy loss
        budget_loss = F.cross_entropy(budget_logits, budget_labels)
        
        return budget_loss


class KSKTTrainer:
    """
    KSKT Trainer implementing three-phase training protocol.
    """
    
    def __init__(
        self,
        model: nn.Module,
        config: TrainingConfig,
        device: torch.device = torch.device('cuda')
    ):
        self.model = model
        self.config = config
        self.device = device
        
        # Initialize loss function
        hidden_size = model.config.hidden_size
        self.loss_fn = KSKTLoss(config, hidden_size)
        
        # Budget supervisor
        self.budget_supervisor = BudgetSupervisor()
        
        # Optimizer
        self.optimizer = None
        self.scheduler = None
        
        # Training state
        self.current_phase = TrainingPhase.PHASE1_SELF_UNDERSTANDING
        self.global_step = 0
    
    def _setup_optimizer(self, lr: float, num_training_steps: int):
        """Setup optimizer and scheduler for current phase."""
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=lr,
            weight_decay=self.config.weight_decay
        )
        
        # Linear warmup + cosine decay
        def lr_lambda(step):
            if step < self.config.warmup_steps:
                return step / self.config.warmup_steps
            progress = (step - self.config.warmup_steps) / (num_training_steps - self.config.warmup_steps)
            return 0.5 * (1 + torch.cos(torch.tensor(progress * 3.14159)).item())
        
        self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)
    
    def train_step(
        self,
        batch: Dict[str, torch.Tensor]
    ) -> Dict[str, float]:
        """
        Execute single training step.
        """
        self.model.train()
        
        # Move batch to device
        input_ids = batch['input_ids'].to(self.device)
        labels = batch['labels'].to(self.device)
        R_proc = batch['R_proc'].to(self.device)
        U_proc = batch['U_proc'].to(self.device)
        attention_mask = batch.get('attention_mask')
        if attention_mask is not None:
            attention_mask = attention_mask.to(self.device)
        
        # Forward pass
        logits, metrics = self.model(
            input_ids, R_proc, U_proc,
            attention_mask=attention_mask,
            return_metrics=True
        )
        
        hidden_states = self.model.layers[-1].post_attention_layernorm.weight.new_zeros(
            logits.shape[0], logits.shape[1], self.model.config.hidden_size
        )  
        
        balance_loss = sum(
            v for k, v in metrics.items() if 'balance_loss' in k
        ) / max(1, sum(1 for k in metrics if 'balance_loss' in k))
        
        # Compute total loss
        total_loss, loss_dict = self.loss_fn(
            logits=logits,
            labels=labels,
            hidden_states=hidden_states,
            R_proc=R_proc,
            U_proc=U_proc,
            balance_loss=balance_loss,
            intent_labels=batch.get('intent_labels'),
            character_mask=batch.get('character_mask')
        )
        
        # Backward pass
        self.optimizer.zero_grad()
        total_loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(
            self.model.parameters(),
            self.config.gradient_clip
        )
        
        # Optimizer step
        self.optimizer.step()
        self.scheduler.step()
        
        self.global_step += 1
        
        # Add metrics to loss dict
        loss_dict['lr'] = self.scheduler.get_last_lr()[0]
        loss_dict.update({k: v for k, v in metrics.items() if isinstance(v, (int, float))})
        
        return loss_dict
    
    def train_phase(
        self,
        phase: TrainingPhase,
        dataloader: torch.utils.data.DataLoader,
        num_epochs: int
    ):
        """
        Train for one phase.
        """
        self.current_phase = phase
        self.loss_fn.set_phase(phase)
        
        # Get learning rate for this phase
        lr_map = {
            TrainingPhase.PHASE1_SELF_UNDERSTANDING: self.config.lr_phase1,
            TrainingPhase.PHASE2_OTHER_UNDERSTANDING: self.config.lr_phase2,
            TrainingPhase.PHASE3_MUTUAL_INTEGRATION: self.config.lr_phase3,
        }
        lr = lr_map[phase]
        
        # Setup optimizer
        num_training_steps = len(dataloader) * num_epochs
        self._setup_optimizer(lr, num_training_steps)
        
        print(f"\n{'='*60}")
        print(f"Starting {phase.name}")
        print(f"Learning rate: {lr}, Epochs: {num_epochs}")
        print(f"{'='*60}\n")
        
        for epoch in range(num_epochs):
            epoch_losses = []
            
            for batch_idx, batch in enumerate(dataloader):
                loss_dict = self.train_step(batch)
                epoch_losses.append(loss_dict['total'])
                
                if batch_idx % 100 == 0:
                    avg_loss = sum(epoch_losses[-100:]) / len(epoch_losses[-100:])
                    print(f"  Epoch {epoch+1}, Step {batch_idx}: Loss = {avg_loss:.4f}")
            
            avg_epoch_loss = sum(epoch_losses) / len(epoch_losses)
            print(f"\nEpoch {epoch+1} completed. Average loss: {avg_epoch_loss:.4f}\n")
    
    def train(self, dataloader: torch.utils.data.DataLoader):
        """
        Complete training following three-phase protocol.
        
        Args:
            dataloader: Training data loader
        """
        # Phase 1: Self-Understanding Foundation (2 epochs)
        self.train_phase(
            TrainingPhase.PHASE1_SELF_UNDERSTANDING,
            dataloader,
            self.config.phase1_epochs
        )
        
        # Phase 2: Other-Understanding Addition (1 epoch)
        self.train_phase(
            TrainingPhase.PHASE2_OTHER_UNDERSTANDING,
            dataloader,
            self.config.phase2_epochs
        )
        
        # Phase 3: Mutual Integration (1 epoch)
        self.train_phase(
            TrainingPhase.PHASE3_MUTUAL_INTEGRATION,
            dataloader,
            self.config.phase3_epochs
        )
        
        print("\n" + "="*60)
        print("Training completed!")
        print("="*60)
==========================================================================

if __name__ == "__main__":
    from kskt_core import KSKTModel, KSKTConfig
    
    # Setup
    config = KSKTConfig(
        hidden_size=256,  
        num_attention_heads=8,
        num_kv_heads=2,
        head_dim=32,
        intermediate_size=512,
        num_layers=4,
    )
    
    train_config = TrainingConfig()
    
    model = KSKTModel(config)
    trainer = KSKTTrainer(model, train_config, device=torch.device('cpu'))
    
    # Create dummy batch
    batch = {
        'input_ids': torch.randint(0, 1000, (4, 64)),
        'labels': torch.randint(0, 1000, (4, 64)),
        'R_proc': torch.randn(4, 16, config.hidden_size),
        'U_proc': torch.randn(4, 32, config.hidden_size),
    }
    
    # Single training step
    loss_dict = trainer.train_step(batch)
    print(f"Loss components: {loss_dict}")