"""GP kernel functions and utilities."""

import math
import torch
from typing import Optional
from pathlib import Path
import json
from typing import Iterator, Optional
from torch.utils.data import Dataset
from tqdm import tqdm
from src.utils import DataAttr


def sample_sobol(n_points, lb, ub, *, scramble=True):
    """Generate points using Sobol sequence within given bounds."""
    dim = len(lb)
    lb = torch.Tensor(lb)
    device = lb.device
    ub = torch.Tensor(ub)
    soboleng = torch.quasirandom.SobolEngine(dimension=dim, scramble=scramble)
    points = soboleng.draw(n_points).to(device)
    return points * (ub - lb) + lb[None, :]


def generate_offline_batches(
    save_dir: Path,
    num_batches: int,
    batch_size: int,
    sampler_data: str = "gp",
    sampler_kwargs: Optional[dict] = None,
    chunk_size: int = 100,
    **generation_kwargs,
):
    """
    Generate and save GP batches for offline use.

    Batches are saved in chunks to reduce I/O overhead.

    Args:
        save_dir: Directory to save batches
        num_batches: Number of batches to generate
        batch_size: Size of each batch (fixed)
        sampler_kwargs: Arguments for GPSampler
        chunk_size: Number of batches per file (default 100)
        **generation_kwargs: Arguments for generate_batch
    """
    from src.data.gp_sampler import GPSampler
    from src.data.sawtooth_sampler import SawtoothSampler
    from src.data.triangular_sampler import TriangularSampler
    from src.data.tabular_sampler import TabularSampler

    SAMPLER_CLASS_DICT = {
        "gp": GPSampler,
        "sawtooth": SawtoothSampler,
        "triangular": TriangularSampler,
        "tabular": TabularSampler,
    }

    save_dir.mkdir(parents=True, exist_ok=True)
    sampler_class = SAMPLER_CLASS_DICT[sampler_data]
    sampler = sampler_class(**(sampler_kwargs or {}))

    num_chunks = (num_batches + chunk_size - 1) // chunk_size

    # Create progress bar for overall generation
    pbar = tqdm(total=num_batches, desc="Generating batches", unit="batch")
    
    for chunk_idx in range(num_chunks):
        chunk_batches = []
        start_idx = chunk_idx * chunk_size
        end_idx = min(start_idx + chunk_size, num_batches)

        for batch_idx in range(start_idx, end_idx):
            batch = sampler.generate_batch(batch_size, **generation_kwargs)
            # Convert DataAttr to dict for saving
            batch_dict = {
                "xc": batch.xc,
                "yc": batch.yc,
                "xb": batch.xb,
                "yb": batch.yb,
                "xt": batch.xt,
                "yt": batch.yt,
            }
            chunk_batches.append(batch_dict)
            pbar.update(1)

        # Save chunk
        chunk_path = save_dir / f"chunk_{chunk_idx:04d}.pt"
        torch.save(chunk_batches, chunk_path)
        pbar.set_postfix({"chunk": f"{chunk_idx+1}/{num_chunks}"})
    
    pbar.close()

    metadata = {
        "num_batches": num_batches,
        "batch_size": batch_size,
        "chunk_size": chunk_size,
        "num_chunks": num_chunks,
        "generation_kwargs": generation_kwargs,
    }

    # Save metadata as JSON for easy inspection
    with open(save_dir / "metadata.json", "w") as f:
        json.dump(metadata, f, indent=2)

    print(f"\nGenerated {num_batches} batches of size {batch_size} in {save_dir}")
    print(f"Saved in {num_chunks} chunks of up to {chunk_size} batches each")


