import torch
from torch.utils.data import Dataset, DataLoader
from typing import Dict, List, Any

class DistillationDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length=512, require_id=False):
        """
        Args:
            dataset: Dataset object containing original data
            tokenizer: Tokenizer for encoding text
            max_length: Maximum length of input sequence
            require_id: Whether to return sample ID
        """
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.require_id = require_id

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        prompt = self.dataset.make_prompt(item)
        
        # Encode input text
        encoded_input = self.tokenizer(
            prompt, 
            max_length=self.max_length, 
            padding="max_length", 
            truncation=True,
            return_tensors="pt"
        )
        
        encoded_input = {k: v.squeeze(0) for k, v in encoded_input.items()}
        annotator_labels = self.dataset.get_label(item)
        
        choices = list(self.dataset.label_mapping.keys())
        
        human_dist = {}
        for i, label in enumerate(choices):
            label_count = sum(1 for l in annotator_labels if l == label)
            human_dist[label] = label_count / len(annotator_labels)
            
        result = {
            "input_ids": encoded_input["input_ids"],
            "attention_mask": encoded_input["attention_mask"],
            "human_distribution": human_dist,
            "choices": choices
        }
        
        if self.require_id:
            result["id"] = item[self.dataset.id_key]
            
        return result


def get_distillation_dataloader(dataset, tokenizer, batch_size=8, shuffle=True, max_length=512, require_id=False):
    """
    Create dataloader for dataset
    
    Args:
        dataset: Original dataset
        tokenizer: Tokenizer for encoding text
        batch_size: Batch size
        shuffle: Whether to shuffle data
        require_id: Whether to need sample ID
    
    Returns:
        DataLoader: Data loader for distillation task
    """
    distillation_dataset = DistillationDataset(dataset, tokenizer, max_length=max_length, require_id=require_id)
    
    # Define collate_fn to properly handle batch data
    def collate_fn(batch):
        input_ids = torch.stack([item["input_ids"] for item in batch])
        attention_mask = torch.stack([item["attention_mask"] for item in batch])
        
        # These cannot be simply stacked because they are dictionaries or lists
        human_distributions = [item["human_distribution"] for item in batch]
        choices = [item["choices"] for item in batch]
        
        result = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "human_distribution": human_distributions,
            "choices": choices
        }
        
        if require_id:
            result["id"] = [item["id"] for item in batch]
            
        return result
    
    return DataLoader(distillation_dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)


if __name__ == "__main__":
    import os
    import sys
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    from dataset.NliDataset import SNLIDataset
    from dataset.MTDataset import MTBenchDataset
    from dataset.SumDataset import SummEvalDataset
    from transformers import AutoTokenizer
    
    tokenizer = AutoTokenizer.from_pretrained(os.path.join(os.path.dirname(__file__), "..", "models", "Qwen", "Qwen2.5-7B-Instruct"))
    snli_dataset = MTBenchDataset()
    dataloader = get_distillation_dataloader(snli_dataset, tokenizer, batch_size=4)
    
    for batch in dataloader:
        print("Input IDs shape:", batch["input_ids"].shape)
        print("Attention mask shape:", batch["attention_mask"].shape)
        print("Human distribution example:", batch["human_distribution"])
        print("Options example:", batch["choices"])
        break