from types import SimpleNamespace
import torch
from data_utils.dataset import TimeSeriesDataset
from data_utils.utils import MinMaxArgs, MinMaxScaler
from torchaudio import transforms

class TrendResampler:
    """
    A deterministic resampler that can downsample to any arbitrary target length.

    It works by finding the optimal "control points" (the downsampled signal)
    such that when they are interpolated back to the original length, the result
    is as close as possible to the original signal.

    The upsampling from these control points is, by definition, error-free.
    """

    def __init__(self, device=None):
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device

        # Cache for interpolation matrices to avoid re-computation
        self.interpolation_matrices = {}

    def _get_interpolation_matrix(self, in_len: int, out_len: int) -> torch.Tensor:
        """
        Creates or retrieves a cached interpolation matrix M of shape [out_len, in_len].
        This matrix, when right-multiplied by a vector of length `in_len`, produces
        a vector of length `out_len`. M @ x_in = x_out.
        """
        key = (in_len, out_len)
        if key not in self.interpolation_matrices:
            # Create an identity matrix of shape [in_len, 1, in_len]
            identity = torch.eye(in_len, device=self.device).unsqueeze(1)
            # Use F.interpolate to create the matrix
            # The result has shape [in_len, 1, out_len]
            M_transposed = torch.nn.functional.interpolate(
                identity, size=out_len, mode="linear", align_corners=True
            )
            # Squeeze and transpose to get the desired [out_len, in_len] matrix
            self.interpolation_matrices[key] = M_transposed.squeeze(1).T
        return self.interpolation_matrices[key]

    def downsample(
        self, trend_tensor: torch.Tensor, target_length: int
    ) -> torch.Tensor:
        """
        Downsamples by projecting the signal onto the interpolation basis.

        Args:
            trend_tensor (torch.Tensor): The high-resolution trend, shape [..., L].
            target_length (int): The desired downsampled length, D.

        Returns:
            torch.Tensor: The optimal control points, shape [..., D].
        """
        assert (
            trend_tensor.device == self.device
        ), "Tensor must be on the device of the resampler."
        L = trend_tensor.shape[-1]
        original_shape = trend_tensor.shape

        # Get the upsampling matrix M of shape [L, D]
        M = self._get_interpolation_matrix(target_length, L)

        # Reshape input to 2D for lstsq: [..., L] -> [N, L]
        trend_2d = trend_tensor.reshape(-1, L)

        # Solve the least squares problem M * x = trend for x, where x are the control points
        # B must be [..., M, K], so we need trend_2d.T [L, N]
        solution = torch.linalg.lstsq(M, trend_2d.T)
        control_points_2d = solution.solution.T  # Result is [N, D]

        # Reshape control points back to the original batch/feature shape
        downsampled_shape = original_shape[:-1] + (target_length,)
        return control_points_2d.reshape(downsampled_shape)

    def upsample(
        self, control_points_tensor: torch.Tensor, target_length: int
    ) -> torch.Tensor:
        """
        Upsamples deterministically from the control points via matrix multiplication.

        Args:
            control_points_tensor (torch.Tensor): The downsampled signal, shape [..., D].
            target_length (int): The desired upsampled length, L.

        Returns:
            torch.Tensor: The reconstructed high-resolution trend, shape [..., L].
        """
        assert (
            control_points_tensor.device == self.device
        ), "Tensor must be on the device of the resampler."
        D = control_points_tensor.shape[-1]
        original_shape = control_points_tensor.shape

        # Get the upsampling matrix M of shape [L, D]
        M = self._get_interpolation_matrix(D, target_length)

        # Reshape control points to 2D for matmul: [..., D] -> [N, D]
        control_points_2d = control_points_tensor.reshape(-1, D)

        # Upsample by matrix multiplying M @ control_points.T
        # M: [L, D], control_points.T: [D, N] -> Result: [L, N]
        upsampled_2d = (M @ control_points_2d.T).T  # Transpose result back to [N, L]

        # Reshape back to original batch/feature shape
        upsampled_shape = original_shape[:-1] + (target_length,)
        return upsampled_2d.reshape(upsampled_shape)
    
def kth_quantile(x, q):
    # x: (N, F, K) >= 1 element along N
    N = x.shape[0]
    k = max(1, int(q * N))
    vals, _ = x.reshape(N, -1).kthvalue(k, dim=0)  # (F*K,)
    return vals.view(*x.shape[1:])

