import torch
import numpy as np
from typing import Any, Dict, List
from functools import partial
from datasets import concatenate_datasets, load_dataset, DatasetDict

LOAD_FROM_CACHE = True
role_map = {"HUMAN": "user", "ASSISTANT": "assistant"}
NUM_PROC = 8


class DataLoader:
    def __init__(self, args, tokenizer, test_size=0.1, seed=42):
        self.args = args
        self.tokenizer = tokenizer
        self.test_size = test_size
        self.seed = seed

    # ================================================================
    # PUBLIC
    # ================================================================

    def load_data(self):
        """
        Main entry: Load, tokenize, filter, enrich, and format dataset into train/validation.

        Returns:
            - train: columns ['input_ids', 'block_ranges', 'prompt_lengths', 'seed']
            - validation: columns ['input_ids', 'block_ranges', 'prompt_lengths', 't', 'seed']
        """
        raw = self.load_dataset_pair()
        raw = raw.map(
            partial(self.to_chat_batched, dataset_args=self.args, tokenizer=self.tokenizer),
            batched=True,
            remove_columns=raw["train"].column_names,
            desc="Tokenization",
            load_from_cache_file=LOAD_FROM_CACHE,
            num_proc=NUM_PROC,
        )
        raw = self.apply_filtering(raw)
        
        # Workaround for a bug in `datasets` where filtering can de-sync table rows before add_column
        raw = raw.map(lambda x: x, num_proc=1)

        raw = self.add_metadata_columns(raw)

        # Token stats
        tr = raw["train"]
        print(
            f"Train tokens: prompt {int(np.sum(tr['prompt_lengths']))}, "
            f"response {int(np.sum(tr['resp_lengths']))}, "
            f"total {int(np.sum(tr['prompt_lengths']) + np.sum(tr['resp_lengths']))}"
        )
        raw["train"].set_format("torch", ["input_ids", "prompt_lengths", "seed"], output_all_columns=True)
        raw["validation"].set_format("torch", ["input_ids", "prompt_lengths", "t", "seed"], output_all_columns=True)

        print("Train size:", len(tr))
        print("Eval  size:", len(raw["validation"]))

        return raw["train"], raw["validation"]

    # ================================================================
    # FIRST-LEVEL HELPERS
    # ================================================================

    def load_dataset_pair(self):
        """Load and split the dataset(s) into train/validation."""
        args, seed = self.args, self.seed

        if args.data_path == "mix":
            d1 = self.to_train_val(load_dataset("allenai/tulu-3-sft-mixture"))
            d2 = self.to_train_val(load_dataset("HuggingFaceTB/smoltalk", "all"))
            raw = DatasetDict(
                train=concatenate_datasets([d1["train"], d2["train"]]).shuffle(seed),
                validation=concatenate_datasets([d1["validation"], d2["validation"]]).shuffle(seed),
            )
        else:
            ds = load_dataset(args.data_path, args.data_subset) if args.data_subset else load_dataset(args.data_path)
            raw = self.to_train_val(ds)

        raw["train"] = raw["train"].shuffle(seed).select(range(min(args.max_train_size, len(raw["train"]))))
        raw["validation"] = raw["validation"].shuffle(seed).select(range(min(500, len(raw["validation"]))))
        return raw

    def to_chat_batched(self, batch, dataset_args, tokenizer):
        """Convert each example in batch into tokenized chat format with block metadata."""
        results = dict(input_ids=[], block_ranges=[], prompt_lengths=[], resp_lengths=[])

        # choose data format
        if dataset_args.is_messages:
            examples = batch["messages"]
        else:
            examples = [
                [{"role": "user", "content": q}, {"role": "assistant", "content": a}]
                for q, a in zip(batch[dataset_args.prompt_column], batch[dataset_args.response_column])
            ]

        for msgs in examples:
            msgs = self.normalize_roles(msgs)
            prompt, response = msgs[:-1], msgs[-1:]

            prompt_len, full_ids, blocks, resp_len = self.process_example(prompt, response, tokenizer)

            results["input_ids"].append(full_ids)
            results["block_ranges"].append(blocks)
            results["prompt_lengths"].append(prompt_len)
            results["resp_lengths"].append(resp_len)

        return results

    def apply_filtering(self, raw):
        """Apply length, response-length, and tool-token filters to all dataset splits."""
        def _fast_filter(ex):
            # combine all into a single pass
            if len(ex["input_ids"]) > 2048:
                return False
            if ex["resp_lengths"] < 5:
                return False
            ids = ex["input_ids"]
            if 151657 in ids or 151658 in ids:
                return False
            return True

        for s in ["train", "validation"]:
            raw[s] = raw[s].filter(
                _fast_filter,
                desc=f"filter {s}",
                load_from_cache_file=LOAD_FROM_CACHE,
                num_proc=NUM_PROC,
            )
        return raw

    def add_metadata_columns(self, raw):
        """Add scheduler value 't' and deterministic random seed column to datasets."""
        args, seed = self.args, self.seed
        n_val = len(raw["validation"])
        n_train = len(raw["train"])

        raw["validation"] = raw["validation"].add_column("t", np.linspace(args.min_prob, args.max_prob, n_val).tolist())
        raw["validation"] = raw["validation"].add_column("seed", [seed + i for i in range(n_val)])
        raw["train"]      = raw["train"].add_column("seed", [seed + i for i in range(n_train)])
        return raw

    # ================================================================
    # SECOND-LEVEL HELPERS
    # ================================================================

    def to_train_val(self, ds):
        """Convert dataset to train/validation split, reusing existing splits if present."""
        if "validation" in ds:
            return DatasetDict(train=ds["train"], validation=ds["validation"])
        if "test" in ds:
            return DatasetDict(train=ds["train"], validation=ds["test"])
        split = ds["train"].train_test_split(test_size=self.test_size, seed=self.seed)
        return DatasetDict(train=split["train"], validation=split["test"])

    # ---------------------------------------------------------------

    def process_example(self, prompt, response, tokenizer):
        """Unifies prompt → full_ids → block ranges logic."""
        prompt_len = self.get_prompt_length(prompt, tokenizer)
        full_text = (tokenizer.apply_chat_template(prompt + response, tokenize=False)
                     if getattr(tokenizer, "chat_template", None)
                     else "".join(f"{m['role']}: {m['content']}\n" for m in prompt + response))

        # tokenize response with granularity logic
        # Append EOS to response text before splitting into blocks to ensure
        # downstream logic sees EOS once and avoids double-appending
        resp_text = response[-1]["content"]
        pad_count = int(np.random.randint(10, 20))
        resp_text = resp_text + "<|im_end|>" + tokenizer.eos_token * pad_count
        # resp_text = resp_text + "<|im_end|>\n"
        resp_ids, block_ranges = self.split_in_blocks(resp_text, tokenizer, prompt_len)

        # combine prompt prefix + response ids
        prompt_ids = tokenizer(full_text, add_special_tokens=True)["input_ids"][:prompt_len]
        full_ids = prompt_ids + resp_ids

        return prompt_len, full_ids, block_ranges, len(resp_ids)

    def get_prompt_length(self, prompt, tokenizer):
        """Compute tokenized length of the prompt (with or without chat template)."""
        if getattr(tokenizer, "chat_template", None):
            return len(tokenizer.apply_chat_template(prompt, tokenize=True, add_generation_prompt=True))
        text = "\n".join(f"{m['role']}: {m['content']}" for m in prompt)
        return len(tokenizer(text, add_special_tokens=True)["input_ids"])

    def split_in_blocks(self, response_text, tokenizer, prompt_len):
        """Split response text into blocks and return token ranges."""
        g = self.args.granularity
        
        # No granularity - single block
        if g is None:
            ids = tokenizer(response_text, add_special_tokens=False)["input_ids"]
            return ids, [(prompt_len, prompt_len + len(ids))]

        # Sentence granularity - split on delimiters
        if g == "sentence":
            sep_chars = ["\n"]
            blocks = []
            start = 0
            text = response_text

            # Text-based splitting
            for i in range(len(text)):
                for sep in sep_chars:
                    if text.startswith(sep, i):
                        blocks.append(text[start: i + len(sep)])
                        start = i + len(sep)

            if start < len(text):
                blocks.append(text[start:])

            input_ids = []
            ranges = []
            running_token_offset = 0

            for block in blocks:
                block_ids = tokenizer(block, add_special_tokens=False)["input_ids"]
                n_tok = len(block_ids)

                start_tok = prompt_len + running_token_offset
                end_tok = start_tok + n_tok

                ranges.append((start_tok, end_tok))
                input_ids.extend(block_ids)

                running_token_offset += n_tok
            
            return input_ids, ranges
        
        # Integer granularity - blocks of specified size
        try:
            block_size = int(g)
            ids = tokenizer(response_text, add_special_tokens=False)["input_ids"]
            ranges = [
                (prompt_len + i, prompt_len + min(i + block_size, len(ids)))
                for i in range(0, len(ids), block_size)
            ]
            return ids, ranges
        except (ValueError, TypeError):
            raise ValueError(f"Unknown granularity {g}")

    # ---------------------------------------------------------------

    def normalize_roles(self, msgs):
        """Convert dataset-specific role labels into unified 'user'/'assistant' roles."""
        return [{"role": role_map.get(m["role"], m["role"]), "content": m["content"]} for m in msgs]
