#!/usr/bin/env python3
"""
Observed Adjacency Conditioned Molecular Sampling Script
=========================================

This script implements observed adjacency conditioning for molecular generation
using the DiGress framework. It allows conditioning the generation process on
specific adjacency matrix entries that are observed/known.

Based on the approach similar to solver_guidance.py, this method applies masking
during the reverse diffusion process to fix certain edge entries to ground truth values.
"""

import os
import sys
import torch
import pickle
import numpy as np
from omegaconf import OmegaConf, open_dict

# Add parent directory to path
sys.path.append('..')
sys.path.append('../..')

# Import DiGress modules
import src.utils as utils
from src.diffusion_model_discrete import DiscreteDenoisingDiffusion
from src.datasets import qm9_dataset
from src.metrics.molecular_metrics import SamplingMolecularMetrics
from src.metrics.molecular_metrics_discrete import TrainMolecularMetricsDiscrete
from src.analysis.visualization import MolecularVisualization
from src.diffusion.extra_features import DummyExtraFeatures, ExtraFeatures
from src.diffusion.extra_features_molecular import ExtraMolecularFeatures
from src.diffusion import diffusion_utils


class ObservedAdjacencyConditionedDiffusion(DiscreteDenoisingDiffusion):
    """
    Modified DiscreteDenoisingDiffusion that supports observed adjacency conditioning.
    
    This class implements masking operations to fix specific adjacency matrix entries
    during the reverse diffusion process, similar to the approach in solver_guidance.py.
    """
    
    def __init__(self, cfg, **kwargs):
        super().__init__(cfg, **kwargs)
        
    @torch.no_grad()
    def sample_batch_with_observed_adjacency(self, batch_id: int, batch_size: int, keep_chain: int, 
                                            number_chain_steps: int, save_final: int, num_nodes=None,
                                            idx_observed=None, ground_truth_adj=None, guidance_lambda=0.0):
        """
        Sample molecules with observed adjacency entries conditioning.
        
        Args:
            batch_id: Batch identifier
            batch_size: Number of molecules to sample
            keep_chain: Number of chains to save
            number_chain_steps: Number of steps to save in chains
            save_final: Number of final samples to save
            num_nodes: Number of nodes (if None, sampled from distribution)
            idx_observed: Tuple of 3 tensors (batch_idx, i_idx, j_idx) from torch.where(mask == 1)
                         indicating which adjacency entries are observed. Up to size 128 in first dimension.
            ground_truth_adj: Ground truth adjacency matrices (bs, n_max, n_max, de) with known values
                             at the observed positions
        
        Returns:
            List of generated molecules
        """
        print(f"Sampling batch {batch_id} with observed adjacency conditioning...")
        
        if num_nodes is None:
            n_nodes = self.node_dist.sample_n(batch_size, self.device)
        elif type(num_nodes) == int:
            n_nodes = num_nodes * torch.ones(batch_size, device=self.device, dtype=torch.int)
        else:
            assert isinstance(num_nodes, torch.Tensor)
            n_nodes = num_nodes
            
        n_max = torch.max(n_nodes).item()
        
        # Build the masks
        arange = torch.arange(n_max, device=self.device).unsqueeze(0).expand(batch_size, -1)
        node_mask = arange < n_nodes.unsqueeze(1)
        
        # Sample noise
        z_T = diffusion_utils.sample_discrete_feature_noise(limit_dist=self.limit_dist, node_mask=node_mask)
        X, E, y = z_T.X, z_T.E, z_T.y

        assert (E == torch.transpose(E, 1, 2)).all()
        assert number_chain_steps < self.T
        
        chain_X_size = torch.Size((number_chain_steps, keep_chain, X.size(1)))
        chain_E_size = torch.Size((number_chain_steps, keep_chain, E.size(1), E.size(2)))
        
        chain_X = torch.zeros(chain_X_size)
        chain_E = torch.zeros(chain_E_size)

        # Prepare observed adjacency conditioning if provided
        if idx_observed is not None and ground_truth_adj is not None:
            batch_idx, i_idx, j_idx = idx_observed
            print(f"Applying observed adjacency conditioning with {len(batch_idx)} observed entries")
            print(f"Observed indices range: batch[{batch_idx.min()}-{batch_idx.max()}], i[{i_idx.min()}-{i_idx.max()}], j[{j_idx.min()}-{j_idx.max()}]")
            
            # Ensure ground truth adjacency is on correct device
            ground_truth_adj = ground_truth_adj.to(self.device)

        # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
        for s_int in reversed(range(0, self.T)):
            s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
            t_array = s_array + 1
            s_norm = s_array / self.T
            t_norm = t_array / self.T

            # Sample z_s
            sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt_guidance(s_norm, t_norm, X, E, y, node_mask, guidance_lambda)
            X, E, y = sampled_s.X, sampled_s.E, sampled_s.y

            # 🎯 APPLY OBSERVED ADJACENCY MASKING HERE 🎯
            if idx_observed is not None and ground_truth_adj is not None:
                X, E = self._apply_observed_adjacency_mask(X, E, idx_observed, ground_truth_adj)

            # Save the first keep_chain graphs
            write_index = (s_int * number_chain_steps) // self.T
            chain_X[write_index] = discrete_sampled_s.X[:keep_chain]
            chain_E[write_index] = discrete_sampled_s.E[:keep_chain]

        # Sample
        sampled_s = sampled_s.mask(node_mask, collapse=True)
        X, E, y = sampled_s.X, sampled_s.E, sampled_s.y

        # Prepare the chain for saving
        if keep_chain > 0:
            final_X_chain = X[:keep_chain]
            final_E_chain = E[:keep_chain]

            chain_X[0] = final_X_chain
            chain_E[0] = final_E_chain

            chain_X = diffusion_utils.reverse_tensor(chain_X)
            chain_E = diffusion_utils.reverse_tensor(chain_E)

            # Repeat last frame to see final sample better
            chain_X = torch.cat([chain_X, chain_X[-1:].repeat(10, 1, 1)], dim=0)
            chain_E = torch.cat([chain_E, chain_E[-1:].repeat(10, 1, 1, 1)], dim=0)
            assert chain_X.size(0) == (number_chain_steps + 10)

        return X, E

    def _apply_observed_adjacency_mask(self, X, E, idx_observed, ground_truth_adj):
        """
        Apply observed adjacency masking by setting specific adjacency entries to ground truth values.
        Similar to the approach in solver_guidance.py with mask_adjs functionality.
        
        Args:
            X: Current node features (bs, n_max, dx)
            E: Current edge features (bs, n_max, n_max, de)
            idx_observed: Tuple of 3 tensors (batch_idx, i_idx, j_idx) from torch.where(mask == 1)
            ground_truth_adj: Ground truth adjacency matrices (bs, n_max, n_max) with 0s and 1s
        
        Returns:
            Tuple of (X, masked_E) where E has ground truth values at observed positions
        """
        
        masked_E = E.clone()
        gt_edge_indices = ground_truth_adj[idx_observed]  # Shape: (n_valid,)
        
        # Convert scalar indices to one-hot vectors
        gt_one_hot = torch.zeros((len(gt_edge_indices), E.shape[-1]), device=E.device)
        gt_one_hot[torch.arange(len(gt_edge_indices)), gt_edge_indices.long()] = 1.0
        
        # Apply ground truth values at valid observed positions (vectorized)
        masked_E[idx_observed] = gt_one_hot.long()
        
        # Ensure symmetry for undirected graphs (vectorized)
        masked_E[idx_observed] = gt_one_hot.long()

        return X, masked_E


