import torch 
from torch.utils.data import DataLoader

import transformers
from dataclasses import dataclass
IGNORE_LABEL = -100

@dataclass
class LabeledStringDataCollator:
    tokenizer: transformers.PreTrainedTokenizer
    system_prompt: str = "You are a helpful assistant."
    skip_template: bool = False
    think_mode: bool = True

    @staticmethod
    def get_tokenizer_args(tokenizer):
        return dict(
            padding=True,
            truncation=True,
            max_length=(
                tokenizer.model_max_length
                if hasattr(tokenizer, "model_max_length")
                else None
            ),
            return_tensors="pt",
            return_length=True,
        )

    def __call__(self, prompts, targets=None):
        tokenizer_args = self.get_tokenizer_args(self.tokenizer)
        
        if not self.skip_template:
            ## This have to be modified if the model is Qwen 3
            msgs = [[#{"role": "system", "content": self.system_prompt},
                    {"role": "user", "content": p}] for p in prompts]
            
            prompts = [self.tokenizer.apply_chat_template(
                m,
                tokenize=False, 
                add_generation_prompt=True,
                enable_thinking=self.think_mode) for m in msgs]
  
        if targets:
            all_prompts = [p + t for p, t in zip(prompts, targets)]
        else:
            all_prompts = prompts
        
        inputs = self.tokenizer(all_prompts, **tokenizer_args)
        input_lengths = inputs.pop("length")

        if targets:
            un_inputs = self.tokenizer(prompts, **tokenizer_args)
            un_input_lengths = un_inputs.pop("length")

            labels = inputs.get("input_ids").clone()
            for i, l in enumerate(input_lengths - un_input_lengths):
                labels[i, :-l] = IGNORE_LABEL
            inputs["labels"] = labels
            
        return inputs
    

def get_num_workers(num_workers=4):
    num_gpus_per_host = torch.cuda.device_count()
    if num_gpus_per_host == 0:
        return num_workers
    return (num_workers + num_gpus_per_host - 1) // num_gpus_per_host


def get_loader(dataset, batch_size=128, num_workers=4, accelerator=None, **kwargs):
    num_workers = get_num_workers(num_workers=num_workers)
    loader = DataLoader(
        dataset, batch_size=batch_size, num_workers=num_workers, **kwargs
    )
    if accelerator is not None:
        loader = accelerator.prepare(loader)

    return loader