# STFT DATASET WITH TREND SCALED SEPARATELY FROM REAL
class STFTDataset(TimeSeriesDataset):
    def __init__(
        self,
        args: SimpleNamespace,
        logger=None,
    ):
        self.n_fft = args.n_fft
        self.hop_length = args.hop_length

        self.cplx_scale = None
        self.min_trend, self.max_trend = None, None
        self.trend_resampler = None
        self.cov_adj = None
        self.freq_adj = None

        super().__init__(args, logger)

    def class_params_tosave_and_load(self):
        return super().class_params_tosave_and_load() + ["trend_resampler"] + ["cov_adj", "freq_adj"]

    def norm_params_tosave_and_load(self):
        base_list = super().norm_params_tosave_and_load()
        return base_list + [
            "min_trend",
            "max_trend",
            "cplx_scale",
        ]

    def compute_norm_params(self, real: torch.Tensor, imag: torch.Tensor, trend : torch.Tensor = None):

        if self.cplx_scale is None:
            self.logger.info(
                f"Computing {self.data_name} normalization parameters from scratch"
            )
            
            B, T, F, K = real.shape
            if self.decompose:
                _, min_trend, max_trend = MinMaxScaler(trend.view(-1, F, K), True)
                self.min_trend = min_trend
                self.max_trend = max_trend

            # use 0.01 and 0.99 quantiles to avoid outliers
            rr = real.reshape(-1, F, K)
            ii = imag.reshape(-1, F, K)

            radius = torch.sqrt(rr**2 + ii**2)  # (N,F,K)
            scale = kth_quantile(radius, 0.995)
            self.cplx_scale = torch.clamp(scale, min=1e-8)

    def normalize_timefreq_data(self, torch_data: torch.Tensor):
        real, imag, trend = self._extract_channels(torch_data)
        B, T, F, K = real.shape
        self.compute_norm_params(real, imag, trend)
        if trend is not None:
            trend = (MinMaxArgs(trend.view(-1, F, K), self.min_trend, self.max_trend) - 0.5) * 2
            trend = trend.view(B, T, F, K)
        s = self.cplx_scale  # (F,K)
        real = (real.reshape(-1, F, K) / s).clamp(-1, 1).view(B, T, F, K)
        imag = (imag.reshape(-1, F, K) / s).clamp(-1, 1).view(B, T, F, K)
        return self._prepare_video(real, imag, trend)  # B * T * C * F * K


    def unnormalize_timefreq_data(self, torch_data: torch.Tensor):
        if (
            self.cplx_scale is None
        ):
            raise ValueError(
                "Normalization parameters are not set. Please compute them first."
            )

        real, imag, trend = self._extract_channels(torch_data)
        if trend is not None:
            trend = (trend / 2 + 0.5) * (self.max_trend - self.min_trend) + self.min_trend
        
        real = real * self.cplx_scale
        imag = imag * self.cplx_scale
        return self._prepare_video(real, imag, trend) # B * T * C * F * K


    def _extract_channels(self, timefreq_data):
        # B, T, C, F, K = timefreq_data.shape
        real_idx = 0 if not self.decompose else 1
        imag_idx = 1 if not self.decompose else 2

        trend = timefreq_data[:, :, 0, :, :] if self.decompose else None
        real = timefreq_data[:, :, real_idx, :, :]
        imag = timefreq_data[:, :, imag_idx, :, :]
        return real, imag, trend # B, T, F, K

    def _prepare_video(self, real, imag, trend=None):
        # tensors of shape B T F K
        tensor_list = [real, imag] if trend is None else [trend, real, imag]
        # B * T * C * F * K, where C = trend (if self.decompose) + real + imag
        video = torch.stack(tensor_list, dim=2)
        return video


    def compute_timestamp_data(self, timecovariates):
        timecovariates = super().compute_timestamp_data(
            timecovariates
        )  # (n_samples, seq_len, n_covariates)
        timecovariates = timecovariates.permute(
            0, 2, 1
        )  # (n_samples, n_covariates, seq_len)
        # now we need to interpolate the timecovariates to match the stft sequence length
        stft_L = self.timefreq_data.shape[1]
        return torch.nn.functional.interpolate(
            timecovariates, size=stft_L, mode="linear", align_corners=True
        ).permute(
            0, 2, 1
        )  # (n_samples, seq_len, n_covariates)
    
    def get_frequencies_centers(self):
        """
        Returns frequency bin centers as a 1D tensor.
        - If sample_rate is None -> normalized cycles/sample in [0, 0.5] (onesided) or (-0.5,0.5].
        - If sample_rate is not None -> Hz.
        """
        d = 1.0   # 'd' is sample spacing for torch.fft.*fftfreq
        # one sided is True by default in torchaudio.transforms.Spectrogram
        f = torch.fft.rfftfreq(self.n_fft, d=d)
        return f
    
    def _compute_freq_adj(
            self,
            X: torch.Tensor,
            use_log: bool = True,
            zscore: bool = True,
            shrink_lambda: float = 0.05,
            eps: float = 1e-8,
        ) -> torch.Tensor:
        """
        Build a (F,F) frequency affinity matrix from STFT data.

        Args
        ----
        X : (B, T, C, F, K)   where C = [trend, real, imag]
        use_log : if True, use log1p(magnitude); else use magnitude
        zscore : if True, per-frequency z-normalize over samples (B*T*K)
        shrink_lambda : diagonal shrinkage toward I (stabilizes small/imbalanced sets)
        eps : numeric stability

        Returns
        -------
        aff : (F, F) tensor in roughly [-1, 1], symmetric.
        """
        assert X.dim() == 5, "X must be (B, T, C, F, K)"
        B, T, C, Freq, K = X.shape
        assert C >= 3, "Expecting channels: trend, real, imag"

        if self.decompose:
            real = X[:, :, 1, :, :]  # (B, T, F, K)
            imag = X[:, :, 2, :, :]
        else:
            real = X[:, :, 0, :, :]  # (B, T, F, K)
            imag = X[:, :, 1, :, :]

        mag = torch.sqrt(real * real + imag * imag + eps)  # (B, T, F, K)

        if use_log:
            mag = torch.log1p(mag)  # log(1 + |X|)

        Z = mag.permute(0, 1, 3, 2).reshape(-1, Freq)  # (B*T*K, F)

        if zscore:
            mu = Z.mean(dim=0, keepdim=True)
            sd = Z.std(dim=0, keepdim=True)
            Z = (Z - mu) / (sd + eps)

        N = max(1, Z.shape[0] - 1)
        C_ff = (Z.T @ Z) / N  # (F, F)
        C_ff = torch.nan_to_num(C_ff, nan=0.0, posinf=0.0, neginf=0.0)
        C_ff = 0.5 * (C_ff + C_ff.T)  # ensure symmetry

        if shrink_lambda > 0:
            I = torch.eye(Freq, device=C_ff.device, dtype=C_ff.dtype)
            C_ff = (1.0 - shrink_lambda) * C_ff + shrink_lambda * I

        C_ff = C_ff.clamp(min=-1.0, max=1.0)

        self.freq_adj = C_ff
        return C_ff

    def _compute_covariate_adj(self, X, eps=1e-6, power_weight=True):
        """
        X: (B, T, C=3, F, K)  -> returns A: (K, K) in [0,1]
        """
        assert X.dim() == 5
        B, T, C, F, K = X.shape

        # complex STFT of residuals
        if self.decompose:
            Re = X[:, :, 1].to(torch.float32)          # (B,T,F,K)
            Im = X[:, :, 2].to(torch.float32)          # (B,T,F,K)
        else:
            Re = X[:, :, 0].to(torch.float32)          # (B,T,F,K)
            Im = X[:, :, 1].to(torch.float32)          # (B,T,F,K)

        Z = torch.complex(Re, Im)                     # (B,T,F,K)

        # remove time mean to reduce bias
        Z = Z - Z.mean(dim=1, keepdim=True)        # (B,T,F,K)

        # Reorder for clean batched matmul: (B,F,K,T)
        Z_bfkt = Z.permute(0, 2, 3, 1).contiguous()

        # Cross-spectral density S_xy(f) ≈ (1/T) * Z_x(f) @ Z_y(f)^H
        # Z_bfkt: (B,F,K,T) ; Z_bfkt.conj().transpose(-1,-2): (B,F,T,K)
        S = Z_bfkt @ Z_bfkt.conj().transpose(-1, -2)   # (B,F,K,K)
        S = S / float(T)

        # Power spectral densities P_xx(f) = mean_t |Z_x|^2
        P = (Z_bfkt.conj() * Z_bfkt).real.mean(dim=-1) # (B,F,K)

        denom = (P.unsqueeze(-1) * P.unsqueeze(-2)).clamp_min(eps)  # (B,F,K,K)
        MSC   = (S.abs()**2 / denom).clamp(0.0, 1.0)                # (B,F,K,K)

        # Average over frequencies (optionally power-weighted)
        if power_weight:
            w = P.mean(dim=-1)                                      # (B,F)
            w = (w / (w.sum(dim=-1, keepdim=True) + eps)).unsqueeze(-1).unsqueeze(-1)  # (B,F,1,1)
            A = (MSC * w).sum(dim=1)                                # (B,K,K)
        else:
            A = MSC.mean(dim=1)                                     # (B,K,K)

        # Batch average, symmetrize, clip
        A = A.mean(dim=0)                                           # (K,K)
        A = 0.5 * (A + A.transpose(-1, -2))
        A = A.clamp(-1.0, 1.0)
        
        self.cov_adj = A
        return A
    
    def get_covariate_adj(self):
        if self.cov_adj is None:
            return self._compute_covariate_adj(self.timefreq_data)
        return self.cov_adj

    def get_frequency_adj(self):
        if self.freq_adj is None:
            return self._compute_freq_adj(self.timefreq_data)
        return self.freq_adj

    def compute_timefreq_decomposed(self, torch_data: torch.Tensor):
        _, B, L, K = torch_data.shape
        torch_data = torch_data.permute(1, 3, 0, 2)  # B, K, C, L

        trend, season_resid = torch.split(torch_data, split_size_or_sections=1, dim=2)
        trend = trend.squeeze(2)  # B, K, L
        season_resid = season_resid.squeeze(2)  # B, K, L

        win_length = min(L, self.n_fft)
        spec = transforms.Spectrogram(
            n_fft=self.n_fft, hop_length=self.hop_length, win_length= win_length, center=True, power=None,
        ).to(torch_data.device)
        transformed_season_resid = spec(season_resid)
        real, imag = (
            transformed_season_resid.real,
            transformed_season_resid.imag,
        )  # B, K, F, T

        stft_F, stft_T = (
            transformed_season_resid.shape[-2],
            transformed_season_resid.shape[-1],
        )

        # trend = downsample_trend(trend, stft_T)  # B, K, T
        self.trend_resampler = TrendResampler(device=trend.device)
        trend = self.trend_resampler.downsample(trend, target_length=stft_T)  # B, K, T
        trend = trend.unsqueeze(2)  # B, K, 1, T
        trend = trend.repeat(1, 1, stft_F, 1)  # B, K, F, T
        
        perm_tuple = (0, 3, 2, 1) # B, T, F, K
        return self._prepare_video(real.permute(perm_tuple), imag.permute(perm_tuple), trend.permute(perm_tuple))  # B * T * C * F * K

    def compute_timefreq_single(self, torch_data: torch.Tensor):
        B, L, K = torch_data.shape
        torch_data = torch.permute(torch_data, (0, 2, 1))  # B, K, L
        win_length = min(self.seq_len, self.n_fft)
        spec = transforms.Spectrogram(
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            win_length=win_length,
            center=True,
            power=None,
        ).to(torch_data.device)
        transformed_data = spec(torch_data)
        real, imag = transformed_data.real, transformed_data.imag  # B, K, F, T
        perm_tuple = (0, 3, 2, 1) # B, T, F, K
        return self._prepare_video(real.permute(perm_tuple), imag.permute(perm_tuple))  # B * T * C * F * K

    def inverse_timefreq_decomposed(self, timefreq_data: torch.Tensor):
        real, imag, trend = self._extract_channels(timefreq_data) # B, T, F, K 
        perm_tuple = (0, 3, 2, 1) 
        real = real.permute(perm_tuple)  # B, K, F, T
        imag = imag.permute(perm_tuple)  # B, K, F, T
        trend = trend.permute(perm_tuple) 

        n_fft = self.n_fft
        hop_length, length = self.hop_length, self.seq_len

        trend = trend.mean(dim=2)  # B, K, T
        coeffs = torch.complex(real, imag) # B, K, F, T

        # -- inverse stft --
        win_length = min(self.seq_len, n_fft)
        ispec = transforms.InverseSpectrogram(
            n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=True
        ).to(self.device)
        season_resid = ispec(coeffs, length) # B, K, L
        season_resid = season_resid.permute(0, 2, 1) # B, L, K

        trend = self.trend_resampler.upsample(
            trend.to(self.trend_resampler.device), target_length=self.seq_len
        )
        trend = trend.permute(0, 2, 1)  # B, L, K
        return trend, season_resid

    def inverse_timefreq_single(self, timefreq_data: torch.Tensor):
        real, imag, _ = self._extract_channels(timefreq_data) # B, T, F, K
        perm_tuple = (0, 3, 2, 1)
        real = real.permute(perm_tuple)  # B, K, F, T
        imag = imag.permute(perm_tuple)  # B, K, F, T
        n_fft = self.n_fft
        hop_length, length = self.hop_length, self.seq_len

        coeffs = torch.complex(real, imag) # B, K, F, T

        # -- inverse stft --
        win_length = min(length, n_fft)
        ispec = transforms.InverseSpectrogram(
            n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=True
        ).to(self.device)

        x_time_series = ispec(coeffs, length) # B, K, L
        x_time_series = x_time_series.permute(0, 2, 1) # B, L, K
        return x_time_series