import os
import torch
import pathlib
from logging import Logger as LoggerType
from torch.utils.data import DataLoader, TensorDataset
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Union
from rich.progress import (
    BarColumn,
    Progress,
    SpinnerColumn,
    TaskProgressColumn,
    TimeElapsedColumn,
    TimeRemainingColumn,
    MofNCompleteColumn
)


@dataclass
class CalibrationDatasetArguments:
    """
    Arguments for preparing a calibration dataset from a Hugging Face dataset or local file.
    """

    dataset_name: str = field(
        default='allenai/c4',
        metadata={"help": "The name of the dataset to use from Hugging Face hub (e.g., 'allenai/c4')."}
    )
    dataset_config_name: Optional[str] = field(
        default=None,
        metadata={"help": "The configuration name or subset of the dataset (e.g., 'en' for 'c4/en')."}
    )
    split: Optional[str] = field(
        default="train",
        metadata={"help": "Dataset split to use (e.g., 'train', 'validation', 'test')."}
    )
    num_samples: int = field(
        default=256,
        metadata={"help": "Number of calibration samples to extract. (Default: 256)"}
    )
    max_seq_length: int = field(
        default=2048,
        metadata={"help": "Maximum token sequence length per sample. (Default: 2048)"}
    )
    batch_size: int = field(
        default=1,
        metadata={"help": "Batch size for the DataLoader. (Default: 1)"}
    )
    output_path: str = field(
        default="calibration/calib_dataset.pt",
        metadata={"help": "Where to save the processed calibration dataset. (Default: calibration/calib_dataset.pt)"}
    )
    streaming: bool = field(
        default=True,
        metadata={"help": "Enable dataset streaming to reduce memory usage. (Default: True)"}
    )
    dataset_trust_remote_code: bool = field(
        default=True,
        metadata={"help": "Trust custom dataset code (only enable if you're sure it's safe). (Default: True)"}
    )
    dataset_kwargs: Union[str, Dict[str, Any]] = field(
        default_factory=dict,
        metadata={"help": "Extra arguments to pass to load_dataset. Use JSON string on CLI, or dict in YAML/JSON file."}
    )



def get_hybrid_calibration_dataset(
    logger: Any,
    tokenizer: PreTrainedTokenizerBase,
    args: Any,
    output_path: pathlib.Path,
    cache_dir: Optional[str] = None
):
    # Load dataset
    dataset = load_dataset(
        args.dataset_name,
        args.dataset_config_name if args.dataset_config_name else None,
        split=args.split,
        streaming=args.streaming,
        trust_remote_code=args.dataset_trust_remote_code,
        cache_dir=cache_dir,
        **args.dataset_kwargs
    )

    input_ids_list = []
    attention_mask_list = []
    concat_buffer = []

    eos_token_id = tokenizer.eos_token_id or tokenizer.sep_token_id or tokenizer.pad_token_id
    assert eos_token_id is not None, "No EOS/SEP/PAD token found in tokenizer"

    logger.info(f"Searching for natural samples ≥ {args.max_seq_length} tokens...")

    with Progress(
        SpinnerColumn(), "[progress.description]{task.description}", BarColumn(),
        TaskProgressColumn(), "Progress:", MofNCompleteColumn(),
        "Elapsed:", TimeElapsedColumn(), "Remaining:", TimeRemainingColumn()
    ) as progress:
        task = progress.add_task("[blue]Collecting...", total=args.num_samples)

        for example in dataset:
            text = example.get("text", "").strip()
            if not isinstance(text, str) or len(text) < 10:
                continue

            enc = tokenizer(text, return_tensors="pt", add_special_tokens=False)
            input_ids = enc["input_ids"][0]  # safer than .squeeze()
            attn_mask = enc["attention_mask"][0]

            if input_ids.size(0) >= args.max_seq_length:
                input_ids_list.append(input_ids[:args.max_seq_length].unsqueeze(0))
                attention_mask_list.append(attn_mask[:args.max_seq_length].unsqueeze(0))
                progress.update(task, advance=1)
            else:
                concat_buffer.extend(input_ids.tolist() + [eos_token_id])

            if len(input_ids_list) >= args.num_samples:
                break

        remaining = args.num_samples - len(input_ids_list)
        if remaining > 0:
            logger.info(f"Found {len(input_ids_list)} long samples. Creating {remaining} more from short samples...")
            total_needed = remaining * args.max_seq_length
            if len(concat_buffer) < total_needed:
                raise ValueError("Not enough tokens in short samples to synthesize additional sequences.")
            concat_tensor = torch.tensor(concat_buffer[:total_needed], dtype=torch.long).reshape(remaining, args.max_seq_length)
            input_ids_list.append(concat_tensor)
            attention_mask_list.append(torch.ones_like(concat_tensor))
            progress.update(task, advance=remaining)

    input_ids = torch.cat(input_ids_list, dim=0)[:args.num_samples]
    attention_mask = torch.cat(attention_mask_list, dim=0)[:args.num_samples]

    save_path = output_path / args.output_path
    save_path.mkdir(parents=True, exist_ok=True)

    torch.save(
        {"input_ids": input_ids, "attention_mask": attention_mask},
        str(save_path / "calib_dataset.pt")
    )

    logger.info(f"Saved {args.num_samples} samples to {save_path}")


