import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
import logging
from tqdm import tqdm
import json
import os

from bridge_core import BRIDGE, BRIDGEConfig, MemoryTiers, count_parameters

@dataclass
class TrainingConfig:
    """Training hyperparameters"""
    # Optimization
    learning_rate: float = 1e-5
    weight_decay: float = 0.01
    batch_size: int = 32
    num_epochs: int = 50
    max_steps: int = 50000
    warmup_steps: int = 1000
    gradient_clip: float = 1.0
    
    # Loss weights
    lambda_cycle: float = 0.1    
    lambda_persona: float = 0.5  
    
    # Logging
    log_interval: int = 100
    eval_interval: int = 1000
    save_interval: int = 5000
    
    # Paths
    output_dir: str = "./checkpoints"
    log_file: str = "training.log"

class PersonaDialogueDataset(Dataset):
    """
    Dataset for persona-grounded dialogue training.

    """
    
    def __init__(
        self,
        data_path: str,
        tokenizer,
        max_length: int = 2048
    ):
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        # Load data
        with open(data_path, 'r') as f:
            self.data = json.load(f)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Format input
        persona = item['persona']
        history = item.get('history', [])
        response = item['response']
        
        # Build prompt
        prompt = f"[Persona]\n{persona}\n\n[Dialogue]\n"
        for turn in history:
            prompt += f"User: {turn['user']}\nAssistant: {turn['response']}\n"
        prompt += f"User: {item.get('current_user', '')}\nAssistant:"
        
        # Tokenize
        inputs = self.tokenizer(
            prompt,
            max_length=self.max_length,
            truncation=True,
            return_tensors='pt'
        )
        
        targets = self.tokenizer(
            response,
            max_length=512,
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': inputs['input_ids'].squeeze(0),
            'attention_mask': inputs['attention_mask'].squeeze(0),
            'labels': targets['input_ids'].squeeze(0),
            'persona': persona
        }