class FairAdjacencyDiffusion(DiscreteDenoisingDiffusion):
    """
    Diffusion process that applies L2-guidance on observed adjacency entries.

    Instead of hard-clamping observed edges, this class computes an L2 loss
    between the current edge logits/probabilities and the ground-truth one-hot
    on the observed positions and uses the gradient of that loss to guide
    the reverse diffusion updates.
    """
    def __init__(self, cfg, guidance_lambda: float = 1.0, **kwargs):
        super().__init__(cfg, **kwargs)
        self.guidance_lambda = guidance_lambda

    @torch.no_grad()
    def sample_batch_with_guidance(self, batch_id: int, batch_size: int, keep_chain: int,
                                   number_chain_steps: int, save_final: int, num_nodes=None,
                                   guidance_lambda: float = 1.0):
        """Sample molecules using guidance applied to observed adjacency entries.

        This method calls the model's `sample_p_zs_given_zt_guided` during sampling so
        the guidance is applied at the posterior/probability level (preferred approach).
        """
        print(f"Guided sampling batch {batch_id} (lambda={guidance_lambda})...")

        if num_nodes is None:
            n_nodes = self.node_dist.sample_n(batch_size, self.device)
        elif isinstance(num_nodes, int):
            n_nodes = num_nodes * torch.ones(batch_size, device=self.device, dtype=torch.int)
        else:
            assert isinstance(num_nodes, torch.Tensor)
            n_nodes = num_nodes

        n_max = torch.max(n_nodes).item()
        arange = torch.arange(n_max, device=self.device).unsqueeze(0).expand(batch_size, -1)
        node_mask = arange < n_nodes.unsqueeze(1)

        # Sample prior noise
        z_T = diffusion_utils.sample_discrete_feature_noise(limit_dist=self.limit_dist, node_mask=node_mask)
        X, E, y = z_T.X, z_T.E, z_T.y

        chain_X = torch.zeros((number_chain_steps, keep_chain, X.size(1)))
        chain_E = torch.zeros((number_chain_steps, keep_chain, E.size(1), E.size(2)))

        ## Sensitive attributes
        # Assign half and half of the nodes randomly to each group
        n_elems = E.shape[1] // 2
        template_row = np.array([0] * n_elems + [1] * (E.shape[1] - n_elems))
        idxs_com = np.tile(template_row, (E.shape[0], 1))
        for i in range(E.shape[0]):
            np.random.shuffle(idxs_com[i])
        Zs = torch.nn.functional.one_hot(torch.tensor(idxs_com), num_classes=2).float().to(E.device)
        Zs = Zs.permute(0,2,1)

        for s_int in reversed(range(0, self.T)):
            s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
            t_array = s_array + 1
            s_norm = s_array / self.T
            t_norm = t_array / self.T

            sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt_fairness(s_norm, t_norm, X, E, y, node_mask,
                                                                               Zs, guidance_lambda=guidance_lambda)
            X, E, y = sampled_s.X, sampled_s.E, sampled_s.y

            write_index = (s_int * number_chain_steps) // self.T
            chain_X[write_index] = discrete_sampled_s.X[:keep_chain]
            chain_E[write_index] = discrete_sampled_s.E[:keep_chain]

        sampled_s = sampled_s.mask(node_mask, collapse=True)
        X, E, y = sampled_s.X, sampled_s.E, sampled_s.y

        return X, E, Zs