class OfflineBatchLoader(Dataset):
    """
    PyTorch Dataset for pre-generated GP batches.

    Supports both chunked and non-chunked storage formats.
    Compatible with PyTorch DataLoader for distributed training.

    IMPORTANT: This dataset returns pre-batched data with fixed batch sizes.
    When using with PyTorch DataLoader, set batch_size=None to prevent
    additional batching. If you need different batch sizes, you'll need
    to regenerate the data with generate_offline_batches().

    Args:
        data_dir: Directory containing saved batches
        device: Device to load tensors to (e.g., 'cuda', 'cpu')
        cache_chunks: Whether to cache loaded chunks in memory
        chunk_subset: Optional list of chunk indices to use (e.g., [0] for first chunk only)
    """

    def __init__(self, data_dir: Path, device: str = "cpu", cache_chunks: bool = True, chunk_subset: Optional[list] = None):
        self.data_dir = Path(data_dir)
        self.device = device
        self.cache_chunks = cache_chunks

        # Load metadata from JSON (with fallback to .pt for backward compatibility)
        metadata_json = self.data_dir / "metadata.json"
        metadata_pt = self.data_dir / "metadata.pt"

        if metadata_json.exists():
            with open(metadata_json) as f:
                self.metadata = json.load(f)
        elif metadata_pt.exists():
            self.metadata = torch.load(metadata_pt, map_location=device)
        else:
            raise FileNotFoundError(f"No metadata file found in {self.data_dir}")

        self.batch_size = self.metadata["batch_size"]

        # Check if using chunked format
        self.chunked = "chunk_size" in self.metadata
        if self.chunked:
            self.chunk_size = self.metadata["chunk_size"]
            self.num_chunks = self.metadata["num_chunks"]
            self._chunk_cache = {} if cache_chunks else None
            
            # Handle chunk subset for validation split
            if chunk_subset is not None:
                self.chunk_indices = chunk_subset
                self.num_batches = len(chunk_subset) * self.chunk_size
                print(f"Using chunk subset {chunk_subset}: {self.num_batches} batches")
            else:
                self.chunk_indices = list(range(self.num_chunks))
                self.num_batches = self.metadata["num_batches"]
        else:
            self.num_batches = self.metadata["num_batches"]

    def __len__(self) -> int:
        return self.num_batches

    def __getitem__(self, idx: int) -> DataAttr:
        """Get a single batch by index."""
        # Handle negative indexing
        if idx < 0:
            idx = self.num_batches + idx

        if idx < 0 or idx >= self.num_batches:
            raise IndexError(
                f"Index {idx} out of range for dataset with {self.num_batches} batches"
            )
        return self._load_batch(idx)

    def __iter__(self) -> Iterator[DataAttr]:
        """Iterate through all batches in order."""
        for i in range(self.num_batches):
            yield self._load_batch(i)

    def _load_batch(self, idx: int) -> DataAttr:
        """Load a single batch and return as DataAttr with all six components."""
        if self.chunked:
            # Calculate chunk index and position within chunk
            if hasattr(self, 'chunk_indices'):
                # Using chunk subset - map to actual chunk
                chunk_list_idx = idx // self.chunk_size
                if chunk_list_idx >= len(self.chunk_indices):
                    raise IndexError(f"Index {idx} maps to chunk {chunk_list_idx} which is out of range")
                chunk_idx = self.chunk_indices[chunk_list_idx]
                batch_idx = idx % self.chunk_size
            else:
                # Normal operation
                chunk_idx = idx // self.chunk_size
                batch_idx = idx % self.chunk_size

            # Load chunk (with caching)
            if self._chunk_cache is not None and chunk_idx in self._chunk_cache:
                chunk = self._chunk_cache[chunk_idx]
            else:
                chunk = torch.load(
                    self.data_dir / f"chunk_{chunk_idx:04d}.pt",
                    map_location=self.device,
                )
                if self._chunk_cache is not None:
                    self._chunk_cache[chunk_idx] = chunk

            batch = chunk[batch_idx]
        else:
            # Legacy format: one file per batch
            batch = torch.load(
                self.data_dir / f"batch_{idx:06d}.pt", map_location=self.device
            )

        # Return DataAttr with all six components
        # Convert to float32 if data is in float64 (for numerical stability in data generation)
        return DataAttr(
            xc=batch["xc"].float() if batch["xc"].dtype == torch.float64 else batch["xc"],
            yc=batch["yc"].float() if batch["yc"].dtype == torch.float64 else batch["yc"],
            xb=batch["xb"].float() if batch["xb"].dtype == torch.float64 else batch["xb"],
            yb=batch["yb"].float() if batch["yb"].dtype == torch.float64 else batch["yb"],
            xt=batch["xt"].float() if batch["xt"].dtype == torch.float64 else batch["xt"],
            yt=batch["yt"].float() if batch["yt"].dtype == torch.float64 else batch["yt"],
        )

    def get_info(self) -> dict:
        """Get information about the dataset including batch size."""
        return {
            "num_batches": self.num_batches,
            "batch_size": self.batch_size,
            "chunked": self.chunked,
            "chunk_size": self.chunk_size if self.chunked else None,
            "device": self.device,
        }
