#!/usr/bin/env python3
"""
KSKT Training Script
Implements the three-phase training strategy described in the paper.
"""

import os
import json
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
from torch.optim import AdamW
import wandb
from tqdm import tqdm
import argparse
from typing import Dict, List, Optional
import numpy as np

# Import KSKT model (assuming it's in kskt_model.py)
from kskt_model import KSKTForCausalLM, KSKTConfig


class RolePlayingDataset(Dataset):
    """Dataset for role-playing conversations with role and user context extraction"""
    
    def __init__(self, data_path: str, tokenizer, max_length: int = 2048, phase: str = "self_understanding"):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.phase = phase
        
        # Load data
        with open(data_path, 'r', encoding='utf-8') as f:
            self.data = json.load(f)
        
        # Filter data based on training phase
        self.data = self._filter_data_by_phase()
        
    def _filter_data_by_phase(self):
        """Filter data based on training phase"""
        if self.phase == "self_understanding":
            # Focus on character comprehension data
            return [item for item in self.data if item.get('type') == 'character_profile']
        elif self.phase == "other_understanding":
            # Focus on instruction-following data
            return [item for item in self.data if item.get('type') == 'instruction_following']
        elif self.phase == "mutual_understanding":
            # Use all data for balanced training
            return self.data
        else:
            return self.data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Extract components
        role_description = item.get('role_description', '')
        conversation = item.get('conversation', [])
        
        # Format conversation
        text_parts = []
        if role_description:
            text_parts.append(f"<role>{role_description}</role>")
        
        for turn in conversation:
            if turn['speaker'] == 'user':
                text_parts.append(f"<user>{turn['content']}</user>")
            else:
                text_parts.append(f"<assistant>{turn['content']}</assistant>")
        
        full_text = "\n".join(text_parts)
        
        # Tokenize
        tokenized = self.tokenizer(
            full_text,
            max_length=self.max_length,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )
        
        input_ids = tokenized['input_ids'].squeeze(0)
        attention_mask = tokenized['attention_mask'].squeeze(0)
        
        # Create role and user masks
        role_mask = self._create_role_mask(full_text, input_ids)
        user_mask = self._create_user_mask(full_text, input_ids)
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': input_ids.clone(),
            'role_mask': role_mask,
            'user_mask': user_mask
        }
    
    def _create_role_mask(self, text: str, input_ids: torch.Tensor):
        """Create mask for role description tokens"""
        # Find role section
        role_start = text.find('<role>')
        role_end = text.find('</role>')
        
        if role_start == -1 or role_end == -1:
            return torch.zeros_like(input_ids, dtype=torch.bool)
        
        role_text = text[role_start:role_end+7]  # Include tags
        role_tokens = self.tokenizer(role_text, add_special_tokens=False)['input_ids']
        
        # Create mask
        mask = torch.zeros_like(input_ids, dtype=torch.bool)
        if len(role_tokens) <= len(input_ids):
            mask[:len(role_tokens)] = True
        
        return mask
    
    def _create_user_mask(self, text: str, input_ids: torch.Tensor):
        """Create mask for user input tokens"""
        mask = torch.zeros_like(input_ids, dtype=torch.bool)
        
        # Find all user sections
        user_starts = []
        user_ends = []
        pos = 0
        while True:
            start = text.find('<user>', pos)
            if start == -1:
                break
            end = text.find('</user>', start)
            if end == -1:
                break
            user_starts.append(start)
            user_ends.append(end + 7)  # Include closing tag
            pos = end + 7
        
        # Tokenize user sections and set mask
        for start, end in zip(user_starts, user_ends):
            user_text = text[start:end]
            user_tokens = self.tokenizer(user_text, add_special_tokens=False)['input_ids']
            # This is a simplified approach - in practice, you'd want more precise token alignment
            
        return mask


