import hydra
from torch.utils.data import Dataset, DataLoader
import torch
from torch.nn import functional as F
import pandas as pd 
import os

from multiguide.helpers import PROJECT_ROOT
from multiguide.dataset.helpers import smi_tokenizer, get_vocab_from_trained_model

class MoleculeDataset(Dataset):
    def __init__(self, rxns, properties, full_lengths, config):
        self.rxns = rxns
        self.properties = properties
        self.full_lengths = [f+2 for f in full_lengths] # for start and end tokens
        self.all_chars_path = os.path.join(PROJECT_ROOT, 
                                           'data', 
                                           config.classifier_guidance.dataset.vocab_file)
        # Preload all data into memory
        self.processed_data = []
        self.src_lengths = []  # Store original source lengths for bucketing
        #self.vocab = self._load_chars()
        self.vocab = get_vocab_from_trained_model(config.classifier_guidance.onmt_checkpoint_path)
        # Create char to index mapping
        self.char_to_idx = {c: i for i, c in enumerate(self.vocab)}
        # Define special token indices
        self.pad_idx = self.char_to_idx['<blank>']
        self.start_idx = self.char_to_idx['<s>']
        self.end_idx = self.char_to_idx['</s>']
        self.unk_idx = self.char_to_idx['<unk>']
        self._preload_data()

    def _preload_data(self):
        '''Preload and process all data'''
        # Load source and target sequences
        for rxn, prop, full_length in zip(self.rxns, self.properties, self.full_lengths):
            tokens_str, _ = smi_tokenizer(rxn)
            tokens = tokens_str.split(" ")
            original_len = len(tokens)+2 # for start and end tokens
            self.src_lengths.append(original_len) # for dynamic bucketing
            # Process sequences
            seq_id = self._tokenize(tokens)
            self.processed_data.append((seq_id, prop, full_length))

    def _load_chars(self):
        '''Load character set'''
        with open(self.all_chars_path, 'r', encoding='utf-8') as f:
            all_chars = [line.strip() for line in f]
        return all_chars

    def _tokenize(self, tokens):
        '''Convert text to token indices and pad'''
        # Convert tokens to indices with unknown token handling
        indices = []
        # Add start token if needed
        indices.append(self.start_idx)
        # Process regular tokens
        for token in tokens:
            indices.append(self.char_to_idx.get(token, self.unk_idx))
        indices.append(self.end_idx)
        # Create tensor
        seq_tensor = torch.tensor(indices, dtype=torch.long)
        return seq_tensor

    def __len__(self):
        return len(self.processed_data)
    
    def get_length_percentage(self, idx):
        return (self.src_lengths[idx] / self.full_lengths[idx]) * 100

    def __getitem__(self, idx):
        # Return the data along with the original length for bucketing
        seq_id, prop, full_length = self.processed_data[idx]

        return seq_id, torch.tensor([prop]), torch.tensor([full_length]), self.src_lengths[idx]
