import torch
from einops import rearrange
from typing import List, Optional, Tuple
import matplotlib.pyplot as plt
from src.utils import DataAttr
from src.data.utils import sample_sobol
import random
import numpy as np
from pathlib import Path


class SawtoothSampler:
    """
    Batched sawtooth- or triangular-wave generator.

    Args:
        freq_range: Tuple[min, max] for uniform sampling of frequency per latent dim.
        x_range: Bounds for input domain (list of [min,...], [max,...]).
        noise_range: Tuple[min, max] for observation noise std.
        num_latents: Number of latent output dimensions.
        peak: Float in (0,1], fraction of period at which the wave peaks.
              peak=1.0 → sawtooth; peak=0.5 → symmetric triangular wave.
        jitter: Not used for sawtooth but kept for API compatibility.
        device: Computation device.
        dtype: Tensor dtype.
    """

    def __init__(
        self,
        freq_range: Tuple[float, float] = (3.0, 5.0),
        x_range: List[List[float]] = None,
        noise_range: Tuple[float, float] = (0.0, 0.1),
        num_latents: int = 1,
        peak: float = 0.0,
        jitter: float = 0.0,
        device: str = "cpu",
        dtype: torch.dtype = torch.float32,
    ):
        if x_range is None:
            x_range = [[-2.0], [2.0]]  # Default 1D

        self.freq_min, self.freq_max = freq_range
        self.x_bounds = torch.tensor(x_range, device=device, dtype=dtype)
        self.noise_min, self.noise_max = noise_range
        self.num_latents = num_latents
        self.peak = float(peak)
        self.jitter = jitter
        self.device = device
        self.dtype = dtype
    
    def _sample_num_context(self, context_range):
        """Sample from [low, high] if len=2, else uniformly from the list."""
        n = len(context_range)
        if n == 2:
            low, high = context_range
            if low > high:
                raise ValueError("Invalid `context_range`: low > high.")
            return torch.randint(low, high + 1, (1,)).item()

        if n > 2:
            choices = torch.as_tensor(context_range)
            idx = torch.randint(0, choices.numel(), (1,))
            return choices[idx].item()
        raise ValueError("`context_range` must have length 2 or > 2.")

    def generate_batch(
        self,
        batch_size: int,
        num_context: Optional[int] = None,
        num_buffer: int = 50,
        num_target: int = 50,
        context_range: Tuple[int, int] = (3, 47),
    ):
        """
        Returns:
            DataAttr with fields: xc, yc, xb, yb, xt, yt
        """
        if num_context is None:
            num_context = self._sample_num_context(context_range)

        num_total = num_context + num_buffer + num_target
        x_dim = self.x_bounds.shape[1]

        # Sample inputs using Sobol sequence
        x_all = sample_sobol(
            batch_size * num_total,
            list(self.x_bounds[0].cpu().numpy()),
            list(self.x_bounds[1].cpu().numpy()),
        )
        x = torch.tensor(x_all, device=self.device, dtype=self.dtype)
        x = rearrange(x, "(b n) d -> b n d", b=batch_size, n=num_total)

        # Sample frequencies per batch and latent
        freqs = self.freq_min + (self.freq_max - self.freq_min) * torch.rand(
            batch_size, self.num_latents, device=self.device, dtype=self.dtype
        )

        # Sample random directions in input space (dim = x_dim)
        dirs = torch.randn(
            batch_size, self.num_latents, x_dim, device=self.device, dtype=self.dtype
        )
        dirs = dirs / dirs.norm(dim=2, keepdim=True)

        # Sample phase offsets
        offsets = torch.rand(
            batch_size, self.num_latents, 1, device=self.device, dtype=self.dtype
        ) / freqs.unsqueeze(-1)

        # Compute raw sawtooth (in [0,1)):
        proj = torch.matmul(dirs, x.transpose(1, 2))  # -> (b, latent, num_total)
        f_saw = (freqs.unsqueeze(-1) * (proj - offsets)) % 1.0

        # Shape waveform by peak location
        if self.peak >= 1.0:
            f = f_saw
        else:
            f = torch.where(
                f_saw < self.peak, f_saw / self.peak, (1.0 - f_saw) / (1.0 - self.peak)
            )

        # Produce final y with noise
        noise_std = self.noise_min + (self.noise_max - self.noise_min) * torch.rand(
            batch_size, device=self.device, dtype=self.dtype
        )
        noise = noise_std.view(batch_size, 1, 1) * torch.randn_like(f)
        y = f + noise
        y = y.permute(0, 2, 1)

        # Permute and split indices
        perm = torch.randperm(num_total, device=self.device)
        ctx_idx = perm[:num_context]
        buf_idx = perm[num_context : num_context + num_buffer]
        tar_idx = perm[num_context + num_buffer :]

        xc = x[:, ctx_idx]
        yc = y[:, ctx_idx]
        xb = x[:, buf_idx]
        yb = y[:, buf_idx]
        xt = x[:, tar_idx]
        yt = y[:, tar_idx]

        return DataAttr(xc=xc, yc=yc, xb=xb, yb=yb, xt=xt, yt=yt)


