import torch
import numpy as np

class LabTransDiscreteTime:
    def __init__(self, num_durations, scheme='quantile', extra_bins=1):
        if scheme not in ['quantile', 'uniform']:
            raise ValueError("scheme must be 'quantile' or 'uniform'")
        self.num_durations = num_durations
        self.scheme = scheme
        self.extra_bins = extra_bins
        self.bin_edges_ = None

    def fit(self, durations):
        if isinstance(durations, np.ndarray):
            durations = torch.as_tensor(durations, dtype=torch.float32)
        durations = durations.flatten()

        if self.scheme == 'quantile':
            quantiles = torch.linspace(0, 1, self.num_durations + 1, device=durations.device)
            self.bin_edges_ = torch.quantile(durations, quantiles)
        elif self.scheme == 'uniform':
            min_val = torch.min(durations)
            max_val = torch.max(durations)
            self.bin_edges_ = torch.linspace(min_val, max_val + 1e-8, self.num_durations + 1, device=durations.device)

        # Ensure unique bin edges
        unique_edges = torch.unique(self.bin_edges_)
        if unique_edges.numel() - 1 < self.num_durations:
            raise ValueError("Not enough unique bin edges. Reduce num_durations.")
        self.bin_edges_ = unique_edges

        return self

    def transform(self, durations):
        # PyTorch path
        if isinstance(durations, torch.Tensor):
            d = durations.flatten()

            # ensure bin_rights is a torch tensor on the same device/dtype
            if isinstance(self.bin_edges_, torch.Tensor):
                bin_rights = self.bin_edges_[1:].to(device=d.device, dtype=d.dtype)
            else:
                bin_rights = torch.as_tensor(self.bin_edges_[1:], device=d.device, dtype=d.dtype)

            # count how many right-edges are <= duration (same as your original)
            duration_idx = (d.unsqueeze(1) >= bin_rights).sum(dim=1) + 1  # 1-based

            max_bin = self.num_durations + self.extra_bins
            duration_idx = duration_idx.clamp(1, max_bin)  # keep as tensor (int64)
            return duration_idx

        # NumPy path
        elif isinstance(durations, np.ndarray):
            d = durations.reshape(-1)

            # ensure bin_rights is a numpy array
            if isinstance(self.bin_edges_, np.ndarray):
                bin_rights = self.bin_edges_[1:]
            else:
                bin_rights = np.asarray(self.bin_edges_[1:], dtype=d.dtype)

            # same logic as torch: count how many right-edges are <= duration
            duration_idx = (d[:, None] >= bin_rights).sum(axis=1) + 1  # 1-based

            max_bin = self.num_durations + self.extra_bins
            duration_idx = np.clip(duration_idx, 1, max_bin).astype(np.int64)
            return duration_idx

        else:
            raise TypeError("durations must be a torch.Tensor or numpy.ndarray")

    def transform_one_hot(self, durations):
        duration_idx = self.transform(durations)
        return self.one_hot(duration_idx)

    def one_hot(self, duration_idx):
        n_samples = duration_idx.shape[0]
        max_bins = self.num_durations + self.extra_bins
        out = torch.zeros((n_samples, max_bins), dtype=torch.long, device=duration_idx.device)

        for i in range(n_samples):
            idx = duration_idx[i]
            if 1 <= idx <= max_bins:
                out[i, idx - 1] = 1

        return out

    def cumulative_one_hot(self, duration_idx):
        n_samples = duration_idx.shape[0]
        max_bins = self.num_durations + self.extra_bins
        out = torch.zeros((n_samples, max_bins), dtype=torch.long, device=duration_idx.device)

        for i in range(n_samples):
            idx = duration_idx[i]
            if idx > 0:
                out[i, :idx] = 1

        return out

    def inverse_transform(self, duration_idx):
        duration_idx = torch.clamp(duration_idx, 1, self.num_durations)
        return self.bin_edges_[duration_idx]

    def fit_transform(self, durations):
        return self.fit(durations).transform(durations)

    def get_bin_edges(self):
        return self.bin_edges_