class BRIDGETrainer:
    """
    Trainer for BRIDGE with frozen LLM backbone.
    
    Implements parameter-efficient fine-tuning where only BRIDGE
    modules are updated while backbone remains frozen.
    """
    
    def __init__(
        self,
        model: BRIDGE,
        backbone: nn.Module,  # Frozen LLM
        train_config: TrainingConfig,
        bridge_config: BRIDGEConfig,
        device: torch.device
    ):
        self.model = model.to(device)
        self.backbone = backbone.to(device)
        self.train_config = train_config
        self.bridge_config = bridge_config
        self.device = device
        
        # Freeze backbone
        for param in self.backbone.parameters():
            param.requires_grad = False
        self.backbone.eval()
        
        # Setup optimizer
        self.optimizer = AdamW(
            self.model.parameters(),
            lr=train_config.learning_rate,
            weight_decay=train_config.weight_decay
        )
        
        # Setup scheduler
        self.scheduler = CosineAnnealingLR(
            self.optimizer,
            T_max=train_config.max_steps,
            eta_min=1e-7
        )
        
        # Logging
        self.setup_logging()
        
        # Metrics
        self.global_step = 0
        self.best_eval_score = float('-inf')
    
    def setup_logging(self):
        """Setup logging configuration."""
        os.makedirs(self.train_config.output_dir, exist_ok=True)
        
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler(
                    os.path.join(self.train_config.output_dir, self.train_config.log_file)
                ),
                logging.StreamHandler()
            ]
        )
        self.logger = logging.getLogger(__name__)
    
    def get_backbone_hidden_states(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor
    ) -> torch.Tensor:
        """
        Get frozen backbone hidden states.
        """
        with torch.no_grad():
            outputs = self.backbone(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True
            )
            # Use last layer hidden states
            H_t = outputs.hidden_states[-1]
        return H_t
    
    def encode_persona(
        self,
        persona_texts: List[str],
        tokenizer
    ) -> torch.Tensor:
        """
        Encode persona descriptions.
        """
        inputs = tokenizer(
            persona_texts,
            padding=True,
            truncation=True,
            return_tensors='pt'
        ).to(self.device)
        
        with torch.no_grad():
            outputs = self.backbone(
                input_ids=inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                output_hidden_states=True
            )
            # Mean pooling
            hidden = outputs.hidden_states[-1]
            mask = inputs['attention_mask'].unsqueeze(-1)
            persona_encoding = (hidden * mask).sum(1) / mask.sum(1)
        
        return persona_encoding
    
    def compute_lm_loss(
        self,
        H_hat_t: torch.Tensor,
        labels: torch.Tensor,
        lm_head: nn.Module
    ) -> torch.Tensor:
        """
        Compute language modeling loss through frozen LM head.
        """
        # Project conditioned hidden states through LM head
        logits = lm_head(H_hat_t)
        
        # Shift for next-token prediction
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        
        # Compute cross-entropy loss
        loss = F.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
            ignore_index=-100
        )
        
        return loss
    
    def train_step(
        self,
        batch: Dict[str, torch.Tensor],
        m_current: MemoryTiers,
        m_anchor: MemoryTiers,
        tokenizer,
        lm_head: nn.Module
    ) -> Tuple[Dict[str, float], MemoryTiers]:
        """
        Execute one training step.
        """
        self.model.train()
        
        # Move batch to device
        input_ids = batch['input_ids'].to(self.device)
        attention_mask = batch['attention_mask'].to(self.device)
        labels = batch['labels'].to(self.device)
        
        # Get frozen backbone hidden states
        H_t = self.get_backbone_hidden_states(input_ids, attention_mask)
        
        # Encode response for persona classifier
        response_encoding = H_t.mean(dim=1)  
        
        # BRIDGE forward pass
        outputs = self.model(
            H_t=H_t,
            m_prev=m_current,
            m_anchor=m_anchor,
            response_encoding=response_encoding
        )
        
        # Compute LM loss
        lm_loss = self.compute_lm_loss(outputs['H_hat_t'], labels, lm_head)
        
        # Compute total loss
        total_loss = self.model.compute_total_loss(
            lm_loss=lm_loss,
            bridge_losses=outputs['losses'],
            lambda_1=self.train_config.lambda_cycle,
            lambda_2=self.train_config.lambda_persona
        )
        
        # Backward pass
        self.optimizer.zero_grad()
        total_loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(
            self.model.parameters(),
            self.train_config.gradient_clip
        )
        
        # Update
        self.optimizer.step()
        self.scheduler.step()
        
        # Collect metrics
        metrics = {
            'total_loss': total_loss.item(),
            'lm_loss': lm_loss.item(),
            'cycle_loss': outputs['losses']['L_cycle'].item(),
            'persona_loss': outputs['losses'].get('L_persona', torch.tensor(0.0)).item(),
            'lyapunov_energy': outputs['losses']['V_mt'].item(),
            'learning_rate': self.scheduler.get_last_lr()[0]
        }
        
        return metrics, outputs['m_updated']
    
    def train_epoch(
        self,
        train_loader: DataLoader,
        tokenizer,
        lm_head: nn.Module
    ) -> Dict[str, float]:
        """
        Train for one epoch.
        """
        epoch_metrics = {
            'total_loss': 0.0,
            'lm_loss': 0.0,
            'cycle_loss': 0.0,
            'persona_loss': 0.0,
            'lyapunov_energy': 0.0
        }
        num_batches = 0
        
        pbar = tqdm(train_loader, desc="Training")
        
        for batch in pbar:
            # Initialize memory from persona
            persona_encoding = self.encode_persona(
                batch['persona'], tokenizer
            )
            m_current, m_anchor = self.model.initialize_memory(persona_encoding)
            
            # Train step
            metrics, m_updated = self.train_step(
                batch, m_current, m_anchor, tokenizer, lm_head
            )
            
            # Accumulate metrics
            for k in epoch_metrics:
                if k in metrics:
                    epoch_metrics[k] += metrics[k]
            num_batches += 1
            
            # Update progress bar
            pbar.set_postfix({
                'loss': f"{metrics['total_loss']:.4f}",
                'V(m)': f"{metrics['lyapunov_energy']:.4f}"
            })
            
            # Logging
            self.global_step += 1
            if self.global_step % self.train_config.log_interval == 0:
                self.logger.info(
                    f"Step {self.global_step}: "
                    f"loss={metrics['total_loss']:.4f}, "
                    f"lm_loss={metrics['lm_loss']:.4f}, "
                    f"cycle_loss={metrics['cycle_loss']:.4f}, "
                    f"V(m)={metrics['lyapunov_energy']:.4f}"
                )
            
            # Save checkpoint
            if self.global_step % self.train_config.save_interval == 0:
                self.save_checkpoint()
        
        # Average metrics
        for k in epoch_metrics:
            epoch_metrics[k] /= max(num_batches, 1)
        
        return epoch_metrics
    
    def save_checkpoint(self, suffix: str = ""):
        """Save model checkpoint."""
        checkpoint_path = os.path.join(
            self.train_config.output_dir,
            f"bridge_step{self.global_step}{suffix}.pt"
        )
        
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'global_step': self.global_step,
            'config': self.bridge_config
        }, checkpoint_path)
        
        self.logger.info(f"Saved checkpoint to {checkpoint_path}")
    
    def load_checkpoint(self, checkpoint_path: str):
        """Load model checkpoint."""
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.global_step = checkpoint['global_step']
        
        self.logger.info(f"Loaded checkpoint from {checkpoint_path}")


class MultiTurnTrainer(BRIDGETrainer):
    """
    Extended trainer that maintains memory across dialogue turns.
    """
    
    def train_conversation(
        self,
        conversation: List[Dict], 
        persona_encoding: torch.Tensor,
        tokenizer,
        lm_head: nn.Module
    ) -> Dict[str, float]:
        """
        Train on a full conversation with memory continuity.
        """
        # Initialize memory
        m_current, m_anchor = self.model.initialize_memory(persona_encoding)
        
        conv_metrics = {
            'total_loss': 0.0,
            'lm_loss': 0.0,
            'cycle_loss': 0.0,
            'persona_loss': 0.0,
            'lyapunov_energy': 0.0
        }
        
        for turn_idx, turn in enumerate(conversation):
            # Prepare turn batch
            batch = self.prepare_turn_batch(turn, tokenizer)
            
            # Train step with memory continuity
            metrics, m_current = self.train_step(
                batch, m_current, m_anchor, tokenizer, lm_head
            )
            
            # Accumulate metrics
            for k in conv_metrics:
                if k in metrics:
                    conv_metrics[k] += metrics[k]
        
        # Average over turns
        num_turns = len(conversation)
        for k in conv_metrics:
            conv_metrics[k] /= max(num_turns, 1)
        
        return conv_metrics
    
    def prepare_turn_batch(
        self,
        turn: Dict,
        tokenizer
    ) -> Dict[str, torch.Tensor]:
        """Prepare a single turn as a batch."""
        inputs = tokenizer(
            turn['input'],
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )
        
        labels = tokenizer(
            turn['response'],
            max_length=512,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )
        
        return {
            'input_ids': inputs['input_ids'],
            'attention_mask': inputs['attention_mask'],
            'labels': labels['input_ids'],
            'persona': [turn.get('persona', '')]
        }
