"""GP kernel functions and utilities."""

import math
import torch
from typing import Optional
from pathlib import Path
import json
from typing import Iterator, Optional
from torch.utils.data import Dataset
from src.utils import DataAttr


def sample_sobol(n_points, lb, ub, *, scramble=True):
    """Generate points using Sobol sequence within given bounds."""
    dim = len(lb)
    lb = torch.Tensor(lb)
    device = lb.device
    ub = torch.Tensor(ub)
    soboleng = torch.quasirandom.SobolEngine(dimension=dim, scramble=scramble)
    points = soboleng.draw(n_points).to(device)
    return points * (ub - lb) + lb[None, :]


def generate_offline_batches(
    save_dir: Path,
    num_batches: int,
    batch_size: int,
    sampler_data: str = "gp",
    sampler_kwargs: Optional[dict] = None,
    chunk_size: int = 100,
    **generation_kwargs,
):
    """
    Generate and save GP batches for offline use.

    Batches are saved in chunks to reduce I/O overhead.

    Args:
        save_dir: Directory to save batches
        num_batches: Number of batches to generate
        batch_size: Size of each batch (fixed)
        sampler_kwargs: Arguments for GPSampler
        chunk_size: Number of batches per file (default 100)
        **generation_kwargs: Arguments for generate_batch
    """
    from src.data.gp_sampler import GPSampler
    from src.data.sawtooth_sampler import SawtoothSampler
    from src.data.triangular_sampler import TriangularSampler
    from src.data.bav_samplerv2 import BavSampler

    SAMPLER_CLASS_DICT = {
        "gp": GPSampler,
        "sawtooth": SawtoothSampler,
        "triangular": TriangularSampler,
        "bav": BavSampler
    }

    save_dir.mkdir(parents=True, exist_ok=True)
    sampler_class = SAMPLER_CLASS_DICT[sampler_data]
    sampler = sampler_class(**(sampler_kwargs or {}))

    num_chunks = (num_batches + chunk_size - 1) // chunk_size

    for chunk_idx in range(num_chunks):
        chunk_batches = []
        start_idx = chunk_idx * chunk_size
        end_idx = min(start_idx + chunk_size, num_batches)

        for _ in range(start_idx, end_idx):
            batch = sampler.generate_batch(batch_size, **generation_kwargs)
            # Convert DataAttr to dict for saving
            batch_dict = {
                "xc": batch.xc,
                "yc": batch.yc,
                "xb": batch.xb,
                "yb": batch.yb,
                "xt": batch.xt,
                "yt": batch.yt,
            }
            chunk_batches.append(batch_dict)

        # Save chunk
        torch.save(chunk_batches, save_dir / f"chunk_{chunk_idx:04d}.pt")

    metadata = {
        "num_batches": num_batches,
        "batch_size": batch_size,
        "chunk_size": chunk_size,
        "num_chunks": num_chunks,
        "generation_kwargs": generation_kwargs,
    }

    # Save metadata as JSON for easy inspection
    with open(save_dir / "metadata.json", "w") as f:
        json.dump(metadata, f, indent=2)

    print(f"Generated {num_batches} batches of size {batch_size} in {save_dir}")
    print(f"Saved in {num_chunks} chunks of up to {chunk_size} batches each")