class KSKTTrainer:
    """Trainer class for KSKT model with three-phase training"""
    
    def __init__(
        self,
        model: KSKTForCausalLM,
        tokenizer,
        train_dataset: Dataset,
        val_dataset: Dataset,
        config: Dict,
        device: str = 'cuda'
    ):
        self.model = model.to(device)
        self.tokenizer = tokenizer
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.config = config
        self.device = device
        
        # Initialize optimizer
        self.optimizer = AdamW(
            self.model.parameters(),
            lr=config['learning_rate'],
            weight_decay=config['weight_decay']
        )
        
        # Initialize scheduler
        total_steps = len(train_dataset) * config['num_epochs'] // config['batch_size']
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=config['warmup_steps'],
            num_training_steps=total_steps
        )
        
        # Training state
        self.global_step = 0
        self.best_loss = float('inf')
        
    def train_phase(self, phase: str, num_epochs: int, phase_config: Dict):
        """Train a specific phase of the three-phase training"""
        print(f"Starting {phase} training phase...")
        
        # Update model loss weights for this phase
        if phase == "self_understanding":
            self.model.lambda_consistency = phase_config.get('lambda_consistency', 0.2)
            self.model.lambda_understanding = phase_config.get('lambda_understanding', 0.0)
        elif phase == "other_understanding":
            self.model.lambda_consistency = phase_config.get('lambda_consistency', 0.1)
            self.model.lambda_understanding = phase_config.get('lambda_understanding', 0.3)
        elif phase == "mutual_understanding":
            self.model.lambda_consistency = phase_config.get('lambda_consistency', 0.1)
            self.model.lambda_understanding = phase_config.get('lambda_understanding', 0.2)
        
        # Create phase-specific dataloader
        train_loader = DataLoader(
            self.train_dataset,
            batch_size=self.config['batch_size'],
            shuffle=True,
            num_workers=4
        )
        
        for epoch in range(num_epochs):
            self.model.train()
            epoch_loss = 0.0
            epoch_steps = 0
            
            progress_bar = tqdm(train_loader, desc=f"{phase} Epoch {epoch+1}/{num_epochs}")
            
            for batch in progress_bar:
                batch = {k: v.to(self.device) for k, v in batch.items()}
                
                # Forward pass
                outputs = self.model(**batch)
                loss = outputs['loss']
                
                # Backward pass
                loss.backward()
                
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['max_grad_norm'])
                
                # Optimizer step
                self.optimizer.step()
                self.scheduler.step()
                self.optimizer.zero_grad()
                
                # Update tracking
                epoch_loss += loss.item()
                epoch_steps += 1
                self.global_step += 1
                
                # Update progress bar
                progress_bar.set_postfix({
                    'loss': loss.item(),
                    'lr': self.scheduler.get_last_lr()[0],
                    'step': self.global_step
                })
                
                # Log to wandb
                if self.global_step % self.config['logging_steps'] == 0:
                    self._log_training_step(outputs, phase)
            
            # Validation
            val_loss = self._validate()
            
            print(f"{phase} Epoch {epoch+1}: Train Loss = {epoch_loss/epoch_steps:.4f}, Val Loss = {val_loss:.4f}")
            
            # Save checkpoint
            if val_loss < self.best_loss:
                self.best_loss = val_loss
                self._save_checkpoint(f"best_{phase}")
            
            # Log epoch results
            wandb.log({
                f'{phase}/epoch_train_loss': epoch_loss / epoch_steps,
                f'{phase}/epoch_val_loss': val_loss,
                'epoch': epoch,
                'phase': phase
            })
    
    def _validate(self) -> float:
        """Run validation and return average loss"""
        self.model.eval()
        
        val_loader = DataLoader(
            self.val_dataset,
            batch_size=self.config['batch_size'],
            shuffle=False,
            num_workers=4
        )
        
        total_loss = 0.0
        total_steps = 0
        
        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validation"):
                batch = {k: v.to(self.device) for k, v in batch.items()}
                
                outputs = self.model(**batch)
                loss = outputs['loss']
                
                total_loss += loss.item()
                total_steps += 1
        
        return total_loss / total_steps
    
    def _log_training_step(self, outputs: Dict, phase: str):
        """Log training step metrics to wandb"""
        aux_losses = outputs['auxiliary_losses']
        
        log_dict = {
            f'{phase}/step_loss': outputs['loss'].item(),
            f'{phase}/load_balance_loss': aux_losses['load_balance_loss'].item(),
            'learning_rate': self.scheduler.get_last_lr()[0],
            'global_step': self.global_step
        }
        
        # Log fusion weights statistics if available
        if aux_losses['fusion_weights']:
            fusion_weights = aux_losses['fusion_weights'][-1]  # Last layer
            alpha, beta = fusion_weights
            log_dict.update({
                f'{phase}/alpha_mean': alpha.mean().item(),
                f'{phase}/beta_mean': beta.mean().item(),
                f'{phase}/fusion_balance': (alpha.mean() - beta.mean()).abs().item()
            })
        
        # Log expert routing statistics if available  
        if aux_losses['routing_probs']:
            routing_probs = aux_losses['routing_probs'][-1]  # Last layer
            expert_usage = routing_probs.mean(dim=0)
            for i, usage in enumerate(expert_usage):
                log_dict[f'{phase}/expert_{i}_usage'] = usage.item()
        
        wandb.log(log_dict)
    
    def _save_checkpoint(self, checkpoint_name: str):
        """Save model checkpoint"""
        checkpoint_dir = self.config['output_dir']
        os.makedirs(checkpoint_dir, exist_ok=True)
        
        checkpoint_path = os.path.join(checkpoint_dir, f"{checkpoint_name}.pt")
        
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'global_step': self.global_step,
            'best_loss': self.best_loss,
            'config': self.config
        }, checkpoint_path)
        
        print(f"Checkpoint saved: {checkpoint_path}")
    
    def three_phase_training(self):
        """Execute the complete three-phase training strategy"""
        
        # Phase 1: Self-understanding pre-training (2 epochs)
        print("="*50)
        print("PHASE 1: SELF-UNDERSTANDING PRE-TRAINING")
        print("="*50)
        
        # Update dataset for self-understanding phase
        self.train_dataset.phase = "self_understanding"
        self.train_dataset.data = self.train_dataset._filter_data_by_phase()
        
        phase1_config = {
            'lambda_consistency': 0.2,
            'lambda_understanding': 0.0
        }
        self.train_phase("self_understanding", 2, phase1_config)
        
        # Phase 2: Other-understanding fine-tuning (1 epoch)
        print("="*50)
        print("PHASE 2: OTHER-UNDERSTANDING FINE-TUNING")
        print("="*50)
        
        # Update dataset for other-understanding phase
        self.train_dataset.phase = "other_understanding"
        self.train_dataset.data = self.train_dataset._filter_data_by_phase()
        
        # Reduce learning rate for fine-tuning
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.config['learning_rate'] * 0.5
        
        phase2_config = {
            'lambda_consistency': 0.1,
            'lambda_understanding': 0.3
        }
        self.train_phase("other_understanding", 1, phase2_config)
        
        # Phase 3: Mutual understanding alignment (1 epoch)
        print("="*50)
        print("PHASE 3: MUTUAL UNDERSTANDING ALIGNMENT")
        print("="*50)
        
        # Update dataset for mutual understanding phase
        self.train_dataset.phase = "mutual_understanding"
        self.train_dataset.data = self.train_dataset._filter_data_by_phase()
        
        # Further reduce learning rate for alignment
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.config['learning_rate'] * 0.25
        
        phase3_config = {
            'lambda_consistency': 0.1,
            'lambda_understanding': 0.2
        }
        self.train_phase("mutual_understanding", 1, phase3_config)
        
        print("="*50)
        print("THREE-PHASE TRAINING COMPLETED!")
        print("="*50)


