import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import Dataset as HFDataset
import logging

logger = logging.getLogger(__name__)

class MolecularPropertyDataset(Dataset):
    """
    Dataset for molecular properties with property-only inputs.
    """
    def __init__(self, data_path, tokenizer, model, property_cols, layer_idx=-1,
                 max_length=512, device=None):
        """
        Initialize the dataset.
        
        Args:
            data_path: Path to the parquet file
            tokenizer: Tokenizer for the language model
            model: Pre-trained language model
            property_cols: List of column names for property values
            layer_idx: Layer to extract representations from
            max_length: Maximum sequence length
            device: Device to use for computation
        """
        self.tokenizer = tokenizer
        self.model = model
        self.layer_idx = layer_idx
        self.max_length = max_length
        self.property_cols = property_cols
        
        # Set device
        if device is None:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = device
        
        # Load data
        logger.info(f"Loading data from {data_path}")
        self.data = pd.read_parquet(data_path)
        logger.info(f"Loaded {len(self.data)} samples")
        
        # Convert properties to tensor
        self.properties = torch.tensor(
            self.data[property_cols].values,
            dtype=torch.float32
        )
        self.smiles_input = self.data['smiles'].values
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        """
        Get a sample from the dataset.
        
        Returns:
            representation: Representation at the [START_SMILES] token
            properties: Property values for the sample
        """
        # Get property values for this row
        property_values = {}
        for col in self.property_cols:
            property_values[col] = self.data[col].iloc[idx]

        smiles = self.smiles_input[idx]
        
        # Create formatted input with only properties + START_SMILES token
        input_text = self.prepare_property_input(
            wavelength=property_values.get('wavelength', None),
            f_osc=property_values.get('f_osc', None),
            qed=property_values.get('qed', None),
            logp=property_values.get('logp', None),
            smiles=smiles
        )
        
        # Extract representation
        representation = self._extract_representation(input_text)
        
        return representation, self.properties[idx]
    
    def _extract_representation(self, text):
        """
        Extract representation from the model at the [START_SMILES] token position.
        """
        # Tokenize input
        inputs = self.tokenizer(
            text, 
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=self.max_length
        ).to(self.device)
        
        start_smiles_pos = -1
        
        if 'token_type_ids' in inputs:
            inputs.pop('token_type_ids')

        # Get hidden states
        with torch.no_grad():
            outputs = self.model(**inputs, output_hidden_states=True)
        
        # Extract hidden states from specified layer
        hidden_states = outputs.hidden_states[self.layer_idx]
        
        # Extract representation at the [START_SMILES] token position
        representation = hidden_states[0, start_smiles_pos, :]
        
        return representation
    
    @staticmethod
    def prepare_property_input(wavelength=None, f_osc=None, qed=None, logp=None, smiles=None):
        """
        Prepare property-only input with special tokens, followed by START_SMILES.
        
        Args:
            wavelength: Wavelength value (optional)
            f_osc: Oscillator strength (optional)
            qed: QED value (optional)
            logp: LogP value (optional)
            
        Returns:
            text: Formatted text with property values and [START_SMILES] token
        """
        text = ""
        
        if wavelength is not None:
            text += f"[WAVELENGTH]{wavelength}[/WAVELENGTH] "
        
        if f_osc is not None:
            text += f"[F_OSC]{f_osc}[/F_OSC] "
        
        if qed is not None:
            text += f"</s>[QED]{round(qed)}[/QED] "
        
        if logp is not None:
            text += f"[LOGP]{round(logp)}[/LOGP] "
        
        # Add START_SMILES without the actual SMILES
        text += f"[START_SMILES]"
        
        return text
    

def create_dataloaders(dataset, batch_size, train_ratio=0.8, val_ratio=0.1, seed=42):
    """
    Split dataset into train, validation, and test sets and create DataLoaders.
    
    Args:
        dataset: The full dataset
        batch_size: Batch size for the DataLoaders
        train_ratio: Proportion of data for training
        val_ratio: Proportion of data for validation
        seed: Random seed for reproducibility
        
    Returns:
        train_loader, val_loader, test_loader: DataLoaders for each split
    """
    # Set random seed for reproducibility
    torch.manual_seed(seed)
    
    dataset_size = len(dataset)
    train_size = int(train_ratio * dataset_size)
    val_size = int(val_ratio * dataset_size)
    test_size = dataset_size - train_size - val_size
    
    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size, test_size]
    )
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    logger.info(f"Created DataLoaders with {train_size} training, {val_size} validation, "
               f"and {test_size} test samples")
    
    return train_loader, val_loader, test_loader