def load_config(base_config_path):
    """Load configuration for QM9 dataset sampling."""
    
    try:
        # Try to load the experiment config first
        config_path = base_config_path + "/experiment/qm9_no_h.yaml"
        cfg = OmegaConf.load(config_path)
        print(f"Loaded experiment config from {config_path}")
    except:
        # Fallback to main config
        config_path = base_config_path + "/config.yaml"
        cfg = OmegaConf.load(config_path)
        print(f"Loaded main config from {config_path}")
    
    # Load and merge default configurations
    try:
        # Load model defaults
        model_config = OmegaConf.load(base_config_path + "/model/discrete.yaml")
        if 'model' not in cfg:
            cfg.model = {}
        cfg.model = OmegaConf.merge(model_config, cfg.model)
        
        # Load general defaults
        general_config = OmegaConf.load(base_config_path + "/general/general_default.yaml")
        if 'general' not in cfg:
            cfg.general = {}
        cfg.general = OmegaConf.merge(general_config, cfg.general)
        
        # Load dataset defaults
        dataset_config = OmegaConf.load(base_config_path + "/dataset/qm9.yaml")
        if 'dataset' not in cfg:
            cfg.dataset = {}
        cfg.dataset = OmegaConf.merge(dataset_config, cfg.dataset)
        
        print("Successfully merged default configurations")
    except Exception as e:
        print(f"Warning: Could not load some default configs: {e}")
        print("Creating minimal config manually...")
    
    # Set up for QM9 dataset sampling with manual fallbacks
    with open_dict(cfg):
        # Dataset settings
        if 'dataset' not in cfg:
            cfg.dataset = {}
        cfg.dataset.name = 'qm9'
        cfg.dataset.remove_h = True
        cfg.dataset.pin_memory = True 
        cfg.dataset.num_workers = 4
        if 'datadir' not in cfg.dataset:
            cfg.dataset.datadir = '../data/qm9'
        
        # Model settings
        if 'model' not in cfg:
            cfg.model = {}
        cfg.model.type = 'discrete'
        cfg.model.extra_features = 'all'  # Critical: set this explicitly
        cfg.model.diffusion_steps = getattr(cfg.model, 'diffusion_steps', 500)
        cfg.model.diffusion_noise_schedule = getattr(cfg.model, 'diffusion_noise_schedule', 'cosine')
        cfg.model.transition = getattr(cfg.model, 'transition', 'marginal')
        cfg.model.n_layers = getattr(cfg.model, 'n_layers', 9)
        
        # Add missing model parameters with defaults
        if 'hidden_mlp_dims' not in cfg.model:
            cfg.model.hidden_mlp_dims = {'X': 256, 'E': 128, 'y': 128}
        if 'hidden_dims' not in cfg.model:
            cfg.model.hidden_dims = {'dx': 256, 'de': 64, 'dy': 64, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 128, 'dim_ffy': 128}
        if 'lambda_train' not in cfg.model:
            cfg.model.lambda_train = [5, 0]
        
        # General settings
        if 'general' not in cfg:
            cfg.general = {}
        cfg.general.name = 'guided_sampling'
        cfg.general.test_only = None
        cfg.general.resume = None
        cfg.general.gpus = 1 if torch.cuda.is_available() else 0
        cfg.general.samples_to_generate = 10
        cfg.general.samples_to_save = 10
        cfg.general.chains_to_save = 5
        cfg.general.number_chain_steps = 100
        cfg.general.sample_every_val = getattr(cfg.general, 'sample_every_val', 1)
        cfg.general.final_model_samples_to_generate = getattr(cfg.general, 'final_model_samples_to_generate', 100)
        cfg.general.final_model_samples_to_save = getattr(cfg.general, 'final_model_samples_to_save', 100)
        cfg.general.final_model_chains_to_save = getattr(cfg.general, 'final_model_chains_to_save', 10)
        
        # Train settings (needed for model initialization)
        if 'train' not in cfg:
            cfg.train = {}
        cfg.train.batch_size = getattr(cfg.train, 'batch_size', 32)
        cfg.train.num_workers = 4
        cfg.train.lr = getattr(cfg.train, 'lr', 0.0001)
        cfg.train.n_epochs = getattr(cfg.train, 'n_epochs', 1000)
        cfg.train.save_model = getattr(cfg.train, 'save_model', True)
    
    print(f"Configuration loaded successfully:")
    print(f"  - model.extra_features: {cfg.model.extra_features}")
    print(f"  - model.type: {cfg.model.type}")
    print(f"  - dataset.name: {cfg.dataset.name}")
    
    return cfg


