from datasets import load_dataset, Dataset, concatenate_datasets, load_from_disk
from utils.data_utils import smiles2selfies
import pandas as pd

class DataLoader:
    """Shared data loading module for molecular datasets"""
    
    def __init__(self, mol_type="SMILES"):
        self.dataset_configs = {
            "chebi": {
                "source": "huggingface",
                "path": "liupf/ChEBI-20-MM",
                "split": "train",
                "columns": {"SMILES": "SMILES", "description": "desc", "SELFIES": "SELFIES"},
                "custom_processing": None
            },
            "lm": {
                "source": "huggingface", 
                "path": "language-plus-molecules/LPM-24_train",
                "split": "split_train",
                "columns": {"molecule": "SMILES", "caption": "desc"},
                "custom_processing": None
            },
            "pubchem": {
                "source": "huggingface",
                "path": "sagawa/pubchem-10m-canonicalized",
                "split": "train", 
                "columns": {"smiles": "SMILES"},
                "custom_processing": None
            },
            "chemnlp": {
                "source": "huggingface",
                "path": "kjappelbaum/chemnlp_iupac_smiles",
                "split": "train",
                "columns": {"SMILES": "SMILES", "Preferred": "desc"},
                "custom_processing": None
            },
            "zinc": {
                "source": "huggingface",
                "path": "yzimmermann/tokenized-ZINC-1B",
                "split": "train",
                "columns": {"smiles": "SMILES"},
                "custom_processing": None
            },
            "qa": {
                "source": "huggingface",
                "path": "sentence-transformers/natural-questions",
                "split": "train",
                "columns": {"query": "question"},
                "custom_processing": None
            },
            "bioqa": {
                "source": "huggingface",
                "path": "qiaojin/PubMedQA",
                "split": "train",
                "columns": {},
                "custom_processing": "combine_question_context_nested"
            },
        }
        self.mol_type = mol_type
    
    def load_dataset(self, dataset_name, split=None, limit=None, processing=None):
        """Load dataset by name/identifier with standardized columns"""
        parts = dataset_name.split(",")
        if parts[0].startswith("./"):
            # Local file
            if parts[0].endswith(".json"):
                data = Dataset.from_json(parts[0])
            if parts[0].endswith(".csv"):
                data = Dataset.from_csv(parts[0])
            else:
                data = load_from_disk(parts[0])
            if len(parts) == 2:
                data = data[parts[1]]
        elif len(parts) > 1:
            data = self._load_hf_dataset(*parts)
        elif dataset_name in self.dataset_configs:
            # Predefined dataset
            config = self.dataset_configs[dataset_name]
            data = self._load_hf_dataset(dataset_name, split or config["split"])
        else:
            raise ValueError(f"Unknown dataset: {dataset_name}")
        
        # Apply custom processing
        data = self._apply_custom_processing(data, dataset_name, processing)
        
        # Standardize columns
        data = self._standardize_columns(data, dataset_name)

        # Add SELFIES if missing
        return data
    
    def _load_hf_dataset(self, dataset_name, split, subset=None):
        """Load from HuggingFace"""
        if dataset_name not in self.dataset_configs:
            raise ValueError(f"Unknown dataset config: {dataset_name}")
        
        config = self.dataset_configs[dataset_name]
        if subset:
            dataset = load_dataset(config["path"], subset)
        else:
            dataset = load_dataset(config["path"])
        return dataset[split]
    
    def _standardize_columns(self, data, dataset_name):
        """Standardize column names to SMILES, SELFIES, desc"""
        if dataset_name.startswith("./"):
            return data
            
        if dataset_name.split(",")[0] in self.dataset_configs:
            config = self.dataset_configs[dataset_name.split(",")[0]]
            column_mapping = config["columns"]
            
            # Rename columns
            for old_name, new_name in column_mapping.items():
                if old_name in data.column_names and old_name != new_name:
                    data = data.rename_column(old_name, new_name)
            
            # Select only relevant columns
            relevant_cols = [col for col in ["SMILES", "SELFIES", "desc", "question", "answer", "equa", "prod", "mol", "class"] if col in data.column_names]
            if relevant_cols:
                data = data.select_columns(relevant_cols)
        
        return data
    
    def _apply_custom_processing(self, data, dataset_name, processing=None):
        """Apply custom processing functions to datasets"""
            
        base_name = dataset_name.split(",")[0]
            
        if processing is not None:
            processing_func = processing
        elif base_name in self.dataset_configs:
            config = self.dataset_configs[base_name]
            processing_func = config.get("custom_processing")
        else:
            processing_func = None
            
        if processing_func == "combine_question_context_nested":
            return data.map(lambda x: {**x, "question": " ".join(x["context"]["contexts"]) + " " + x["question"], "answer": x["long_answer"]})
        elif processing_func == "combine_question_context":
            return data.map(lambda x: {**x, "question": x["context"] + " " + x["question"], "answer": x["answers"]["text"][0]})
        elif processing_func == "swap_gen_prod":
            return data.map(lambda x: {**x, "prod": x["gen_prod"]})
        elif processing_func == "swap_gen_equa":
            return data.map(lambda x: {**x, "equa": x["gen_equa"]})
        elif processing_func == "swap_gen_mol":
            return data.map(lambda x: {**x, self.mol_type: x["gen_mol"]})
        elif processing_func == "swap_gen_desc":
            return data.map(lambda x: {**x, "desc": x["gen_desc"]})
        elif processing_func == "prod_smiles":
            return data.map(lambda x: {**x, "SMILES": x["prod"]})
        elif processing_func == "equa_smiles":
            return data.map(lambda x: {**x, "SMILES": x["equa"]})
        return data
    
    def load_multiple_datasets(self, dataset_names, limits=None, processing=None):
        """Load and concatenate multiple datasets"""
        datasets = []
        limits = limits or {}
        processing = processing or {}
        
        for name in dataset_names:
            limit = limits.get(name)
            proc = processing.get(name)
            data = self.load_dataset(name, limit=limit, processing=proc)
            if limit:
                data = data.select(range(min(limit, len(data))))
            datasets.append(data)
        
        return concatenate_datasets(datasets)
    
    def load_datasets_with_tasks(self, dataset_configs, limits=None, processing=None):
        """Load datasets with individual task specifications, sampling per task"""
        datasets = []
        limits = limits or {}
        processing = processing or {}
        
        for config in dataset_configs:
            dataset_name = config["name"]
            limit = limits.get(dataset_name)
            proc = processing.get(dataset_name)
            tasks = config["tasks"]
            
            # Load full dataset first
            full_data = self.load_dataset(dataset_name, processing=proc)
            
            # Create separate dataset for each task
            for task in tasks:
                # Sample with task-specific seed for different subsets
                task_data = full_data.shuffle(seed=43)
                if limit:
                    task_data = task_data.select(range(min(limit, len(task_data))))
                
                # Treat as new dataset with task-specific name
                task_data = task_data.add_column("dataset_tasks", [[task]] * len(task_data))
                datasets.append(task_data)
        
        return concatenate_datasets(datasets)