def main():
    parser = argparse.ArgumentParser(description="Train KSKT model")
    parser.add_argument('--train_data', type=str, required=True, help='Path to training data JSON file')
    parser.add_argument('--val_data', type=str, required=True, help='Path to validation data JSON file')
    parser.add_argument('--output_dir', type=str, default='./checkpoints', help='Output directory for checkpoints')
    parser.add_argument('--base_model', type=str, default='Qwen3-4B-Thinking', help='Base model name')
    parser.add_argument('--learning_rate', type=float, default=2e-5, help='Learning rate')
    parser.add_argument('--batch_size', type=int, default=8, help='Batch size')
    parser.add_argument('--max_length', type=int, default=2048, help='Maximum sequence length')
    parser.add_argument('--weight_decay', type=float, default=0.01, help='Weight decay')
    parser.add_argument('--warmup_steps', type=int, default=500, help='Warmup steps')
    parser.add_argument('--logging_steps', type=int, default=100, help='Logging frequency')
    parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Max gradient norm for clipping')
    parser.add_argument('--wandb_project', type=str, default='kskt-training', help='Wandb project name')
    
    args = parser.parse_args()
    
    # Initialize wandb
    wandb.init(
        project=args.wandb_project,
        config=args.__dict__
    )
    
    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.base_model)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Add special tokens for role-playing
    special_tokens = ['<role>', '</role>', '<user>', '</user>', '<assistant>', '</assistant>']
    tokenizer.add_special_tokens({'additional_special_tokens': special_tokens})
    
    # Initialize model config and model
    config = KSKTConfig()
    config.vocab_size = len(tokenizer)
    
    model = KSKTForCausalLM(config)
    
    # Resize token embeddings to accommodate new special tokens
    model.model.embed_tokens.weight = torch.nn.Parameter(
        torch.cat([
            model.model.embed_tokens.weight,
            torch.randn(len(special_tokens), config.hidden_size) * 0.02
        ], dim=0)
    )
    
    # Initialize datasets
    train_dataset = RolePlayingDataset(
        args.train_data, 
        tokenizer, 
        max_length=args.max_length,
        phase="self_understanding"
    )
    
    val_dataset = RolePlayingDataset(
        args.val_data,
        tokenizer,
        max_length=args.max_length,
        phase="mutual_understanding"  # Use all data for validation
    )
    
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    
    # Initialize trainer
    training_config = {
        'learning_rate': args.learning_rate,
        'batch_size': args.batch_size,
        'weight_decay': args.weight_decay,
        'warmup_steps': args.warmup_steps,
        'max_grad_norm': args.max_grad_norm,
        'logging_steps': args.logging_steps,
        'output_dir': args.output_dir,
        'num_epochs': 4  # Total across all phases
    }
    
    trainer = KSKTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        config=training_config,
        device=device
    )
    
    # Execute three-phase training
    trainer.three_phase_training()
    
    print("Training completed successfully!")


if __name__ == "__main__":
    main()