class OfflineBatchLoader(Dataset):
    """
    PyTorch Dataset for pre-generated GP batches.

    Supports both chunked and non-chunked storage formats.
    Compatible with PyTorch DataLoader for distributed training.

    IMPORTANT: This dataset returns pre-batched data with fixed batch sizes.
    When using with PyTorch DataLoader, set batch_size=None to prevent
    additional batching. If you need different batch sizes, you'll need
    to regenerate the data with generate_offline_batches().

    Args:
        data_dir: Directory containing saved batches
        device: Device to load tensors to (e.g., 'cuda', 'cpu')
        cache_chunks: Whether to cache loaded chunks in memory
    """

    def __init__(self, data_dir: Path, device: str = "cpu", cache_chunks: bool = True, max_buffer_size: int = None):
        self.data_dir = Path(data_dir)
        self.device = device
        self.cache_chunks = cache_chunks
        self.max_buffer_size = max_buffer_size

        # Load metadata from JSON (with fallback to .pt for backward compatibility)
        metadata_json = self.data_dir / "metadata.json"
        metadata_pt = self.data_dir / "metadata.pt"

        if metadata_json.exists():
            with open(metadata_json) as f:
                self.metadata = json.load(f)
        elif metadata_pt.exists():
            self.metadata = torch.load(metadata_pt, map_location=device)
        else:
            raise FileNotFoundError(f"No metadata file found in {self.data_dir}")

        self.num_batches = self.metadata["num_batches"]
        self.batch_size = self.metadata["batch_size"]

        # Check if using chunked format
        self.chunked = "chunk_size" in self.metadata
        if self.chunked:
            self.chunk_size = self.metadata["chunk_size"]
            self.num_chunks = self.metadata["num_chunks"]
            self._chunk_cache = {} if cache_chunks else None

    def __len__(self) -> int:
        return self.num_batches

    def __getitem__(self, idx: int) -> DataAttr:
        """Get a single batch by index."""
        # Handle negative indexing
        if idx < 0:
            idx = self.num_batches + idx

        if idx < 0 or idx >= self.num_batches:
            raise IndexError(
                f"Index {idx} out of range for dataset with {self.num_batches} batches"
            )
        return self._load_batch(idx)

    def __iter__(self) -> Iterator[DataAttr]:
        """Iterate through all batches in order."""
        for i in range(self.num_batches):
            yield self._load_batch(i)

    def _load_batch(self, idx: int) -> DataAttr:
        """Load a single batch and return as DataAttr with all six components."""
        if self.chunked:
            # Calculate chunk index and position within chunk
            chunk_idx = idx // self.chunk_size
            batch_idx = idx % self.chunk_size

            # Load chunk (with caching)
            if self._chunk_cache is not None and chunk_idx in self._chunk_cache:
                chunk = self._chunk_cache[chunk_idx]
            else:
                chunk = torch.load(
                    self.data_dir / f"chunk_{chunk_idx:04d}.pt",
                    map_location=self.device,
                )
                if self._chunk_cache is not None:
                    self._chunk_cache[chunk_idx] = chunk

            batch = chunk[batch_idx]
        else:
            # Legacy format: one file per batch
            batch = torch.load(
                self.data_dir / f"batch_{idx:06d}.pt", map_location=self.device
            )

        if self.max_buffer_size is not None and batch["xb"] is not None:
            # Truncate buffer points if exceeding max_buffer_size
            if batch["xb"].shape[1] > self.max_buffer_size:
                batch["xb"] = batch["xb"][:, :self.max_buffer_size, :]
                batch["yb"] = batch["yb"][:, :self.max_buffer_size, :]

        # Return DataAttr with all six components
        # Convert to float32 if data is in float64 (for numerical stability in data generation)
        return DataAttr(
            xc=batch["xc"].float() if batch["xc"].dtype == torch.float64 else batch["xc"],
            yc=batch["yc"].float() if batch["yc"].dtype == torch.float64 else batch["yc"],
            xb=batch["xb"].float() if batch["xb"].dtype == torch.float64 else batch["xb"],
            yb=batch["yb"].float() if batch["yb"].dtype == torch.float64 else batch["yb"],
            xt=batch["xt"].float() if batch["xt"].dtype == torch.float64 else batch["xt"],
            yt=batch["yt"].float() if batch["yt"].dtype == torch.float64 else batch["yt"],
        )

    def get_info(self) -> dict:
        """Get information about the dataset including batch size."""
        return {
            "num_batches": self.num_batches,
            "batch_size": self.batch_size,
            "chunked": self.chunked,
            "chunk_size": self.chunk_size if self.chunked else None,
            "device": self.device,
        }


