import torch
import numpy as np
from torch.utils.data import DataLoader
import random
import pandas as pd
from chemprop import models, nn
from chemprop.nn import metrics
from chemprop.data.datapoints import MoleculeDatapoint
from chemprop.data.datasets import MoleculeDataset
from chemprop.data.collate import collate_batch
from chemprop.featurizers import SimpleMoleculeMolGraphFeaturizer
from chemprop.data import MulticomponentDataset, build_dataloader

def set_seed(seed):
    """Set random seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

def sample_for_testing(df, n_samples=1000, seed=42):
    """
    Randomly sample rows from a dataframe for testing/debugging.
    
    Parameters:
    - df: pandas DataFrame to sample from
    - n_samples: number of samples to take (default: 1000)
    - seed: random seed for reproducibility
    
    Returns:
    - Sampled dataframe
    """
    if len(df) <= n_samples:
        return df
    
    # Set seed for reproducibility
    random.seed(seed)
    
    # Sample the dataframe
    return df.sample(n=n_samples, random_state=seed)

def define_binary_model():
    """
    Define the binary classification model architecture.
    """
    mpnn = models.MPNN(
        message_passing=nn.message_passing.BondMessagePassing(),
        agg=nn.agg.MeanAggregation(),
        predictor=nn.predictors.BinaryClassificationFFN(),
        batch_norm=True,
        metrics=[nn.metrics.BinaryAccuracy(), nn.metrics.BinaryAUROC()]
    )
    return mpnn

def define_multicomponent_model():
    """
    Define the multicomponent model architecture for feature extraction.
    """
    # Create model with two message passing blocks (one for each component)
    mcmp = nn.MulticomponentMessagePassing(
        blocks=[nn.BondMessagePassing() for _ in range(2)],
        n_components=2,
    )
    agg = nn.MeanAggregation()
    ffn = nn.BinaryClassificationFFN(input_dim=mcmp.output_dim, hidden_dim=4, n_layers=2)
    model = models.MulticomponentMPNN(
        mcmp, agg, ffn, 
        batch_norm=False,
        metrics=[metrics.BinaryAccuracy(), metrics.BinaryAUROC()]
    )
    return model

def load_model(model_type, checkpoint_path):
    """
    Load a model from a checkpoint.
    
    Parameters:
    - model_type: Type of model ("binary" or "multicomponent")
    - checkpoint_path: Path to model checkpoint
    
    Returns:
    - Loaded model
    """
    if model_type == "binary":
        model = define_binary_model()
    elif model_type == "multicomponent":
        model = define_multicomponent_model()
    else:
        raise ValueError(f"Unknown model type: {model_type}")
    
    # Load model checkpoint
    model = model.__class__.load_from_checkpoint(checkpoint_path)
    model.eval()
    
    return model

def extract_features(model, model_type, smiles_pairs, batch_size=32):
    """
    Extract features from a model for a set of SMILES pairs.
    
    Parameters:
    - model: PyTorch model
    - model_type: Type of model ("binary" or "multicomponent")
    - smiles_pairs: Array of SMILES pairs
    - batch_size: Batch size for processing
    
    Returns:
    - Array of features
    - Array of scores
    """
    # Hook to extract features
    features = []
    
    def hook_fn(_, __, output):
        features.append(output.detach().cpu().numpy())
    
    # Register hook
    if model_type == "binary":
        # For binary model, hook to the penultimate layer
        model.predictor.ffn[1].register_forward_hook(hook_fn)
    else:
        # For multicomponent model, hook to the penultimate layer
        model.predictor.ffn[1].register_forward_hook(hook_fn)
    
    # Process SMILES pairs in batches
    all_features = []
    all_scores = []
    
    if model_type == "binary":
        # For binary model, use the first SMILES in each pair
        datapoints = [MoleculeDatapoint.from_smi(pair[0]) for pair in smiles_pairs]
        dataset = MoleculeDataset(datapoints)
        
        # Create dataloader
        loader = DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=collate_batch
        )
        
        # Make predictions
        model.cpu()
        with torch.no_grad():
            for batch in loader:
                output = model(batch.bmg, batch.V_d, batch.X_d)
                scores = output.cpu().numpy()
                all_scores.append(scores)
                # Features were collected via the hook
                if features:
                    all_features.append(features[-1])
                    features = []  # Clear features for next batch
    
    else:  # multicomponent model
        # Create datapoints for both components
        featurizer = SimpleMoleculeMolGraphFeaturizer()
        component_datapoints = [[], []]
        
        for pair in smiles_pairs:
            component_datapoints[0].append(MoleculeDatapoint.from_smi(pair[0]))
            component_datapoints[1].append(MoleculeDatapoint.from_smi(pair[1]))
        
        # Create datasets
        datasets = [
            MoleculeDataset(component_datapoints[i], featurizer) 
            for i in range(2)
        ]
        
        # Create multicomponent dataset
        multi_dataset = MulticomponentDataset(datasets)
        multi_dataset.target_transform = None
        
        # Create dataloader
        loader = build_dataloader(multi_dataset, batch_size=batch_size, shuffle=False)
        
        # Make predictions
        model.cpu()
        with torch.no_grad():
            for batch in loader:
                output = model(batch.bmgs, batch.V_ds, batch.X_d)
                scores = output.cpu().numpy()
                all_scores.append(scores)
                # Features were collected via the hook
                if features:
                    all_features.append(features[-1])
                    features = []  # Clear features for next batch
    
    # Concatenate results
    if all_features:
        all_features = np.vstack(all_features)
    else:
        raise ValueError("No features were extracted. Check the hook placement.")
        
    all_scores = np.concatenate(all_scores, axis=0)
    all_scores = 1 - all_scores
    return all_features, all_scores.flatten() 