if __name__ == "__main__":
    # Simple visualization
    seed = 999
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

    def plot_batch(batch, save=False):
        xc, yc = batch.xc, batch.yc
        xb, yb = batch.xb, batch.yb
        xt, yt = batch.xt, batch.yt

        for i in range(xc.shape[0]):
            # Combine all points
            x_all = torch.cat([xc[i], xb[i], xt[i]], dim=0).cpu().numpy().squeeze()
            y_all = torch.cat([yc[i], yb[i], yt[i]], dim=0).cpu().numpy().squeeze()

            # Identify segment masks
            n_c, n_b = xc.shape[1], xb.shape[1]
            mask_c = torch.zeros(x_all.shape, dtype=torch.bool)
            mask_b = torch.zeros(x_all.shape, dtype=torch.bool)
            mask_t = torch.zeros(x_all.shape, dtype=torch.bool)
            mask_c[:n_c] = True
            mask_b[n_c : n_c + n_b] = True
            mask_t[n_c + n_b :] = True

            # Sort by x for continuous line
            order = torch.argsort(torch.from_numpy(x_all))
            x_sorted = x_all[order]
            y_sorted = y_all[order]

            plt.figure(figsize=(8, 4))
            # Continuous line connecting all points
            plt.plot(x_sorted, y_sorted, "-", label="sawtooth")
            # Overlay markers with styles
            plt.scatter(x_all[mask_c], y_all[mask_c], marker="o", label="context")
            plt.scatter(x_all[mask_b], y_all[mask_b], marker="x", label="buffer")
            plt.scatter(
                x_all[mask_t], y_all[mask_t], marker="+", label="target", alpha=0.5
            )
            plt.legend()
            plt.title(f"Sawtooth Sample {i}")
            if save:
                plt.savefig(f"sawtooth_{i}.png")
            plt.show()

    visualize = True
    generate_offline = True

    if visualize:
        sampler = SawtoothSampler(
            freq_range=(3.0, 5.0), noise_range=(0.0, 0.0), num_latents=1, device="cpu"
        )
        batch = sampler.generate_batch(
            batch_size=10, num_context=10, num_buffer=20, num_target=1000
        )

        plot_batch(batch, save=True)

    if generate_offline:
        from src.data.utils import OfflineBatchLoader
        from torch.utils.data import DataLoader
        from src.data.utils import generate_offline_batches

        sampler_kwargs = {}
        outputdir = Path("data/sawtooth_test")
        generate_offline_batches(
            save_dir=outputdir,
            num_batches=2,
            batch_size=1,
            sampler_data="sawtooth",
            num_context=None,
            num_buffer=3,
            num_target=1000,
            context_range=[5, 10],
            chunk_size=5,
        )
        print("Offline datagen OK")

        dataset = OfflineBatchLoader(outputdir)
        dataloader = DataLoader(dataset, batch_size=None)

        for batch_idx, batch in enumerate(dataloader):
            plot_batch(batch, save=False)
