"""
RQ3: Scalability to Complex Domains - Molecular Generation with DATE-GFN

This experiment addresses scalability limitations by implementing DATE-GFN
for high-dimensional molecular generation using MPNN-based architectures.

Research Question: Can DATE-GFN framework generalize to complex, structured 
state spaces and enhance SOTA architectures in molecular generation tasks?
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.geometric
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_geometric.data import Data, Batch
import numpy as np
import wandb
import time
import random
from collections import deque, defaultdict
from typing import List, Tuple, Dict, Optional, Union
import argparse
import os
# Removed RDKit dependencies - simplified molecular generation
import networkx as nx

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

class MolecularEnvironment:
    """Molecular generation environment using graph-based representation"""
    
    def __init__(self, max_atoms: int = 20, fragment_library: List[str] = None):
        self.max_atoms = max_atoms
        
        # Chemical fragment library (simplified for demo)
        if fragment_library is None:
            self.fragments = [
                'C',      # Carbon
                'N',      # Nitrogen  
                'O',      # Oxygen
                'S',      # Sulfur
                'F',      # Fluorine
                'Cl',     # Chlorine
                'C=C',    # Double bond
                'C#C',    # Triple bond
                'c1ccccc1', # Benzene ring
                'C(=O)',  # Carbonyl
                'N(C)C',  # Tertiary amine
                'O-C',    # Ether
                'S(=O)(=O)', # Sulfonyl
                'P(=O)',  # Phosphoryl
                '[STOP]'  # Stop token
            ]
        else:
            self.fragments = fragment_library
        
        self.vocab_size = len(self.fragments)
        self.stop_token = self.vocab_size - 1
        
        # Current molecule state
        self.reset()
    
    def reset(self):
        """Reset to empty molecule"""
        self.mol_graph = nx.Graph()
        self.atom_count = 0
        self.fragment_sequence = []
        return self.get_state_representation()
    
    def get_state_representation(self):
        """Convert current molecule to graph representation"""
        if self.atom_count == 0:
            # Empty molecule - return dummy graph
            node_features = torch.zeros(1, 16)  # 16 atom features
            edge_index = torch.empty(2, 0, dtype=torch.long)
            edge_attr = torch.empty(0, 4)  # 4 bond features
        else:
            # Convert NetworkX graph to PyTorch Geometric format
            node_features = []
            edge_index = []
            edge_attr = []
            
            # Node features (atom properties)
            for node in self.mol_graph.nodes():
                features = self._get_atom_features(node)
                node_features.append(features)
            
            # Edge features (bond properties)  
            for edge in self.mol_graph.edges():
                edge_index.append([edge[0], edge[1]])
                edge_index.append([edge[1], edge[0]])  # Undirected
                
                bond_features = self._get_bond_features(edge)
                edge_attr.extend([bond_features, bond_features])
            
            node_features = torch.stack(node_features) if node_features else torch.zeros(1, 16)
            edge_index = torch.tensor(edge_index).t() if edge_index else torch.empty(2, 0, dtype=torch.long)
            edge_attr = torch.stack(edge_attr) if edge_attr else torch.empty(0, 4)
        
        return Data(x=node_features, edge_index=edge_index, edge_attr=edge_attr)
    
    def _get_atom_features(self, node_id: int) -> torch.Tensor:
        """Get feature vector for an atom"""
        # Simplified atom features (in practice, use RDKit descriptors)
        node_data = self.mol_graph.nodes[node_id]
        atom_type = node_data.get('atom_type', 'C')
        
        # One-hot encoding for common atoms + additional features
        features = [0.0] * 16
        
        # Atom type encoding
        atom_types = ['C', 'N', 'O', 'S', 'F', 'Cl', 'P']
        if atom_type in atom_types:
            features[atom_types.index(atom_type)] = 1.0
        
        # Additional features
        features[7] = node_data.get('degree', 0) / 4.0  # Normalized degree
        features[8] = node_data.get('formal_charge', 0) / 2.0  # Normalized charge
        features[9] = node_data.get('hybridization', 0) / 3.0  # Hybridization state
        features[10] = node_data.get('aromatic', 0)  # Aromaticity
        features[11] = node_data.get('in_ring', 0)   # Ring membership
        
        return torch.tensor(features, dtype=torch.float32)
    
    def _get_bond_features(self, edge: Tuple[int, int]) -> torch.Tensor:
        """Get feature vector for a bond"""
        edge_data = self.mol_graph.edges[edge]
        
        # Bond type encoding
        bond_type = edge_data.get('bond_type', 'SINGLE')
        bond_types = ['SINGLE', 'DOUBLE', 'TRIPLE', 'AROMATIC']
        
        features = [0.0] * 4
        if bond_type in bond_types:
            features[bond_types.index(bond_type)] = 1.0
        
        return torch.tensor(features, dtype=torch.float32)
    
    def get_valid_actions(self):
        """Get valid fragment additions"""
        if self.atom_count >= self.max_atoms:
            return [self.stop_token]
        
        # All fragments are valid (simplified)
        return list(range(self.vocab_size))
    
    def step(self, action: int):
        """Add fragment to molecule"""
        if action == self.stop_token or self.atom_count >= self.max_atoms:
            # Terminal action
            reward = self.compute_molecular_reward()
            mol_smiles = self._to_smiles()
            return self.get_state_representation(), reward, True, {'smiles': mol_smiles, 'valid': self._is_valid_molecule()}
        
        # Add fragment
        fragment = self.fragments[action]
        success = self._add_fragment(fragment)
        
        if success:
            self.fragment_sequence.append(action)
        
        return self.get_state_representation(), 0.0, False, {}
    
    def _add_fragment(self, fragment: str) -> bool:
        """Add a fragment to the current molecule"""
        try:
            # Simplified fragment addition logic
            if fragment == 'C':
                self._add_atom('C')
            elif fragment == 'N':
                self._add_atom('N')
            elif fragment == 'O':
                self._add_atom('O')
            elif fragment == 'S':
                self._add_atom('S')
            elif fragment == 'F':
                self._add_atom('F')
            elif fragment == 'Cl':
                self._add_atom('Cl')
            elif fragment == 'C=C':
                self._add_double_bond()
            elif fragment == 'c1ccccc1':
                self._add_benzene_ring()
            else:
                # Generic atom addition
                self._add_atom('C')
            
            return True
        except Exception:
            return False
    
    def _add_atom(self, atom_type: str):
        """Add a single atom"""
        new_id = self.atom_count
        self.mol_graph.add_node(new_id, atom_type=atom_type, degree=0)
        
        # Connect to existing atoms (simplified connectivity)
        if self.atom_count > 0:
            # Connect to last added atom
            prev_id = self.atom_count - 1
            self.mol_graph.add_edge(prev_id, new_id, bond_type='SINGLE')
        
        self.atom_count += 1
    
    def _add_double_bond(self):
        """Add carbon-carbon double bond"""
        if self.atom_count >= 2:
            # Modify last bond to double
            edges = list(self.mol_graph.edges())
            if edges:
                last_edge = edges[-1]
                self.mol_graph.edges[last_edge]['bond_type'] = 'DOUBLE'
    
    def _add_benzene_ring(self):
        """Add benzene ring structure"""
        if self.atom_count + 6 <= self.max_atoms:
            ring_start = self.atom_count
            
            # Add 6 carbon atoms in ring
            for i in range(6):
                self.mol_graph.add_node(ring_start + i, atom_type='C', aromatic=1, in_ring=1)
            
            # Add ring bonds
            for i in range(6):
                next_i = (i + 1) % 6
                self.mol_graph.add_edge(ring_start + i, ring_start + next_i, bond_type='AROMATIC')
            
            # Connect ring to existing molecule
            if ring_start > 0:
                self.mol_graph.add_edge(ring_start - 1, ring_start, bond_type='SINGLE')
            
            self.atom_count += 6
    
    def _to_smiles(self) -> str:
        """Convert molecule to SMILES string (simplified)"""
        if self.atom_count == 0:
            return ""
        
        # Very simplified SMILES generation
        atom_symbols = []
        for node in self.mol_graph.nodes():
            atom_type = self.mol_graph.nodes[node].get('atom_type', 'C')
            atom_symbols.append(atom_type)
        
        return ''.join(atom_symbols)  # Oversimplified
    
    def _is_valid_molecule(self) -> bool:
        """Check if molecule is chemically valid"""
        # Simplified validity check
        return (self.atom_count > 0 and 
                self.atom_count <= self.max_atoms and
                len(self.fragment_sequence) > 0)
    
    def compute_molecular_reward(self) -> float:
        """Compute molecular property-based reward"""
        if self.atom_count == 0:
            return 0.0
        
        # Simplified molecular property scoring
        scores = []
        
        # Size score (prefer moderate sizes)
        size_score = max(0, 1 - abs(self.atom_count - 15) / 10.0)
        scores.append(size_score * 0.2)
        
        # Complexity score (prefer some complexity)
        complexity = len(set(self.fragment_sequence)) / len(self.fragments)
        scores.append(complexity * 0.3)
        
        # Ring score (rings are often important)
        has_ring = any(self.mol_graph.nodes[node].get('in_ring', 0) for node in self.mol_graph.nodes())
        ring_score = 0.3 if has_ring else 0.1
        scores.append(ring_score * 0.2)
        
        # Connectivity score (prefer connected molecules)
        if self.atom_count > 1:
            connectivity = nx.is_connected(self.mol_graph)
            connectivity_score = 0.4 if connectivity else 0.1
        else:
            connectivity_score = 0.4
        scores.append(connectivity_score * 0.3)
        
        return min(1.0, sum(scores))

class MPNNLayer(MessagePassing):
    """Message Passing Neural Network layer"""
    
    def __init__(self, node_dim: int, edge_dim: int, hidden_dim: int):
        super().__init__(aggr='mean')
        
        self.node_dim = node_dim
        self.edge_dim = edge_dim
        self.hidden_dim = hidden_dim
        
        # Message function
        self.message_mlp = nn.Sequential(
            nn.Linear(2 * node_dim + edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Update function
        self.update_mlp = nn.Sequential(
            nn.Linear(node_dim + hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, node_dim)
        )
    
    def forward(self, x, edge_index, edge_attr):
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)
    
    def message(self, x_i, x_j, edge_attr):
        # Concatenate source node, target node, and edge features
        msg_input = torch.cat([x_i, x_j, edge_attr], dim=-1)
        return self.message_mlp(msg_input)
    
    def update(self, aggr_out, x):
        # Update node features
        update_input = torch.cat([x, aggr_out], dim=-1)
        return self.update_mlp(update_input)

class MPNNPolicyNetwork(nn.Module):
    """MPNN-based policy network for molecular generation"""
    
    def __init__(self, 
                 node_dim: int = 16,
                 edge_dim: int = 4, 
                 hidden_dim: int = 128,
                 num_layers: int = 3,
                 vocab_size: int = 15):
        super().__init__()
        
        self.node_dim = node_dim
        self.vocab_size = vocab_size
        
        # Initial node embedding
        self.node_embedding = nn.Linear(node_dim, hidden_dim)
        
        # MPNN layers
        self.mpnn_layers = nn.ModuleList([
            MPNNLayer(hidden_dim, edge_dim, hidden_dim)
            for _ in range(num_layers)
        ])
        
        # GRU for sequence modeling
        self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
        
        # Output layers
        self.output_mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, vocab_size)
        )
        
        # Teacher components for DATE-GFN
        self.teacher_knowledge = None
        self.distillation_buffer = deque(maxlen=100)
    
    def forward(self, batch_data, action_mask=None):
        """Forward pass through MPNN policy"""
        # Handle single graph
        if not isinstance(batch_data, Batch):
            batch_data = Batch.from_data_list([batch_data])
        
        x = batch_data.x
        edge_index = batch_data.edge_index
        edge_attr = batch_data.edge_attr
        batch = batch_data.batch
        
        # Initial embedding
        h = self.node_embedding(x)
        
        # MPNN message passing
        for mpnn_layer in self.mpnn_layers:
            h_new = mpnn_layer(h, edge_index, edge_attr)
            h = h + h_new  # Residual connection
        
        # Global pooling to get graph-level representation
        graph_repr = global_mean_pool(h, batch)
        
        # Pass through GRU (treating as sequence of length 1)
        gru_out, _ = self.gru(graph_repr.unsqueeze(1))
        graph_features = gru_out.squeeze(1)
        
        # Generate action logits
        logits = self.output_mlp(graph_features)
        
        # Apply action mask
        if action_mask is not None:
            logits = logits + (action_mask - 1) * 1e9
        
        return F.log_softmax(logits, dim=-1)
    
    def get_teacher_guidance(self, batch_data, action_mask, step=0):
        """Generate teacher guidance for molecular fragments"""
        if self.teacher_knowledge is None:
            # Initialize with chemical knowledge
            # Prefer: C, N, O, rings, functional groups
            self.teacher_knowledge = torch.tensor([
                1.0,  # C
                0.8,  # N  
                0.7,  # O
                0.6,  # S
                0.5,  # F
                0.4,  # Cl
                0.9,  # C=C
                0.7,  # C#C
                1.2,  # Benzene (important!)
                0.8,  # C=O
                0.7,  # N(C)C
                0.6,  # O-C
                0.5,  # S(=O)(=O)
                0.4,  # P(=O)
                0.1   # STOP
            ], dtype=torch.float32)
        
        # Evolve based on training progress
        evolution_factor = 1.0 + (step / 10000) * 0.3
        evolved_knowledge = self.teacher_knowledge * evolution_factor
        
        # Apply mask and normalize
        if action_mask is not None:
            masked_teacher = evolved_knowledge.unsqueeze(0) * action_mask
        else:
            masked_teacher = evolved_knowledge.unsqueeze(0)
        
        return F.softmax(masked_teacher, dim=1)
    
    def distill_from_teacher(self, batch_data, action_mask, teacher_probs, lambda_param):
        """Compute distillation loss"""
        student_log_probs = self.forward(batch_data, action_mask)
        kl_loss = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean')
        
        self.distillation_buffer.append(kl_loss.item())
        return lambda_param * kl_loss

class MPNNCritic(nn.Module):
    """MPNN-based critic for molecular evaluation"""
    
    def __init__(self,
                 node_dim: int = 16,
                 edge_dim: int = 4,
                 hidden_dim: int = 128,
                 num_layers: int = 3):
        super().__init__()
        
        # Initial embedding
        self.node_embedding = nn.Linear(node_dim, hidden_dim)
        
        # MPNN layers
        self.mpnn_layers = nn.ModuleList([
            MPNNLayer(hidden_dim, edge_dim, hidden_dim)
            for _ in range(num_layers)
        ])
        
        # Value head
        self.value_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, batch_data):
        """Forward pass to predict molecular value"""
        if not isinstance(batch_data, Batch):
            batch_data = Batch.from_data_list([batch_data])
        
        x = batch_data.x
        edge_index = batch_data.edge_index  
        edge_attr = batch_data.edge_attr
        batch = batch_data.batch
        
        # Initial embedding
        h = self.node_embedding(x)
        
        # MPNN layers
        for mpnn_layer in self.mpnn_layers:
            h_new = mpnn_layer(h, edge_index, edge_attr)
            h = h + h_new
        
        # Global pooling and value prediction
        graph_repr = global_mean_pool(h, batch)
        value = self.value_head(graph_repr)
        
        return value

class MolecularDATEGFN:
    """DATE-GFN for molecular generation with MPNN architectures"""
    
    def __init__(self,
                 vocab_size: int,
                 max_atoms: int,
                 lambda_param: float = 0.1,
                 lr: float = 1e-3,
                 use_critic: bool = True):
        
        self.vocab_size = vocab_size
        self.max_atoms = max_atoms
        self.lambda_param = lambda_param
        self.use_critic = use_critic
        
        # Initialize networks
        self.policy = MPNNPolicyNetwork(vocab_size=vocab_size)
        self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr)
        
        if use_critic:
            self.critic = MPNNCritic()
            self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr)
        
        # Training metrics
        self.step_count = 0
        self.performance_buffer = deque(maxlen=100)
        
    def sample_trajectory(self, env: MolecularEnvironment):
        """Sample molecular generation trajectory"""
        env.reset()
        trajectory = []
        total_reward = 0.0
        
        while True:
            state = env.get_state_representation()
            valid_actions = env.get_valid_actions()
            
            # Create action mask
            action_mask = torch.zeros(self.vocab_size)
            action_mask[valid_actions] = 1.0
            
            # Get policy distribution
            with torch.no_grad():
                log_probs = self.policy(state, action_mask.unsqueeze(0))
                probs = torch.exp(log_probs)
                action = torch.multinomial(probs, 1).item()
            
            # Store trajectory step
            trajectory.append({
                'state': state,
                'action': action,
                'action_mask': action_mask,
                'log_prob': log_probs[0, action].item()
            })
            
            # Execute action
            next_state, reward, done, info = env.step(action)
            total_reward += reward
            
            if done:
                break
        
        return trajectory, total_reward, info
    
    def train_step(self, env: MolecularEnvironment, num_trajectories: int = 16):
        """Single training step"""
        trajectories = []
        rewards = []
        
        # Sample trajectories
        for _ in range(num_trajectories):
            trajectory, reward, info = self.sample_trajectory(env)
            trajectories.append(trajectory)
            rewards.append(reward)
        
        # Compute losses
        policy_loss = 0.0
        critic_loss = 0.0
        total_distillation_loss = 0.0
        
        for trajectory, reward in zip(trajectories, rewards):
            for step_data in trajectory:
                state = step_data['state']
                action = step_data['action']
                action_mask = step_data['action_mask']
                
                # Policy gradient
                log_prob = self.policy(state, action_mask.unsqueeze(0))[0, action]
                
                # Baseline from critic if available
                baseline = 0.0
                if self.use_critic:
                    state_value = self.critic(state).item()
                    baseline = state_value
                    
                    # Critic loss (TD error)
                    critic_target = torch.tensor([reward], dtype=torch.float32)
                    critic_pred = self.critic(state)
                    critic_loss += F.mse_loss(critic_pred, critic_target)
                
                # Policy loss with baseline
                advantage = reward - baseline
                policy_loss -= log_prob * advantage
                
                # Distillation loss
                teacher_probs = self.policy.get_teacher_guidance(
                    state, action_mask.unsqueeze(0), self.step_count
                )
                distillation_loss = self.policy.distill_from_teacher(
                    state, action_mask.unsqueeze(0), teacher_probs, self.lambda_param
                )
                total_distillation_loss += distillation_loss
        
        # Average losses
        num_steps = sum(len(traj) for traj in trajectories)
        policy_loss /= num_steps
        if self.use_critic:
            critic_loss /= num_steps
        total_distillation_loss /= num_steps
        
        # Combined loss with teacher bonus
        teacher_bonus = self.lambda_param * torch.exp(-total_distillation_loss)
        total_loss = policy_loss - teacher_bonus
        
        # Update policy
        self.policy_optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 1.0)
        self.policy_optimizer.step()
        
        # Update critic
        if self.use_critic and critic_loss > 0:
            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 1.0)
            self.critic_optimizer.step()
        
        self.step_count += 1
        
        return {
            'avg_reward': np.mean(rewards),
            'policy_loss': policy_loss.item(),
            'critic_loss': critic_loss.item() if self.use_critic else 0.0,
            'distillation_loss': total_distillation_loss.item(),
            'total_loss': total_loss.item()
        }
    
    def evaluate_performance(self, env: MolecularEnvironment, num_episodes: int = 100):
        """Evaluate molecular generation performance"""
        rewards = []
        valid_molecules = 0
        unique_smiles = set()
        high_reward_modes = 0
        
        for _ in range(num_episodes):
            trajectory, reward, info = self.sample_trajectory(env)
            rewards.append(reward)
            
            if info.get('valid', False):
                valid_molecules += 1
                smiles = info.get('smiles', '')
                unique_smiles.add(smiles)
                
                if reward > 0.75:  # High-quality threshold
                    high_reward_modes += 1
        
        return {
            'avg_reward': np.mean(rewards),
            'reward_std': np.std(rewards),
            'validity_rate': valid_molecules / num_episodes,
            'unique_molecules': len(unique_smiles),
            'high_reward_modes': high_reward_modes,
            'diversity': len(unique_smiles) / max(1, valid_molecules)
        }

def run_molecular_experiment(config: Dict):
    """Run molecular generation experiment"""
    
    # Initialize wandb
    method_name = f"DATE-GFN-MPNN (max_atoms={config['max_atoms']})"
    wandb.init(
        project="DATE_GFN_Molecular_Scalability", 
        name=method_name,
        config=config
    )
    
    # Setup environment and agent
    env = MolecularEnvironment(max_atoms=config['max_atoms'])
    agent = MolecularDATEGFN(
        vocab_size=env.vocab_size,
        max_atoms=config['max_atoms'],
        lambda_param=config['lambda_param'],
        use_critic=config['use_critic']
    )
    
    print(f"🚀 Starting {method_name}")
    print(f"   Configuration: {config}")
    
    # Training loop
    for step in range(config['num_steps']):
        # Training step
        metrics = agent.train_step(env, num_trajectories=config['batch_size'])
        
        # Periodic evaluation
        if step % config['eval_frequency'] == 0:
            performance = agent.evaluate_performance(env)
            
            # Log metrics
            wandb.log({
                'step': step,
                'train_avg_reward': metrics['avg_reward'],
                'policy_loss': metrics['policy_loss'],
                'critic_loss': metrics['critic_loss'],
                'distillation_loss': metrics['distillation_loss'],
                'total_loss': metrics['total_loss'],
                'eval_avg_reward': performance['avg_reward'],
                'eval_reward_std': performance['reward_std'],
                'validity_rate': performance['validity_rate'],
                'unique_molecules': performance['unique_molecules'],
                'high_reward_modes': performance['high_reward_modes'],
                'diversity': performance['diversity'],
                'max_atoms': config['max_atoms']
            })
            
            print(f"  Step {step:4d}: Reward={performance['avg_reward']:.3f}, "
                  f"Valid={performance['validity_rate']:.2f}, "
                  f"Unique={performance['unique_molecules']}, "
                  f"Modes={performance['high_reward_modes']}")
    
    # Final evaluation
    final_performance = agent.evaluate_performance(env, num_episodes=500)
    
    wandb.log({
        'final_avg_reward': final_performance['avg_reward'],
        'final_validity_rate': final_performance['validity_rate'],
        'final_unique_molecules': final_performance['unique_molecules'],
        'final_high_reward_modes': final_performance['high_reward_modes'],
        'final_diversity': final_performance['diversity']
    })
    
    print(f"✅ {method_name} completed")
    print(f"   Final: Reward={final_performance['avg_reward']:.3f}, "
          f"Modes={final_performance['high_reward_modes']}, "
          f"Diversity={final_performance['diversity']:.3f}")
    
    wandb.finish()
    
    return final_performance

def main():
    """Main experiment launcher for RQ3"""
    
    parser = argparse.ArgumentParser(description='RQ3: Molecular Scalability Experiments')
    parser.add_argument('--mode', choices=['single', 'scalability', 'comparison'], default='scalability')
    parser.add_argument('--max_atoms', type=int, default=15)
    parser.add_argument('--use_critic', action='store_true', default=True)
    args = parser.parse_args()
    
    base_config = {
        'num_steps': 3000,
        'batch_size': 16,
        'eval_frequency': 100,
        'lambda_param': 0.1
    }
    
    if args.mode == 'single':
        # Single experiment
        config = {
            **base_config,
            'max_atoms': args.max_atoms,
            'use_critic': args.use_critic
        }
        run_molecular_experiment(config)
        
    elif args.mode == 'scalability':
        # Scalability analysis across molecular complexity
        max_atoms_values = [10, 15, 20, 25]
        results = []
        
        for max_atoms in max_atoms_values:
            config = {
                **base_config,
                'max_atoms': max_atoms,
                'use_critic': True
            }
            
            performance = run_molecular_experiment(config)
            results.append({
                'max_atoms': max_atoms,
                'performance': performance
            })
        
        # Scalability summary
        print("\n" + "="*70)
        print("MOLECULAR SCALABILITY ANALYSIS")
        print("="*70)
        print(f"{'Max Atoms':<12} {'Avg Reward':<12} {'Modes':<8} {'Diversity':<12} {'Validity':<10}")
        print("-"*70)
        
        for result in results:
            perf = result['performance']
            print(f"{result['max_atoms']:<12} {perf['avg_reward']:<12.3f} "
                  f"{perf['high_reward_modes']:<8d} {perf['diversity']:<12.3f} "
                  f"{perf['validity_rate']:<10.2f}")
        
        # Check scaling behavior
        rewards = [r['performance']['avg_reward'] for r in results]
        modes = [r['performance']['high_reward_modes'] for r in results]
        
        print(f"\n📊 SCALABILITY METRICS:")
        print(f"  Reward degradation: {(rewards[0] - rewards[-1]) / rewards[0]:.1%}")
        print(f"  Mode discovery scaling: {modes[-1] / max(1, modes[0]):.2f}x")
        
        if rewards[-1] >= 0.5 and modes[-1] >= modes[0]:
            print("✅ SUCCESS: DATE-GFN scales well to complex molecules!")
        else:
            print("⚠️  Scalability challenges detected")
            
    elif args.mode == 'comparison':
        # Compare against baselines
        methods = [
            {'name': 'GFN-FM', 'use_critic': False, 'lambda_param': 0.0},
            {'name': 'EGFN', 'use_critic': True, 'lambda_param': 0.0}, 
            {'name': 'DATE-GFN', 'use_critic': True, 'lambda_param': 0.1}
        ]
        
        results = []
        max_atoms = 20  # Complex molecules
        
        for method in methods:
            config = {
                **base_config,
                'max_atoms': max_atoms,
                'use_critic': method['use_critic'],
                'lambda_param': method['lambda_param']
            }
            
            # Override wandb name
            original_name = f"DATE-GFN-MPNN (max_atoms={max_atoms})"
            method_name = f"{method['name']}-MPNN (max_atoms={max_atoms})"
            
            performance = run_molecular_experiment(config)
            results.append({
                'method': method['name'],
                'performance': performance
            })
        
        # Comparison summary
        print("\n" + "="*80)
        print("MOLECULAR GENERATION METHODS COMPARISON")
        print("="*80)
        print(f"{'Method':<15} {'Avg Reward':<12} {'Modes':<8} {'Unique':<8} {'Diversity':<12}")
        print("-"*80)
        
        for result in results:
            perf = result['performance']
            print(f"{result['method']:<15} {perf['avg_reward']:<12.3f} "
                  f"{perf['high_reward_modes']:<8d} {perf['unique_molecules']:<8d} "
                  f"{perf['diversity']:<12.3f}")
        
        # Find best method
        best_result = max(results, key=lambda x: x['performance']['high_reward_modes'])
        date_gfn_result = next(r for r in results if r['method'] == 'DATE-GFN')
        
        print(f"\n🏆 Best Method: {best_result['method']}")
        
        if date_gfn_result['method'] == best_result['method']:
            print("✅ SUCCESS: DATE-GFN achieves best molecular generation performance!")
        else:
            modes_improvement = (date_gfn_result['performance']['high_reward_modes'] / 
                               max(1, results[0]['performance']['high_reward_modes']))
            print(f"📊 DATE-GFN improvement: {modes_improvement:.2f}x modes over baseline")

if __name__ == "__main__":
    main()