def setup_dataset_and_models(cfg):
    """Setup dataset, models, and all necessary components."""
    
    # Setup QM9 dataset
    datamodule = qm9_dataset.QM9DataModule(cfg)
    dataset_infos = qm9_dataset.QM9infos(datamodule=datamodule, cfg=cfg)
    datamodule.prepare_data()
    train_smiles = qm9_dataset.get_train_smiles(cfg=cfg, train_dataloader=datamodule.train_dataloader(),
                                                dataset_infos=dataset_infos, evaluate_dataset=False)

    # Setup features
    # Use getattr to safely access extra_features with a default value
    extra_features_config = getattr(cfg.model, 'extra_features', 'all')
    if extra_features_config is not None and extra_features_config != 'null':
        extra_features = ExtraFeatures(extra_features_config, dataset_info=dataset_infos)
        domain_features = ExtraMolecularFeatures(dataset_infos=dataset_infos)
    else:
        extra_features = DummyExtraFeatures()
        domain_features = DummyExtraFeatures()

    dataset_infos.compute_input_output_dims(datamodule=datamodule, extra_features=extra_features,
                                            domain_features=domain_features)

    # Setup metrics and visualization
    train_metrics = TrainMolecularMetricsDiscrete(dataset_infos)
    sampling_metrics = SamplingMolecularMetrics(dataset_infos, train_smiles)
    visualization_tools = MolecularVisualization(cfg.dataset.remove_h, dataset_infos=dataset_infos)

    model_kwargs = {'dataset_infos': dataset_infos, 'train_metrics': train_metrics,
                    'sampling_metrics': sampling_metrics, 'visualization_tools': visualization_tools,
                    'extra_features': extra_features, 'domain_features': domain_features}
    
    return datamodule, dataset_infos, model_kwargs