def get_hybrid_calibration_dataset1(logger: LoggerType, tokenizer: AutoTokenizer, args: CalibrationDatasetArguments, output_path: pathlib.Path, cache_dir: Optional[str] = None):
    # Load dataset
    if args.dataset_config_name:
        dataset = load_dataset(args.dataset_name, args.dataset_config_name, split=args.split, streaming=args.streaming, trust_remote_code=args.dataset_trust_remote_code, cache_dir=cache_dir, **args.dataset_kwargs)
    else:
        dataset = load_dataset(args.dataset_name, split=args.split, streaming=args.streaming, trust_remote_code=args.dataset_trust_remote_code, cache_dir=cache_dir, **args.dataset_kwargs)

    input_ids_list = []
    attention_mask_list = []
    concat_buffer = []

    logger.info(f"Searching for natural samples ≥ {args.max_seq_length} tokens...")

    with Progress(
        SpinnerColumn(), "[progress.description]{task.description}", BarColumn(),
        TaskProgressColumn(), "Progress:", MofNCompleteColumn(),
        "Elapsed:", TimeElapsedColumn(), "Remaining:", TimeRemainingColumn()
    ) as progress:
        task = progress.add_task("[blue]Collecting...", total=args.num_samples)

        for example in dataset:
            text = example["text"].strip()
            if not text:
                continue

            enc = tokenizer(text, return_tensors="pt", add_special_tokens=False)
            input_ids = enc["input_ids"].squeeze()
            attn_mask = enc["attention_mask"].squeeze()

            if input_ids.shape[0] >= args.max_seq_length:
                input_ids_list.append(input_ids[:args.max_seq_length].unsqueeze(0))
                attention_mask_list.append(attn_mask[:args.max_seq_length].unsqueeze(0))
                progress.update(task, advance=1)
            else:
                concat_buffer.extend(input_ids.tolist())

            if len(input_ids_list) >= args.num_samples:
                break
        remaining = args.num_samples - len(input_ids_list)
        if remaining > 0:
            logger.info(f"Found {len(input_ids_list)} long samples. Creating {remaining} more from short samples...")
            total_needed = remaining * args.max_seq_length
            if len(concat_buffer) < total_needed:
                raise ValueError("Not enough tokens in small samples.")
            concat_tensor = torch.tensor(concat_buffer[:total_needed]).reshape(remaining, args.max_seq_length)
            input_ids_list.append(concat_tensor)
            attention_mask_list.append(torch.ones_like(concat_tensor))
            progress.update(task, advance=remaining)

    input_ids = torch.cat(input_ids_list, dim=0)[:args.num_samples]
    attention_mask = torch.cat(attention_mask_list, dim=0)[:args.num_samples]

    # Ensure the parent directory exists before saving
    (output_path/args.output_path).mkdir(parents=True, exist_ok=True)
    # os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
    torch.save({"input_ids": input_ids, "attention_mask": attention_mask}, str(output_path / args.output_path / "calib_dataset.pt"))

    logger.info(f"Saved {args.num_samples} samples to {args.output_path}")


def get_calibration_dataloader(logger, path, batch_size=1, num_workers=16):
    data = torch.load(path)
    logger.info(f"Loaded calibration dataset with shape: {data['input_ids'].shape}")
    
    dataset = TensorDataset(data["input_ids"], data["attention_mask"])
    return DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, 
                      pin_memory=True, persistent_workers=True, prefetch_factor=4)
