from pathlib import Path
from typing import Any, Dict, List, Optional, cast
import json

from torch.utils.data import DataLoader, DistributedSampler, Dataset

from ..aliases import PathOrStr
from ..config import DataConfig, TrainConfig
from ..exceptions import OLMoConfigurationError
from ..torch_util import barrier, get_global_rank, get_world_size
from .collator import DataCollator
from .iterable_dataset import IterableDataset
from .memmap_dataset import MemMapDataset
from ..tokenizer import Tokenizer

__all__ = ["MemMapDataset", "DataCollator", "IterableDataset", "build_eval_dataloader", "build_train_dataloader"]


def build_memmap_dataset(
    train_config: TrainConfig, data_config: DataConfig, include_instance_metadata: bool = True
) -> MemMapDataset:
    paths: List[str]
    metadata: List[Dict[str, Any]] = []
    if data_config.paths:
        if data_config.datasets:
            raise OLMoConfigurationError("DataConfig.paths is mutually exclusive with DataConfig.datasets")
        paths = data_config.paths
        for path in paths:
            metadata.append({"path": str(path)})
    elif data_config.datasets:
        paths = []
        for label in sorted(data_config.datasets.keys()):
            label_paths = data_config.datasets[label]
            paths.extend(label_paths)
            metadata.extend([{"label": label}] * len(label_paths))
    else:
        raise OLMoConfigurationError("One of DataConfig.paths or DataConfig.datasets is required")
    return MemMapDataset(
        *paths,
        chunk_size=train_config.model.max_sequence_length,
        metadata=metadata,
        include_instance_metadata=include_instance_metadata,
        pad_token_id=train_config.model.pad_token_id,
        generate_attention_mask=data_config.generate_attention_mask,
        label_mask_paths=cast(Optional[List[PathOrStr]], data_config.label_mask_paths),
    )


def build_eval_dataloader(
    train_config: TrainConfig,
    data_config: DataConfig,
    batch_size: int,
    shuffle: bool = True,
) -> DataLoader:
    dataset = build_memmap_dataset(train_config, data_config, include_instance_metadata=True)
    collator = DataCollator(pad_direction=data_config.pad_direction, pad_token_id=train_config.model.pad_token_id)
    if data_config.drop_last:
        # Make sure batch size is small enough.
        samples_per_device = len(dataset) // get_world_size()
        batch_size = min(batch_size, samples_per_device)
        assert batch_size > 0, f"dataset for {data_config.paths} is too small"
    seed = data_config.seed if data_config.seed is not None else train_config.seed
    sampler = DistributedSampler(
        dataset,
        drop_last=False,
        shuffle=False,
        num_replicas=get_world_size(),
        rank=get_global_rank(),
        seed=seed,
    )
    return DataLoader(
        dataset,
        batch_size=batch_size,
        collate_fn=collator,
        num_workers=data_config.num_workers,
        sampler=sampler,
        pin_memory=data_config.pin_memory,
        prefetch_factor=None if data_config.num_workers == 0 else data_config.prefetch_factor,
        persistent_workers=False if data_config.num_workers == 0 else data_config.persistent_workers,
        timeout=data_config.timeout,
    )


def build_train_dataloader(train_config: TrainConfig) -> DataLoader:
    assert train_config.device_train_batch_size is not None
    collator = DataCollator(
        pad_direction=train_config.data.pad_direction, pad_token_id=train_config.model.pad_token_id
    )
    dataset = build_memmap_dataset(train_config, train_config.data, include_instance_metadata=False)
    work_dir = Path(train_config.save_folder) / "train_data"
    if get_global_rank() == 0:
        if work_dir.is_dir() and not train_config.save_overwrite:
            raise OLMoConfigurationError(
                "train data working directory already exists, use --save_overwrite to overwrite"
            )
        else:
            work_dir.mkdir(exist_ok=True, parents=True)
    barrier()
    seed = train_config.data.seed if train_config.data.seed is not None else train_config.seed
    return DataLoader(
        IterableDataset(
            dataset,  # type: ignore
            train_config.global_train_batch_size,
            seed=seed + (train_config.epoch or 0),
            shuffle=train_config.data_shuffling,
            drop_last=train_config.data.drop_last,
            work_dir=work_dir,
        ),
        batch_size=train_config.device_train_batch_size,
        drop_last=train_config.data.drop_last,
        collate_fn=collator,
        num_workers=train_config.data.num_workers,
        pin_memory=train_config.data.pin_memory,
        prefetch_factor=None if train_config.data.num_workers == 0 else train_config.data.prefetch_factor,
        persistent_workers=False if train_config.data.num_workers == 0 else train_config.data.persistent_workers,
        timeout=train_config.data.timeout,
    )
    
    