def load_pretrained_model(cfg, model_kwargs, checkpoint_path):
    """Load the pretrained DiGress model."""
    
    # Create our modified model
    model = ObservedAdjacencyConditionedDiffusion(cfg=cfg, **model_kwargs)
    
    # Load state dict from checkpoint
    if checkpoint_path and os.path.exists(checkpoint_path):
        print(f"Loading checkpoint from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
        model.load_state_dict(checkpoint['state_dict'])
    else:
        print("Warning: No checkpoint provided, using randomly initialized model")
    
    model.eval()
    return model


def load_pretrained_guided_fairness_model(cfg, model_kwargs, checkpoint_path, guidance_lambda: float = 1.0):
    """Load a GuidedAdjacencyDiffusion model from checkpoint."""
    model = FairAdjacencyDiffusion(cfg=cfg, guidance_lambda=guidance_lambda, **model_kwargs)

    # Load state dict from checkpoint if available
    if checkpoint_path and os.path.exists(checkpoint_path):
        print(f"Loading checkpoint into guided model from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
        try:
            model.load_state_dict(checkpoint['state_dict'])
        except Exception:
            # Fallback: try loading full checkpoint directly
            model.load_state_dict(checkpoint)
    else:
        print("Warning: No checkpoint provided for guided model, using randomly initialized model")

    model.eval()
    return model


def sample_with_observed_adjacency(model, idx_observed, ground_truth_adj,
                                  n_samples=10, batch_size=5, guidance_lambda: float = 0.0):
    """
    Sample molecules conditioned on observed adjacency entries.
    
    Args:
        model: ObservedAdjacencyConditionedDiffusion model
        idx_observed: Tuple of 3 tensors (batch_idx, i_idx, j_idx)
        ground_truth_adj: Ground truth adjacency matrices (bs, n_max, n_max, de) 
        n_samples: Total number of molecules to sample
        batch_size: Batch size for sampling
        guidance_lambda: Guidance strength for sampling

    Returns:
        List of generated molecules
    """
    print(f"Sampling {n_samples} molecules with observed adjacency conditioning")
    
    # Get observed indices using torch.where
    print(f"Found {len(idx_observed[0])} observed adjacency entries")

    Xs = []
    Es = []
    n_batches = (n_samples + batch_size - 1) // batch_size
    
    for batch_idx in range(n_batches):
        current_batch_size = min(batch_size, n_samples - batch_idx * batch_size)
        
        print(f"Sampling batch {batch_idx + 1}/{n_batches} (size: {current_batch_size})")
        
        # Adjust observed indices and ground truth for current batch size
        if current_batch_size != batch_size:
            # Filter indices for current batch size
            batch_idx_observed = list(idx[valid_batch_mask] for idx in idx_observed)
            batch_idx_observed[0] -= batch_idx * batch_size  # Adjust batch indices
            batch_idx_observed = tuple(batch_idx_observed)
            batch_ground_truth = ground_truth_adj[batch_idx * batch_size:]
            assert batch_ground_truth.shape[0] == current_batch_size
        else:
            valid_batch_mask = (idx_observed[0] > batch_idx * batch_size) & (idx_observed[0] < (batch_idx + 1) * batch_size)
            batch_idx_observed = list(idx[valid_batch_mask] for idx in idx_observed)
            batch_idx_observed[0] -= batch_idx * batch_size  # Adjust batch indices
            batch_idx_observed = tuple(batch_idx_observed)
            batch_ground_truth = ground_truth_adj[batch_idx * batch_size:(batch_idx + 1) * batch_size]
        
        # Sample with observed adjacency conditioning
        X, E = model.sample_batch_with_observed_adjacency(
            batch_id=batch_idx,
            batch_size=current_batch_size,
            keep_chain=0,  # Don't save chains for efficiency
            number_chain_steps=model.cfg.general.number_chain_steps,
            save_final=current_batch_size,
            num_nodes=None,  # Let model decide
            idx_observed=batch_idx_observed,
            ground_truth_adj=batch_ground_truth,
            guidance_lambda=guidance_lambda
        )

        Xs.extend(X)
        Es.extend(E)

    print(f"Successfully sampled {len(Xs)} molecules!")
    return torch.stack(Xs), torch.stack(Es)


def sample_with_guided_adjacency(model, idx_observed, ground_truth_adj,
                                 n_samples=10, batch_size=5, guidance_lambda: float = 1.0):
    """Sample molecules using a GuidedAdjacencyDiffusion model (calls sample_batch_with_guidance).

    Args:
        model: GuidedAdjacencyDiffusion model instance
        idx_observed: Tuple of 3 tensors (batch_idx, i_idx, j_idx)
        ground_truth_adj: Ground truth adjacency matrices (bs, n_max, n_max)
        n_samples: total number of molecules to sample
        batch_size: sampling batch size
        guidance_lambda: guidance strength (passed to model.sample_batch_with_guidance)
    Returns:
        List of generated molecules (same format as sample_with_observed_adjacency)
    """
    print(f"Sampling {n_samples} molecules with L2 guidance (lambda={guidance_lambda})")

    print(f"Found {len(idx_observed[0])} observed adjacency entries")

    Xs = []
    Es = []
    n_batches = (n_samples + batch_size - 1) // batch_size

    for batch_idx in range(n_batches):
        current_batch_size = min(batch_size, n_samples - batch_idx * batch_size)
        print(f"Sampling batch {batch_idx + 1}/{n_batches} (size: {current_batch_size})")

        if current_batch_size != batch_size:
            valid_batch_mask = idx_observed[0] < current_batch_size
            batch_idx_observed = list(idx[valid_batch_mask] for idx in idx_observed)
            batch_idx_observed[0] -= batch_idx * batch_size  # Adjust batch indices
            batch_idx_observed = tuple(batch_idx_observed)
            batch_ground_truth = ground_truth_adj[batch_idx * batch_size:]
            assert batch_ground_truth.shape[0] == current_batch_size
        else:
            valid_batch_mask = (idx_observed[0] > batch_idx * batch_size) & (idx_observed[0] < (batch_idx + 1) * batch_size)
            batch_idx_observed = list(idx[valid_batch_mask] for idx in idx_observed)
            batch_idx_observed[0] -= batch_idx * batch_size  # Adjust batch indices
            batch_idx_observed = tuple(batch_idx_observed)
            batch_ground_truth = ground_truth_adj[batch_idx * batch_size:(batch_idx + 1) * batch_size]

        X, E = model.sample_batch_with_guidance(
            batch_id=batch_idx,
            batch_size=current_batch_size,
            keep_chain=0,
            number_chain_steps=model.cfg.general.number_chain_steps,
            save_final=current_batch_size,
            num_nodes=None,
            idx_observed=batch_idx_observed,
            ground_truth_adj=batch_ground_truth,
            guidance_lambda=guidance_lambda
        )

        Xs.extend(X)
        Es.extend(E)

    print(f"Successfully sampled {len(Xs)} molecules with guidance!")
    return torch.stack(Xs), torch.stack(Es)

def sample_with_guided_fair_adjacency(model, n_samples=10, batch_size=5, guidance_lambda: float = 1.0):
    """Sample molecules using a GuidedAdjacencyDiffusion model (calls sample_batch_with_guidance).

    Args:
        model: GuidedAdjacencyDiffusion model instance
        idx_observed: Tuple of 3 tensors (batch_idx, i_idx, j_idx)
        ground_truth_adj: Ground truth adjacency matrices (bs, n_max, n_max)
        n_samples: total number of molecules to sample
        batch_size: sampling batch size
        guidance_lambda: guidance strength (passed to model.sample_batch_with_guidance)
    Returns:
        List of generated graphs
    """
    print(f"Sampling {n_samples} fair graphs with L2 guidance (lambda={guidance_lambda})")

    Xs = []
    Es = []
    Zs = []
    n_batches = (n_samples + batch_size - 1) // batch_size

    for batch_idx in range(n_batches):
        current_batch_size = min(batch_size, n_samples - batch_idx * batch_size)
        print(f"Sampling batch {batch_idx + 1}/{n_batches} (size: {current_batch_size})")

        X, E, Z = model.sample_batch_with_guidance(
            batch_id=batch_idx,
            batch_size=current_batch_size,
            keep_chain=0,
            number_chain_steps=model.cfg.general.number_chain_steps,
            save_final=current_batch_size,
            num_nodes=None,
            guidance_lambda=guidance_lambda
        )

        Xs.extend(X)
        Es.extend(E)
        Zs.extend(Z)

    print(f"Successfully sampled {len(Xs)} molecules with guidance!")
    return torch.stack(Xs), torch.stack(Es), torch.stack(Zs)


def convert_to_smiles(samples, dataset_infos):
    """Convert samples to SMILES strings."""
    from src.analysis.rdkit_functions import build_molecule_with_partial_charges, mol2smiles
    
    smiles_list = []
    valid_count = 0
    
    for i, sample in enumerate(samples):
        try:
            atom_types = sample[0]
            edge_types = sample[1]
            
            # Check if charges_dic exists, if not use a default
            charges_dic = getattr(dataset_infos, 'charges_dic', {})
            
            # Convert to molecule
            mol = build_molecule_with_partial_charges(
                atom_types, edge_types,
                dataset_infos.atom_decoder,
                charges_dic
            )
            
            if mol is not None:
                smiles = mol2smiles(mol)
                if smiles:
                    smiles_list.append(smiles)
                    valid_count += 1
                else:
                    smiles_list.append(None)
            else:
                smiles_list.append(None)
                
        except Exception as e:
            print(f"Error converting molecule {i}: {e}")
            smiles_list.append(None)
    
    print(f"Valid molecules: {valid_count}/{len(samples)} ({100*valid_count/len(samples):.1f}%)")
    return smiles_list


def sample_digress(idx_observed, ground_truth_adj, n_samples, config_path, ckpt_path, device='cpu', guidance_lambda: float = 0.0):
    """
    Sample molecules with observed adjacency conditioning.
    
    Args:
        idx_observed: Tuple of 3 tensors (batch_idx, i_idx, j_idx) from torch.where(mask == 1)
        ground_truth_adj: Ground truth adjacency matrices (bs, n_max, n_max, de)
        max_nodes: Maximum number of nodes in the graph
        n_samples: Total number of samples to generate
        
    Returns:
        List of sampled molecules
    """
    # Load configuration and setup dataset and models
    cfg = load_config(config_path)
    datamodule, dataset_infos, model_kwargs = setup_dataset_and_models(cfg)
    
    # Load pretrained model
    model = load_pretrained_model(cfg, model_kwargs, ckpt_path)
    
    # Set device
    model = model.to(device)
    
    # Sample with observed adjacency conditioning
    X, E = sample_with_observed_adjacency(
        model=model,
        idx_observed=idx_observed,
        ground_truth_adj=ground_truth_adj,
        n_samples=n_samples,
        batch_size=1000,  # Use a reasonable batch size
        guidance_lambda=guidance_lambda
    )
    
    return X, E


def sample_guided_digress(idx_observed, ground_truth_adj, n_samples, config_path, ckpt_path,
                          device='cpu', guidance_lambda: float = 1.0):
    """Sample molecules using the GuidedAdjacencyDiffusion model and guidance strength.

    This mirrors `sample_digress` but loads a guided model and calls `sample_with_guided_adjacency`.
    """
    cfg = load_config(config_path)
    datamodule, dataset_infos, model_kwargs = setup_dataset_and_models(cfg)

    # Load guided model
    model = load_pretrained_guided_model(cfg, model_kwargs, ckpt_path, guidance_lambda=guidance_lambda)
    model = model.to(device)

    X, E = sample_with_guided_adjacency(
        model=model,
        idx_observed=idx_observed,
        ground_truth_adj=ground_truth_adj,
        n_samples=n_samples,
        batch_size=1000,
        guidance_lambda=guidance_lambda
    )

    return X, E

def sample_fair_digress(n_samples, config_path, ckpt_path,
                          device='cpu', guidance_lambda: float = 1.0):
    """Sample molecules using the GuidedAdjacencyDiffusion model and guidance strength.

    This mirrors `sample_digress` but loads a guided model and calls `sample_with_guided_adjacency`.
    """
    cfg = load_config(config_path)
    datamodule, dataset_infos, model_kwargs = setup_dataset_and_models(cfg)

    # Load guided model
    model = load_pretrained_guided_fairness_model(cfg, model_kwargs, ckpt_path, guidance_lambda=guidance_lambda)
    model = model.to(device)

    X, E, Z = sample_with_guided_fair_adjacency(
        model=model,
        n_samples=n_samples,
        batch_size=1000,
        guidance_lambda=guidance_lambda
    )

    return X, E, Z


def main():
    """Main function to run observed adjacency conditioned sampling."""
    
    print("DiGress Observed Adjacency Conditioned Molecular Sampling")
    print("=" * 70)
    
    # Load configuration
    print("Loading configuration...")
    cfg = load_config()
    
    # Setup dataset and models
    print("Setting up dataset and models...")
    datamodule, dataset_infos, model_kwargs = setup_dataset_and_models(cfg)
    
    # Load pretrained model
    checkpoint_path = "../qm9.ckpt"
    print("Loading pretrained model...")
    model = load_pretrained_model(cfg, model_kwargs, checkpoint_path)
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    print(f"Using device: {device}")
    
    # Test observed adjacency conditioning with random data
    all_results = {}
    
    print(f"\n{'='*60}")
    print("Testing OBSERVED ADJACENCY conditioning")
    print(f"{'='*60}")
    
    batch_size = 5
    max_nodes = 15
    
    # Test with different numbers of observed edges
    edge_counts = [3, 5, 8]
    
    for num_edges in edge_counts:
        print(f"\n{'='*40}")
        print(f"Testing with {num_edges} observed edges")
        print(f"{'='*40}")
        
        # Create random observed adjacency entries
        random_mask, random_ground_truth = create_random_observed_adjacency(
            batch_size=batch_size, max_nodes=max_nodes, num_edges=num_edges, dataset_infos=dataset_infos
        )
        
        if random_mask is not None:
            # Move to device
            random_mask = random_mask.to(device) 
            random_ground_truth = random_ground_truth.to(device)
            
            # Sample with random observed adjacency
            random_samples = sample_with_observed_adjacency(
                model=model,
                idx_observed=(random_mask.nonzero(as_tuple=True)),
                ground_truth_adj=random_ground_truth,
                n_samples=10,
                batch_size=batch_size
            )
            
            # Convert to SMILES
            random_smiles = convert_to_smiles(random_samples, dataset_infos)
            
            all_results[f'observed_{num_edges}_edges'] = {
                'samples': random_samples,
                'smiles': random_smiles,
                'valid_smiles': [s for s in random_smiles if s is not None]
            }
            
            valid_count = len(all_results[f'observed_{num_edges}_edges']['valid_smiles'])
            print(f"\nObserved adjacency ({num_edges} edges): {valid_count}/{len(random_samples)} valid molecules")
            for i, smiles in enumerate(all_results[f'observed_{num_edges}_edges']['valid_smiles'][:3]):
                print(f"  {i+1}. {smiles}")
    
    # Save results
    output_file = "observed_adjacency_sampling_results.pkl"
    with open(output_file, 'wb') as f:
        pickle.dump(all_results, f)
    print(f"\nResults saved to {output_file}")
    
    # Summary comparison
    print(f"\n{'='*60}")
    print("SUMMARY")
    print(f"{'='*60}")
    
    for key in all_results.keys():
        valid_count = len(all_results[key]['valid_smiles'])
        total_count = len(all_results[key]['samples'])
        num_edges = key.split('_')[1]
        print(f"Observed {num_edges} edges -> {valid_count:2d}/{total_count:2d} valid molecules ({100*valid_count/total_count:4.1f}%)")
    
    print("\nObserved adjacency sampling completed!")
    print("\nThis approach fixes specific edge entries during sampling,")
    print("allowing for flexible conditioning on partial graph information.")


if __name__ == "__main__":
    # Set environment variable to avoid OpenMP issues
    os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
    
    try:
        main()
    except Exception as e:
        print(f"Error during sampling: {e}")
        import traceback
        traceback.print_exc()
