import torch
import torch.nn as nn
import random

from torch.utils.data import Dataset
import json
from torch.utils.data.dataloader import default_collate
import torch.nn.functional as F

import json
from datasets import load_from_disk
import os

import itertools, random
from pathlib import Path
from typing import Dict, List

import torch
from torch.utils.data import Dataset
from datasets import load_dataset, interleave_datasets

    
import glob, json, os
from torch.utils.data import Dataset
import random
from typing import List, Tuple, Optional
from torch.utils.data import Dataset
from datasets import load_dataset
from transformers import PreTrainedTokenizerBase

import os
import random
from typing import Optional, Dict, Any, List

import torch
from torch.utils.data import Dataset
from datasets import load_from_disk, Dataset as HFDataset

import copy
import torch




class SlimPajama(Dataset):
    def __init__(
        self,
        source_dir: str,
        subset_dir: str,
        k: int = 30_000,
        seed: int = 42,
        ratios: Optional[Dict[str, float]] = None,
        tokenizer=None,
        max_length: Optional[int] = None,
        overwrite: bool = False,
    ):
        super().__init__()
        self.tokenizer  = tokenizer
        self.max_length = max_length

        if ratios is None:
            ratios = {
                "CommonCrawl": 54.1,
                "C4": 28.7,
                "GitHub": 4.2,
                "Books": 3.7,
                "ArXiv": 3.4,
                "Wikipedia": 3.1,
                "StackExchange": 2.8,
            }

        if (not os.path.exists(subset_dir)) or overwrite:
            base = self._open_disk(source_dir)
            rng  = random.Random(int(seed))

            groups: Dict[str, List[int]] = {}
            n = len(base)
            for i in range(n):
                row = base[i]
                if not isinstance(row.get("text", None), str):
                    continue
                meta = row.get("meta", None)
                if not isinstance(meta, dict):
                    continue
                rp = meta.get("redpajama_set_name", None)
                dom = self._map_rp_to_domain(rp)  # "RedPajamaC4" -> "C4" 등
                if dom is None:
                    continue
                groups.setdefault(dom, []).append(i)

            ratios = {g: r for g, r in ratios.items() if g in groups and r > 0}
            if not ratios:
                raise ValueError("Error!")
            s = sum(ratios.values())
            ratios = {g: r / s for g, r in ratios.items()}

            k = int(k)
            targets = {g: int(k * ratios[g]) for g in ratios}
            used = sum(targets.values())
            rema = sorted([(g, (k * ratios[g]) - targets[g]) for g in ratios],
                          key=lambda x: x[1], reverse=True)
            for g, _ in rema:
                if used >= k:
                    break
                targets[g] += 1
                used += 1

            spare = 0
            for g in list(targets.keys()):
                cap = len(groups[g])
                if targets[g] > cap:
                    spare += targets[g] - cap
                    targets[g] = cap

            if spare > 0:
                while spare > 0:
                    progressed = False
                    for g in sorted(ratios, key=ratios.get, reverse=True):
                        cap = len(groups[g])
                        if targets[g] < cap:
                            targets[g] += 1
                            spare -= 1
                            progressed = True
                            if spare == 0:
                                break
                    if not progressed:
                        break

            total = sum(targets.values())
            if total > k:
                for g in sorted(ratios, key=ratios.get, reverse=True):
                    if total == k:
                        break
                    dec = min(targets[g], total - k)
                    targets[g] -= dec
                    total -= dec
            elif total < k:
                need = k - total
                for g in sorted(ratios, key=ratios.get, reverse=True):
                    cap = len(groups[g])
                    add = min(cap - targets[g], need)
                    if add > 0:
                        targets[g] += add
                        need -= add
                        if need == 0:
                            break

            picked: List[int] = []
            for g, cnt in targets.items():
                if cnt <= 0:
                    continue
                idxs = groups[g]
                chosen = idxs if len(idxs) <= cnt else rng.sample(idxs, cnt)
                picked.extend(chosen)

            rng.shuffle(picked)
            picked = picked[:min(k, len(picked))]

            sub: HFDataset = base.select(picked)
            os.makedirs(subset_dir, exist_ok=True)
            sub.save_to_disk(subset_dir)

        self.ds: HFDataset = self._open_disk(subset_dir)

    # ---- PyTorch Dataset API ----
    def __len__(self) -> int:
        return len(self.ds)

    def __getitem__(self, i: int) -> Dict[str, Any]:
        ex = self.ds[i]
        text = ex.get("text", "")
        if self.tokenizer is None:
            return {"text": text}
        enc = self.tokenizer(
            text,
            return_attention_mask=True,
            max_length=self.max_length,
            truncation=True,
            padding="max_length",
        )
        # import IPython;IPython.embed()
        input_ids = torch.tensor(enc["input_ids"], dtype=torch.long)
        attention_mask = torch.tensor(
                enc.get("attention_mask", [1] * len(enc["input_ids"])),
                dtype=torch.long,
            )
        labels = input_ids.clone()
        labels[attention_mask == 0] = -100
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }

    # ---- helpers ----
    def _open_disk(self, dir_path: str) -> HFDataset:
        ds = load_from_disk(dir_path)
        if hasattr(ds, "keys") and "train" in ds:
            ds = ds["train"]
        return ds

    def _map_rp_to_domain(self, rp: Optional[str]) -> Optional[str]:
        """Map HF meta['redpajama_set_name'] -> canonical domain key used in `ratios`."""
        if not isinstance(rp, str):
            return None
        rp = rp.strip()
        mapping = {
            "RedPajamaCommonCrawl": "CommonCrawl",
            "RedPajamaC4": "C4",
            "RedPajamaGithub": "GitHub",      
            "RedPajamaWikipedia": "Wikipedia",
            "RedPajamaStackExchange": "StackExchange",
            "RedPajamaArXiv": "ArXiv",
            "RedPajamaBook": "Books",
        }
        return mapping.get(rp, None)