class SamplePermutationHelper:
    @staticmethod
    def permute_targets(xt: torch.Tensor, yt: torch.Tensor, num_samples: int, batch_size: int, num_target: int):
        """
        Permute target order for each sample.
        
        Args:
            xt: Target inputs [num_samples, batch_size, num_target, Dx], or can be expanded to this shape
            yt: Target outputs [num_samples, batch_size, num_target, Dy], or can be expanded to this shape
            num_samples: Number of samples
            batch_size: Size of the batch
            num_target: Number of target points

        """
        device = xt.device
        
        # Generate random permutations (same permutation within each batch)
        perm_ids = torch.rand(
            num_samples, num_target, device=device
        ).unsqueeze(1).repeat(1, batch_size, 1)
        perm_ids = torch.argsort(perm_ids, dim=-1)
        
        # Also compute inverse permutation for unpermuting later
        deperm_ids = torch.argsort(perm_ids, dim=-1)
        
        # Create indices for gathering
        dim_sample = torch.arange(num_samples, device=device).view(-1, 1, 1).expand(
            num_samples, batch_size, num_target
        )
        dim_batch = torch.arange(batch_size, device=device).view(1, -1, 1).expand(
            num_samples, batch_size, num_target
        )
        
        # Apply permutation
        xt_perm = xt[dim_sample, dim_batch, perm_ids]
        yt_perm = yt[dim_sample, dim_batch, perm_ids]
        
        perm_info = (dim_sample, dim_batch, deperm_ids)
        return xt_perm, yt_perm, perm_info

    @staticmethod
    def unpermute_targets(tensor1: torch.Tensor, tensor2: torch.Tensor, perm_info: tuple):
        """Unpermute targets back to original order."""
        dim_sample, dim_batch, deperm_ids = perm_info
        
        tensor1_unperm = tensor1[dim_sample, dim_batch, deperm_ids]
        
        if tensor2 is not None:
            tensor2_unperm = tensor2[dim_sample, dim_batch, deperm_ids]
            return tensor1_unperm, tensor2_unperm
        
        return tensor1_unperm, None

    @staticmethod
    def repeat_and_permute_batch(batch: DataAttr, num_samples: int):
        """
        Repeat all data at the batch dimension for num_samples times (no additional dimension), permute target points for individual batch & repetition.
        
        Args:
            batch.xc: Context inputs [batch_size, num_context, Dx]
            batch.yc: Context outputs [batch_size, num_context, Dy]
            batch.xb: Buffer points [batch_size, num_buffer, Dx]
            batch.yb: Buffer outputs [batch_size, num_buffer, Dy]
            batch.xt: Target inputs [batch_size, num_target, Dx]
            batch.yt: Target outputs [batch_size, num_target, Dy]
            num_samples: Number of samples (repeat for this number)

        Returns:
            Tuple of augmented batch and permutation info.

        """
        device = batch.xt.device
        batch_size = batch.xt.shape[0]
        num_target = batch.xt.shape[1]

        # Generate random permutations (same permutation within each batch)
        perm_ids = torch.rand(num_samples, num_target, device=device).unsqueeze(1).repeat(1, batch_size, 1)
        perm_ids = torch.argsort(perm_ids, dim=-1) # [num_samples, batch_size, num_target]

        # Also compute inverse permutation for unpermuting later
        deperm_ids = torch.argsort(perm_ids, dim=-1) # [num_samples, batch_size, num_target]
        
        # Create indices for gathering
        dim_sample = torch.arange(num_samples, device=device).view(-1, 1, 1).expand(
            num_samples, batch_size, num_target
        ) # [num_samples, batch_size, num_target]
        dim_batch = torch.arange(batch_size, device=device).view(1, -1, 1).expand(
            num_samples, batch_size, num_target
        ) # [num_samples, batch_size, num_target]
        
        # Apply permutation
        xt = batch.xt.unsqueeze(0).repeat(num_samples, 1, 1, 1)  # [num_samples, batch_size, num_target, Dx]
        xt_perm = xt[dim_sample, dim_batch, perm_ids].view(-1, num_target, xt.shape[-1])  # [num_samples*batch_size, num_target, Dx]
        yt = None if batch.yt is None else batch.yt.unsqueeze(0).repeat(num_samples, 1, 1, 1)  # [num_samples, batch_size, num_target, Dy]
        yt_perm = None if batch.yt is None else yt[dim_sample, dim_batch, perm_ids].view(-1, num_target, yt.shape[-1])  # [num_samples*batch_size, num_target, Dy]

        batch_out = DataAttr(
            xc = batch.xc.repeat(num_samples, 1, 1),  # [num_samples*batch_size, num_context, Dx]
            yc = batch.yc.repeat(num_samples, 1, 1),  # [num_samples*batch_size, num_context, Dy]
            xb = None if batch.xb is None else batch.xb.repeat(num_samples, 1, 1),  # [num_samples*batch_size, num_buffer, Dx]
            yb = None if batch.yb is None else batch.yb.repeat(num_samples, 1, 1),  # [num_samples*batch_size, num_buffer, Dy]
            xt = xt_perm,
            yt = yt_perm,
        )

        perm_info = (dim_sample, dim_batch, deperm_ids)
        return batch_out, perm_info

    @staticmethod
    def unpermute_batch(batch: DataAttr, perm_info: tuple):
        """Unpermute the batch back to the original order."""
        dim_sample, dim_batch, deperm_ids = perm_info
        NR, B, T = deperm_ids.shape
        Dx = batch.xc.shape[-1]
        Dy = batch.yc.shape[-1]

        batch_out = DataAttr(
            xc = batch.xc,
            yc = batch.yc,
            xb = batch.xb,
            yb = batch.yb,
            xt = batch.xt.view(NR, B, T, Dx)[dim_sample, dim_batch, deperm_ids].view(-1, T, Dx),
            yt = None if batch.yt is None else batch.yt.view(NR, B, T, Dy)[dim_sample, dim_batch, deperm_ids].view(-1, T, Dy)
        )

        return batch_out
