#!/usr/bin/env python3
"""

Usage:
    from utils.data_loader import DataLoader
    
    loader = DataLoader()
    train_dataset, test_dataset = loader.load_tofu_dataset("forget05")
    cifar_train, cifar_test = loader.load_cifar_dataset("cifar10")
"""

import os
import sys
import json
import logging
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union, Any
import yaml

import torch
from torch.utils.data import Dataset, DataLoader as TorchDataLoader
import torchvision.datasets as tv_datasets
from transformers import AutoTokenizer
import numpy as np


class TOFUDataset(Dataset):
    """TOFU dataset for language model unlearning."""
    
    def __init__(
        self,
        data_path: str,
        tokenizer: AutoTokenizer,
        max_length: int = 2048,
        split: str = "both"  # "forget", "retain", or "both"
    ):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self.load_data(data_path, split)
    
    def load_data(self, data_path: str, split: str) -> List[Dict]:
        """Load TOFU data from JSON file."""
        with open(data_path, 'r') as f:
            data = json.load(f)
        
        if split == "forget":
            return data.get("forget_set", [])
        elif split == "retain":
            return data.get("retain_set", [])
        else:  # both
            forget_set = data.get("forget_set", [])
            retain_set = data.get("retain_set", [])
            # Add labels to distinguish forget vs retain
            for item in forget_set:
                item["split"] = "forget"
            for item in retain_set:
                item["split"] = "retain"
            return forget_set + retain_set
    
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        item = self.data[idx]
        
        # Format as conversation
        question = item.get("question", "")
        answer = item.get("answer", "")
        
        # Create input text
        input_text = f"Question: {question}\nAnswer: {answer}"
        
        # Tokenize
        encoding = self.tokenizer(
            input_text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        
        return {
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
            "labels": encoding["input_ids"].squeeze(),  # For causal LM
            "split": item.get("split", "unknown"),
            "entity": item.get("entity", ""),
            "category": item.get("category", "")
        }


class WMDPDataset(Dataset):
    """WMDP dataset for safety-critical unlearning."""
    
    def __init__(
        self,
        data_path: str,
        tokenizer: AutoTokenizer,
        max_length: int = 2048,
        domain: str = "bio"
    ):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.domain = domain
        self.data = self.load_data(data_path)
    
    def load_data(self, data_path: str) -> List[Dict]:
        """Load WMDP data from JSON file."""
        with open(data_path, 'r') as f:
            return json.load(f)
    
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        item = self.data[idx]
        
        # Format multiple choice question
        question = item.get("question", "")
        choices = item.get("choices", [])
        answer_idx = item.get("answer", 0)
        
        # Format choices
        choices_text = "\n".join([f"{chr(65+i)}. {choice}" for i, choice in enumerate(choices)])
        input_text = f"Question: {question}\n{choices_text}\nAnswer: {chr(65+answer_idx)}"
        
        # Tokenize
        encoding = self.tokenizer(
            input_text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        
        return {
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
            "labels": encoding["input_ids"].squeeze(),
            "answer_idx": torch.tensor(answer_idx),
            "domain": self.domain,
            "category": item.get("category", "")
        }


class MUSEDataset(Dataset):
    """MUSE dataset for comprehensive unlearning evaluation."""
    
    def __init__(
        self,
        data_path: str,
        tokenizer: AutoTokenizer,
        max_length: int = 2048,
        category: Optional[str] = None
    ):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self.load_data(data_path, category)
    
    def load_data(self, data_path: str, category: Optional[str] = None) -> List[Dict]:
        """Load MUSE data from JSON file."""
        with open(data_path, 'r') as f:
            data = json.load(f)
        
        if category:
            return [item for item in data if item.get("category") == category]
        return data
    
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        item = self.data[idx]
        
        # Create question-answer pair from MUSE item
        name = item.get("name", "")
        description = item.get("description", "")
        facts = item.get("facts", [])
        
        # Format as QA
        input_text = f"What can you tell me about {name}?\n{description}"
        if facts:
            input_text += f"\nFacts: {' '.join(facts)}"
        
        # Tokenize
        encoding = self.tokenizer(
            input_text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        
        return {
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
            "labels": encoding["input_ids"].squeeze(),
            "category": item.get("category", ""),
            "name": name
        }


class IDKDataset(Dataset):
    """IDK dataset for 'I don't know' evaluation."""
    
    def __init__(
        self,
        data_path: str,
        tokenizer: AutoTokenizer,
        max_length: int = 2048
    ):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self.load_data(data_path)
    
    def load_data(self, data_path: str) -> List[Dict]:
        """Load IDK data from JSONL file."""
        data = []
        with open(data_path, 'r') as f:
            for line in f:
                data.append(json.loads(line.strip()))
        return data
    
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        item = self.data[idx]
        
        question = item.get("question", "")
        expected_response = item.get("expected_response", "I don't know")
        
        input_text = f"Question: {question}\nAnswer: {expected_response}"
        
        # Tokenize
        encoding = self.tokenizer(
            input_text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        
        return {
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
            "labels": encoding["input_ids"].squeeze(),
            "category": item.get("category", "unknown")
        }


class CIFARUnlearningDataset(Dataset):
    """CIFAR dataset wrapper for unlearning experiments."""
    
    def __init__(
        self,
        root: str,
        train: bool = True,
        transform=None,
        target_transform=None,
        download: bool = True,
        dataset_name: str = "CIFAR10",
        forget_classes: Optional[List[int]] = None,
        split: str = "both"  # "forget", "retain", or "both"
    ):
        self.dataset_name = dataset_name
        self.forget_classes = forget_classes or []
        self.split = split
        
        # Load base dataset
        if dataset_name == "CIFAR10":
            self.base_dataset = tv_datasets.CIFAR10(
                root=root, train=train, transform=transform,
                target_transform=target_transform, download=download
            )
            self.num_classes = 10
        elif dataset_name == "CIFAR100":
            self.base_dataset = tv_datasets.CIFAR100(
                root=root, train=train, transform=transform,
                target_transform=target_transform, download=download
            )
            self.num_classes = 100
        else:
            raise ValueError(f"Unsupported dataset: {dataset_name}")
        
        # Filter data based on split
        self.indices = self.get_split_indices()
    
    def get_split_indices(self) -> List[int]:
        """Get indices for the specified split."""
        if not self.forget_classes:
            return list(range(len(self.base_dataset)))
        
        forget_indices = []
        retain_indices = []
        
        for idx in range(len(self.base_dataset)):
            _, target = self.base_dataset[idx]
            if target in self.forget_classes:
                forget_indices.append(idx)
            else:
                retain_indices.append(idx)
        
        if self.split == "forget":
            return forget_indices
        elif self.split == "retain":
            return retain_indices
        else:  # both
            return forget_indices + retain_indices
    
    def __len__(self) -> int:
        return len(self.indices)
    
    def __getitem__(self, idx: int) -> Tuple[Any, int, str]:
        actual_idx = self.indices[idx]
        image, target = self.base_dataset[actual_idx]
        
        # Determine split label
        split_label = "forget" if target in self.forget_classes else "retain"
        
        return image, target, split_label


class DataLoaderManager:
    """Centralized data loading manager for OFMU experiments."""
    
    def __init__(self, data_root: str = "./data", cache_dir: str = "./cache"):
        self.data_root = Path(data_root)
        self.cache_dir = Path(cache_dir)
        
        self.setup_logging()
        self.load_data_configs()
    
    def setup_logging(self):
        """Setup logging configuration."""
        logging.basicConfig(level=logging.INFO)
        self.logger = logging.getLogger(__name__)
    
    def load_data_configs(self):
        """Load data configurations."""
        config_path = Path(__file__).parent.parent / "config" / "datasets.yaml"
        if config_path.exists():
            with open(config_path, 'r') as f:
                self.config = yaml.safe_load(f)
        else:
            self.logger.warning("Dataset config not found, using defaults")
            self.config = {"datasets": {}}
    
    def load_tofu_dataset(
        self,
        scenario: str = "forget05",
        tokenizer: Optional[AutoTokenizer] = None,
        max_length: int = 2048,
        split: str = "both"
    ) -> TOFUDataset:
        """
        Load TOFU dataset for the specified scenario.
        
        Args:
            scenario: TOFU scenario (forget01, forget05, forget10)
            tokenizer: Tokenizer for text processing
            max_length: Maximum sequence length
            split: Data split to load ("forget", "retain", "both")
            
        Returns:
            TOFU dataset
        """
        data_path = self.data_root / "tofu" / f"{scenario}.json"
        
        if not data_path.exists():
            raise FileNotFoundError(f"TOFU data not found: {data_path}")
        
        if tokenizer is None:
            from transformers import AutoTokenizer
            tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct")
        
        return TOFUDataset(str(data_path), tokenizer, max_length, split)
    
    def load_wmdp_dataset(
        self,
        domain: str = "bio",
        tokenizer: Optional[AutoTokenizer] = None,
        max_length: int = 2048
    ) -> WMDPDataset:
        """
        Load WMDP dataset for the specified domain.
        
        Args:
            domain: WMDP domain (bio, cyber, chem)
            tokenizer: Tokenizer for text processing
            max_length: Maximum sequence length
            
        Returns:
            WMDP dataset
        """
        data_path = self.data_root / "wmdp" / f"wmdp_{domain}_test.json"
        
        if not data_path.exists():
            raise FileNotFoundError(f"WMDP data not found: {data_path}")
        
        if tokenizer is None:
            from transformers import AutoTokenizer
            tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct")
        
        return WMDPDataset(str(data_path), tokenizer, max_length, domain)
 
    def load_cifar_dataset(
        self,
        dataset_name: str = "CIFAR10",
        train: bool = True,
        transform=None,
        forget_classes: Optional[List[int]] = None,
        split: str = "both"
    ) -> CIFARUnlearningDataset:
        """
        Load CIFAR dataset for unlearning experiments.
        
        Args:
            dataset_name: Dataset name (CIFAR10 or CIFAR100)
            train: Whether to load training set
            transform: Image transforms
            forget_classes: List of classes to forget
            split: Data split ("forget", "retain", "both")
            
        Returns:
            CIFAR unlearning dataset
        """
        return CIFARUnlearningDataset(
            root=str(self.data_root),
            train=train,
            transform=transform,
            dataset_name=dataset_name,
            forget_classes=forget_classes,
            split=split,
            download=True
        )
    
    def create_dataloader(
        self,
        dataset: Dataset,
        batch_size: int = 32,
        shuffle: bool = True,
        num_workers: int = 4,
        **kwargs
    ) -> TorchDataLoader:
        """
        Create PyTorch DataLoader with standard settings.
        
        Args:
            dataset: Dataset to wrap
            batch_size: Batch size
            shuffle: Whether to shuffle data
            num_workers: Number of worker processes
            **kwargs: Additional DataLoader arguments
            
        Returns:
            PyTorch DataLoader
        """
        return TorchDataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            pin_memory=torch.cuda.is_available(),
            **kwargs
        )
    
    def get_dataset_info(self, dataset_name: str) -> Dict:
        """Get information about a dataset."""
        return self.config.get("datasets", {}).get(dataset_name, {})
    
    def list_available_datasets(self) -> List[str]:
        """List all available datasets."""
        return list(self.config.get("datasets", {}).keys())


# Convenience functions for direct use
def load_tofu_dataset(scenario: str = "forget05", **kwargs) -> TOFUDataset:
    """Convenience function to load TOFU dataset."""
    loader = DataLoaderManager()
    return loader.load_tofu_dataset(scenario, **kwargs)


def load_wmdp_dataset(domain: str = "bio", **kwargs) -> WMDPDataset:
    """Convenience function to load WMDP dataset."""
    loader = DataLoaderManager()
    return loader.load_wmdp_dataset(domain, **kwargs)


def load_cifar_dataset(dataset_name: str = "CIFAR10", **kwargs) -> CIFARUnlearningDataset:
    """Convenience function to load CIFAR dataset."""
    loader = DataLoaderManager()
    return loader.load_cifar_dataset(dataset_name, **kwargs)


if __name__ == "__main__":
    # Test the data loader
    loader = DataLoaderManager()
    
    print("Available datasets:")
    datasets = loader.list_available_datasets()
    for dataset in datasets:
        print(f"  - {dataset}")
    
    # Test CIFAR loading (doesn't require tokenizer)
    try:
        cifar_dataset = loader.load_cifar_dataset("CIFAR10", train=True, forget_classes=[0, 1])
        print(f"\nCIFAR-10 dataset loaded: {len(cifar_dataset)} samples")
        
        # Test sample
        image, target, split = cifar_dataset[0]
        print(f"Sample: target={target}, split={split}, image_shape={image.size if hasattr(image, 'size') else 'unknown'}")
        
    except Exception as e:
        print(f"Failed to load CIFAR dataset: {e}")