def build_custom_dataloader(
    train_config: TrainConfig,
) -> DataLoader:
    # dataset = build_memmap_dataset(train_config, data_config, include_instance_metadata=True)
    tokenizer = Tokenizer.from_train_config(train_config)  
    assert train_config.probe_dataset is not None
    with open(train_config.probe_dataset, 'r') as f:
        raw_dataset = json.load(f)
    
    mem_probes = []
    mem_targets = []
    gen_probes = []
    gen_targets = []
    hard_gen_probes = []
    hard_gen_targets = []
    definitions = []
    
    for idx, d in enumerate(raw_dataset):
        definitions.append((d["train_context"], f"def-{idx}-0"))
        mem_probes.extend([(d["mem_input"][i]+" "+d["mem_target"][i], f"mem-{idx}-{i}") for i in range(len(d["mem_input"]))])
        mem_targets.extend([" " + d["mem_target"][i] for i in range(len(d["mem_target"]))])
        gen_probes.extend([(d["gen_input"][i]+" "+d["gen_target"][i], f"gen-{idx}-{i}") for i in range(len(d["gen_input"]))])
        gen_targets.extend([" " + d["gen_target"][i] for i in range(len(d["gen_target"]))])
        hard_gen_probes.extend([(d["hard_gen_input"][i]+" "+d["hard_gen_target"][i], f"hard_gen-{idx}-{i}") for i in range(len(d["hard_gen_input"]))])
        hard_gen_targets.extend([" " + d["hard_gen_target"][i] for i in range(len(d["hard_gen_target"]))])
    
    all_data = mem_probes + gen_probes + hard_gen_probes + definitions
    all_targets = mem_targets + gen_targets + hard_gen_targets + definitions
    all_data_tokenized = tokenizer.encode_batch([d[0] for d in all_data], add_special_tokens=False)
    all_targets_tokenized = tokenizer.encode_batch(all_targets, add_special_tokens=False)
    
    dataset = CustomDataset([{"input_ids": all_data_tokenized[i], "metadata": (all_data[i][1], all_targets_tokenized[i])} for i in range(len(all_data))])
    collator = DataCollator(
        pad_direction=train_config.data.pad_direction, pad_token_id=train_config.model.pad_token_id
    )
    seed = train_config.seed
    sampler = DistributedSampler(
        dataset,
        drop_last=False,
        shuffle=False,
        num_replicas=get_world_size(),
        rank=get_global_rank(),
        seed=seed,
    )
    
    return DataLoader(
        dataset,
        # batch_size=train_config.device_train_batch_size,
        batch_size=train_config.device_eval_batch_size,
        collate_fn=collator,
        num_workers=train_config.data.num_workers,
        sampler=sampler,
        pin_memory=train_config.data.pin_memory,
        prefetch_factor=None if train_config.data.num_workers == 0 else train_config.data.prefetch_factor,
        persistent_workers=False if train_config.data.num_workers == 0 else train_config.data.persistent_workers,
        timeout=train_config.data.timeout,
        drop_last=False
    )
    
    
class CustomDataset(Dataset):
        def __init__(self, data):
            self.data = data
            self.length = len(data)
        
        def __len__(self):
            return self.length

        def __getitem__(self, index):
            return self.data[index]