import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from tqdm import tqdm
from typing import Dict, List, Tuple, Optional
import os
import json
from datetime import datetime

from model.vae_models.vae import VAE, vae_loss


class VAETrainer:
    """Trainer class for VAE models."""
    
    def __init__(self, model: VAE, config: Dict, device: torch.device):
        self.model = model.to(device)
        self.config = config
        self.device = device
        
        # Training parameters
        self.epochs = config['training']['epochs']
        self.learning_rate = config['training']['learning_rate']
        self.weight_decay = config['training']['weight_decay']
        self.beta = config['training']['beta']
        self.early_stopping_patience = config['training']['early_stopping_patience']
        self.save_frequency = config['training']['save_frequency']
        
        # Optimizer
        self.optimizer = optim.Adam(
            self.model.parameters(),
            lr=self.learning_rate,
            weight_decay=self.weight_decay
        )
        
        # Logging
        self.log_dir = config['logging']['log_dir']
        self.model_dir = config['logging']['model_dir']
        os.makedirs(self.log_dir, exist_ok=True)
        os.makedirs(self.model_dir, exist_ok=True)
        
        self.writer = SummaryWriter(self.log_dir)
        
        # Objective (elbo or iwae)
        self.objective = self.config['training'].get('objective', 'elbo')
        # iwae_k can be int or list[int] (per-layer)
        raw_k = self.config['training'].get('iwae_k', 5)
        self.iwae_k = raw_k

        # Training history
        self.train_losses = []
        self.val_losses = []
        self.recon_losses = []
        self.kl_losses = []
        
        # Early stopping
        self.best_val_loss = float('inf')
        self.patience_counter = 0
    
    def _compute_iwae_batch(self, data: torch.Tensor) -> Dict[str, float]:
        """Compute IWAE loss and diagnostics for one batch."""
        # If model provides hierarchical-aware IWAE, delegate
        if hasattr(self.model, 'compute_iwae_loss'):
            out = self.model.compute_iwae_loss(data, k=self.iwae_k)
            # Ensure keys exist and are tensors
            return {
                'loss': out['loss'],
                'recon_loss': out.get('recon_loss', out['loss'].detach() * 0.0 + float('nan')),
                'kl_loss': out.get('kl_loss', out['loss'].detach() * 0.0 + float('nan')),
            }
        # Encode
        mu, logvar = self.model.encode(data)
        std = torch.exp(0.5 * logvar)

        batch_size = data.size(0)
        k = int(self.iwae_k) if not isinstance(self.iwae_k, list) else int(max(self.iwae_k))
        latent_dim = mu.size(1)

        # Sample K latents per data point
        eps = torch.randn(batch_size, k, latent_dim, device=data.device)
        z = mu.unsqueeze(1) + std.unsqueeze(1) * eps  # [B, K, D]

        # Decode each sample
        # Flatten BK dimension loop for efficiency
        z_flat = z.view(batch_size * k, latent_dim)
        recon_flat = self.model.decode(z_flat)  # shape [BK, ...]
        # Repeat inputs accordingly
        if data.dim() == 2:
            x_rep = data.unsqueeze(1).repeat(1, k, 1).view_as(recon_flat)
            # Per-sample reconstruction log-likelihood (Bernoulli)
            bce = F.binary_cross_entropy(recon_flat, x_rep, reduction='none')
            bce = bce.view(batch_size, k, -1).sum(dim=2)  # sum over features
        else:
            # data shape [B, C, H, W]
            x_rep = data.unsqueeze(1).repeat(1, k, 1, 1, 1).view_as(recon_flat)
            bce = F.binary_cross_entropy(recon_flat, x_rep, reduction='none')
            bce = bce.view(batch_size, k, -1).sum(dim=2)

        log_px_z = -bce  # [B, K]

        # Prior and posterior log-probs (drop constants)
        # log p(z) ~ -0.5 * ||z||^2
        log_pz = -0.5 * (z ** 2).view(batch_size, k, -1).sum(dim=2)
        # log q(z|x) ~ -0.5 * ( ((z-mu)^2 / var) + logvar )
        var = torch.exp(logvar)
        log_qz_x = -0.5 * (((z - mu.unsqueeze(1)) ** 2) / var.unsqueeze(1)).view(batch_size, k, -1).sum(dim=2)
        log_qz_x = log_qz_x - 0.5 * logvar.unsqueeze(1).view(batch_size, 1, -1).sum(dim=2)

        log_w = log_px_z + log_pz - log_qz_x  # [B, K]
        # IWAE loss: - E[ log (1/K sum_j w_j) ] ~ -mean_b(logsumexp - logK)
        logsumexp = torch.logsumexp(log_w, dim=1)  # [B]
        iwae_bound = (logsumexp - torch.log(torch.tensor(k, dtype=logsumexp.dtype, device=logsumexp.device)))
        loss = -iwae_bound.mean()

        # Diagnostics: weighted recon (mean over features) and analytic KL
        # recon mean per sample per k
        if data.dim() == 2:
            bce_mean = F.binary_cross_entropy(recon_flat, x_rep, reduction='none').view(batch_size, k, -1).mean(dim=2)
        else:
            bce_mean = F.binary_cross_entropy(recon_flat, x_rep, reduction='none').view(batch_size, k, -1).mean(dim=2)

        # Normalized importance weights per data point
        w_tilde = torch.softmax(log_w, dim=1).detach()
        recon_loss = (w_tilde * bce_mean).sum(dim=1).mean()

        # Analytic KL per data point
        kl_per = 0.5 * (var + mu ** 2 - 1.0 - logvar).sum(dim=1)
        kl_loss = kl_per.mean()

        return {
            'loss': loss,
            'recon_loss': recon_loss,
            'kl_loss': kl_loss,
        }

    def _compute_elbo_batch(self, data: torch.Tensor) -> Dict[str, torch.Tensor]:
        """Compute ELBO-based loss, delegating to model if available."""
        if hasattr(self.model, 'compute_elbo_loss'):
            total, recon, kl = self.model.compute_elbo_loss(data, beta=self.beta)
            return {'loss': total, 'recon_loss': recon, 'kl_loss': kl}
        # Fallback to standard VAE loss
        recon_batch, mu, logvar = self.model(data)
        loss, recon_loss, kl_loss = vae_loss(recon_batch, data, mu, logvar, self.beta)
        return {'loss': loss, 'recon_loss': recon_loss, 'kl_loss': kl_loss}

    def train_epoch(self, train_loader: DataLoader) -> Dict[str, float]:
        """Train for one epoch."""
        self.model.train()
        total_loss = 0
        total_recon_loss = 0
        total_kl_loss = 0
        num_batches = 0
        
        pbar = tqdm(train_loader, desc="Training")
        for batch_idx, (data, _) in enumerate(pbar):
            data = data.to(self.device)
            
            self.optimizer.zero_grad()
            
            if self.objective == 'iwae':
                out = self._compute_iwae_batch(data)
                loss = out['loss']
                recon_loss = out['recon_loss']
                kl_loss = out['kl_loss']
            else:
                out = self._compute_elbo_batch(data)
                loss = out['loss']
                recon_loss = out['recon_loss']
                kl_loss = out['kl_loss']
            
            # Backward pass
            loss.backward()
            self.optimizer.step()
            
            # Update statistics
            total_loss += loss.item()
            total_recon_loss += recon_loss.item()
            total_kl_loss += kl_loss.item()
            num_batches += 1
            
            # Update progress bar
            pbar.set_postfix({
                'Loss': f"{loss.item():.4f}",
                'Recon': f"{recon_loss.item():.4f}",
                'KL': f"{kl_loss.item():.4f}"
            })
        
        return {
            'loss': total_loss / num_batches,
            'recon_loss': total_recon_loss / num_batches,
            'kl_loss': total_kl_loss / num_batches
        }
    
    def validate(self, val_loader: DataLoader) -> Dict[str, float]:
        """Validate the model."""
        self.model.eval()
        total_loss = 0
        total_recon_loss = 0
        total_kl_loss = 0
        num_batches = 0
        
        with torch.no_grad():
            for data, _ in val_loader:
                data = data.to(self.device)
                if self.objective == 'iwae':
                    out = self._compute_iwae_batch(data)
                    loss = out['loss']
                    recon_loss = out['recon_loss']
                    kl_loss = out['kl_loss']
                else:
                    out = self._compute_elbo_batch(data)
                    loss = out['loss']
                    recon_loss = out['recon_loss']
                    kl_loss = out['kl_loss']
                
                # Update statistics
                total_loss += loss.item()
                total_recon_loss += recon_loss.item()
                total_kl_loss += kl_loss.item()
                num_batches += 1
        
        return {
            'loss': total_loss / num_batches,
            'recon_loss': total_recon_loss / num_batches,
            'kl_loss': total_kl_loss / num_batches
        }
    
    def train(self, train_loader: DataLoader, val_loader: DataLoader) -> Dict[str, List[float]]:
        """Train the model."""
        print(f"Starting training for {self.epochs} epochs...")
        
        for epoch in range(self.epochs):
            print(f"\nEpoch {epoch+1}/{self.epochs}")
            
            # Training
            train_metrics = self.train_epoch(train_loader)
            
            # Validation
            val_metrics = self.validate(val_loader)
            
            # Log metrics
            self.log_metrics(epoch, train_metrics, val_metrics)
            
            # Save training history
            self.train_losses.append(train_metrics['loss'])
            self.val_losses.append(val_metrics['loss'])
            self.recon_losses.append(train_metrics['recon_loss'])
            self.kl_losses.append(train_metrics['kl_loss'])
            
            # Early stopping
            if val_metrics['loss'] < self.best_val_loss:
                self.best_val_loss = val_metrics['loss']
                self.patience_counter = 0
                self.save_model('best_model.pth')
            else:
                self.patience_counter += 1
            
            # Save checkpoint
            if (epoch + 1) % self.save_frequency == 0:
                self.save_model(f'checkpoint_epoch_{epoch+1}.pth')
            
            # Early stopping check
            if self.patience_counter >= self.early_stopping_patience:
                print(f"Early stopping triggered after {epoch+1} epochs")
                break
        
        return {
            'train_losses': self.train_losses,
            'val_losses': self.val_losses,
            'recon_losses': self.recon_losses,
            'kl_losses': self.kl_losses
        }
    
    def log_metrics(self, epoch: int, train_metrics: Dict[str, float], 
                   val_metrics: Dict[str, float]) -> None:
        """Log metrics to tensorboard."""
        # Training metrics
        self.writer.add_scalar('Loss/Train', train_metrics['loss'], epoch)
        self.writer.add_scalar('Reconstruction_Loss/Train', train_metrics['recon_loss'], epoch)
        self.writer.add_scalar('KL_Loss/Train', train_metrics['kl_loss'], epoch)
        
        # Validation metrics
        self.writer.add_scalar('Loss/Validation', val_metrics['loss'], epoch)
        self.writer.add_scalar('Reconstruction_Loss/Validation', val_metrics['recon_loss'], epoch)
        self.writer.add_scalar('KL_Loss/Validation', val_metrics['kl_loss'], epoch)
        
        print(f"Train Loss: {train_metrics['loss']:.4f}, Val Loss: {val_metrics['loss']:.4f}")
    
    def save_model(self, filename: str) -> None:
        """Save model checkpoint."""
        checkpoint = {
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'config': self.config,
            'train_losses': self.train_losses,
            'val_losses': self.val_losses,
            'best_val_loss': self.best_val_loss
        }
        
        save_path = os.path.join(self.model_dir, filename)
        torch.save(checkpoint, save_path)
        print(f"Model saved to {save_path}")
    
    def load_model(self, filename: str) -> None:
        """Load model checkpoint."""
        load_path = os.path.join(self.model_dir, filename)
        checkpoint = torch.load(load_path, map_location=self.device)
        
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.train_losses = checkpoint.get('train_losses', [])
        self.val_losses = checkpoint.get('val_losses', [])
        self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
        
        print(f"Model loaded from {load_path}")
    
    def close(self) -> None:
        """Close tensorboard writer."""
        self.writer.close() 