import torch
import torch.nn.functional as F
import numpy as np
import random
from models import ConditionalVAE, EnergyModel, MixtureOfCVAE
from utils import encode_graph_condition


# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Global parameters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def morph_phase(solution_dataset, num_epochs=500, initial_experts=1, max_experts=9, expert_add_threshold=0.425):
    """
    Implements the Morph phase to train an EBM-guided mixture of CVAEs
    
    Args:
        solution_dataset: Output from the Forge phase
        num_epochs: Number of training epochs
        initial_experts: Initial number of experts in the mixture
        max_experts: Maximum number of experts to add
        expert_add_threshold: Threshold for adding new experts
    
    Returns:
        ebm: Trained Energy-Based Model
        mix_cvae: Trained Mixture of CVAEs
    """
    if not solution_dataset:
        print("Warning: Solution dataset is empty. Cannot proceed with Morph phase.")
        # Return dummy models
        ebm = EnergyModel(input_dim=10, hidden_dim=512, num_layers=6).to(device)
        mix_cvae = MixtureOfCVAE(
            input_dim=10,
            condition_dim=10,
            num_experts=initial_experts,
            latent_dim=128,
            hidden_dim=512
        ).to(device)
        return ebm, mix_cvae
    
    # Find the maximum dimension for perturbation vectors
    max_edge_dim = 0
    for _, _, _, _, perturbation_vector in solution_dataset:
        max_edge_dim = max(max_edge_dim, len(perturbation_vector))
    
    # Extract dimensions from a sample
    sample_G, sample_K, sample_T, sample_perturbation_dict, _ = solution_dataset[0]
    
    # Encode sample condition to get dimension
    sample_condition = encode_graph_condition(sample_G, sample_K, sample_T)
    condition_dim = sample_condition.size(0)
    
    print(f"Maximum perturbation dimension: {max_edge_dim}, Condition dimension: {condition_dim}")
    
    # Initialize models
    ebm = EnergyModel(input_dim=max_edge_dim, hidden_dim=512, num_layers=6).to(device)
    mix_cvae = MixtureOfCVAE(
        input_dim=max_edge_dim,
        condition_dim=condition_dim,
        num_experts=initial_experts,
        latent_dim=128,
        hidden_dim=512
    ).to(device)
    
    # Set up optimizers
    ebm_optimizer = torch.optim.Adam(ebm.parameters(), lr=2e-4)
    mix_cvae_optimizer = torch.optim.Adam(mix_cvae.parameters(), lr=8e-4)
    
    # Training loop
    ebm_losses = []
    mix_cvae_losses = []
    expert_count_history = [initial_experts]
    
    for epoch in range(num_epochs):
        epoch_ebm_loss = 0.0
        epoch_mix_cvae_loss = 0.0
        
        # Shuffle dataset
        random.shuffle(solution_dataset)
        
        for batch_idx in range(0, len(solution_dataset), 32):
            batch_data = solution_dataset[batch_idx:batch_idx+32]
            
            # Prepare batch tensors
            x_real_batch = []
            condition_batch = []
            
            for G, K, T, perturbation_dict, perturbation_vector in batch_data:
                # Pad or truncate the perturbation vector to match max_edge_dim
                if len(perturbation_vector) < max_edge_dim:
                    # Pad with zeros if vector is shorter
                    padding = torch.zeros(max_edge_dim - len(perturbation_vector), device=perturbation_vector.device)
                    padded_vector = torch.cat([perturbation_vector, padding])
                    x_real_batch.append(padded_vector)
                else:
                    # Truncate if vector is longer or equal
                    x_real_batch.append(perturbation_vector[:max_edge_dim])
                
                condition = encode_graph_condition(G, K, T)
                condition_batch.append(condition)
            
            # Convert to tensors
            # try:
            x_real = torch.stack(x_real_batch).to(device)
            condition = torch.stack(condition_batch).to(device)
            
            # 1. Update EBM
            ebm_optimizer.zero_grad()
            
            # Energy of real data
            energy_real = ebm(x_real).mean()
            
            # Generate fake data from the mixture model
            with torch.no_grad():
                x_fake = mix_cvae.sample(condition, num_samples=x_real.size(0))
            
            # Energy of fake data
            energy_fake = ebm(x_fake).mean()
            
            # EBM loss: minimize energy of real data, maximize energy of fake data
            # Add regularization term
            energy_reg = (energy_real**2 + energy_fake**2)
            ebm_loss = energy_real - energy_fake + 1.0 * energy_reg
            
            ebm_loss.backward()
            ebm_optimizer.step()
            
            epoch_ebm_loss += ebm_loss.item()
            
            # 2. Update Mix-CVAE
            mix_cvae_optimizer.zero_grad()
            
            # Forward pass through the mixture
            x_recon, mus, logvars, expert_weights = mix_cvae(x_real, condition)
            
            # Reconstruction loss
            recon_loss = F.mse_loss(x_recon, x_real)
            
            # KL divergence for each expert
            kl_loss = 0.0
            for i, (mu, logvar) in enumerate(zip(mus, logvars)):
                expert_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
                expert_weight = expert_weights[:, i].mean()
                kl_loss += expert_weight * expert_kl
            
            # Energy guidance: penalize high energy regions
            x_sample = mix_cvae.sample(condition, num_samples=x_real.size(0))
            energy_guidance = ebm(x_sample).mean()
            
            # Combined loss
            beta_kl = 0.1  # Weight for KL divergence term
            lambda_energy = 0.5  # Weight for energy guidance term
            
            mix_cvae_loss = recon_loss + beta_kl * kl_loss + lambda_energy * energy_guidance
            mix_cvae_loss.backward()
            mix_cvae_optimizer.step()
            
            epoch_mix_cvae_loss += mix_cvae_loss.item()
            
            # Check if we should add a new expert
            # This happens when there are regions with high energy that the current
            # mixture is not capturing well
            if epoch > 50 and mix_cvae.num_experts < max_experts and epoch % 10 == 0:
                with torch.no_grad():
                    # Sample from the current model
                    x_samples = mix_cvae.sample(condition, num_samples=x_real.size(0))
                    
                    # Evaluate energy of samples
                    sample_energies = ebm(x_samples).squeeze()
                    # Calculate density ratio log(q(x)/Ω(x))
                    # We approximate this with the energy values
                    
                    
                    high_energy_samples = (sample_energies > expert_add_threshold).sum().item()
                    if high_energy_samples > 10:
                        print(f"Epoch {epoch}: Adding new expert (current: {mix_cvae.num_experts})")
                        mix_cvae.add_expert()
                        expert_count_history.append(mix_cvae.num_experts)
            # except Exception as e:
            #     print(f"Error in batch {batch_idx}: {e}")
            #     continue
        
        # Log progress
        if epoch % 10 == 0:
            avg_ebm_loss = epoch_ebm_loss / (len(solution_dataset) // 32 + 1)
            avg_mix_cvae_loss = epoch_mix_cvae_loss / (len(solution_dataset) // 32 + 1)
            
            ebm_losses.append(avg_ebm_loss)
            mix_cvae_losses.append(avg_mix_cvae_loss)
            
            print(f"Epoch {epoch}/{num_epochs}")
            print(f"  EBM Loss: {avg_ebm_loss:.4f}")
            print(f"  Mix-CVAE Loss: {avg_mix_cvae_loss:.4f}")
            print(f"  Number of experts: {mix_cvae.num_experts}")
    
    
    return ebm, mix_cvae