import json
import os
import random
from datasets import concatenate_datasets

class PromptBuilder:
    """Template-based prompt building"""
    
    def __init__(self, mol_type="SELFIES", template_file="prompt_templates.json"):
        self.mol_type = mol_type
        # If template_file is just a filename, look in the same directory as this module
        if not os.path.dirname(template_file):
            template_file = os.path.join(os.path.dirname(__file__), template_file)
        with open(template_file, 'r') as f:
            self.templates = json.load(f)
    
    def build_prompts(self, data, tasks, is_generation=False):
        """Build prompts for given tasks"""
        if isinstance(tasks, str):
            tasks = [tasks]
        
        datasets = []
        for task in tasks:
            if task not in self.templates:
                raise ValueError(f"Unknown task: {task}")
            
            task_data = data.map(
                lambda samples: self._create_prompts(samples, task, is_generation),
                batched=True,
            )
            datasets.append(task_data)
        
        return datasets[0] if len(datasets) == 1 else concatenate_datasets(datasets)
    
    def build_prompts_per_dataset(self, data, is_generation=False):
        """Build prompts using per-dataset task specifications"""
        if "dataset_tasks" not in data.column_names:
            raise ValueError("Dataset must have 'dataset_tasks' column for per-dataset prompting")
        
        # Group by tasks
        task_groups = {}
        for i, tasks in enumerate(data["dataset_tasks"]):
            task_key = tuple(sorted(tasks))
            if task_key not in task_groups:
                task_groups[task_key] = []
            task_groups[task_key].append(i)
        
        datasets = []
        for task_key, indices in task_groups.items():
            subset = data.select(indices)
            # Remove dataset_tasks column before processing
            subset = subset.remove_columns(["dataset_tasks"])
            
            for task in task_key:
                if task not in self.templates:
                    raise ValueError(f"Unknown task: {task}")
                
                task_data = subset.map(
                    lambda samples: self._create_prompts(samples, task, is_generation),
                    batched=True,
                )
                datasets.append(task_data)
        
        return concatenate_datasets(datasets) if datasets else data.remove_columns(["dataset_tasks"])
    
    def _create_prompts(self, samples, task, is_generation):
        """Create prompts from template"""
        task_templates = self.templates[task]
        prompts = []
        
        for i in range(len(samples[list(samples.keys())[0]])):
            # Get sample data
            sample_data = {key: samples[key][i] for key in samples.keys() if key != "dataset_tasks"}
            
            # Randomly select template if multiple available
            if isinstance(task_templates, list):
                template = random.choice(task_templates)
            else:
                template = task_templates
            
            # Build messages from template
            messages = []
            for msg_template in template["messages"]:
                content = msg_template["content"].format(
                    mol_type=self.mol_type,
                    mol=sample_data.get(self.mol_type, ""),
                    **sample_data
                )
                
                # Skip assistant message for generation
                if is_generation and msg_template["role"] == "assistant":
                    break
                    
                messages.append({"role": msg_template["role"], "content": content})
            
            prompts.append(messages)
        
        return {"prompt": prompts}
    
    def format_for_inference(self, data, tokenizer, thinking=None, is_generation=True):
        """Format prompts for inference"""
        if thinking is None:
            def format_messages(samples):
                texts = [tokenizer.apply_chat_template(msg, tokenize=False, add_generation_prompt=is_generation) 
                        for msg in samples["prompt"]]
                return {"messages": texts}
        else:
            def format_messages(samples):
                texts = [tokenizer.apply_chat_template(msg, tokenize=False, add_generation_prompt=is_generation, enable_thinking=thinking) 
                        for msg in samples["prompt"]]
                return {"messages": texts}
        
        return data.map(format_messages, batched=True)