import torch
import torchcde
import math

def _kernel_gaussian(u, bandwidth):
    return (1 / (bandwidth * math.sqrt(2 * math.pi))) * torch.exp(-0.5 * (u / bandwidth)**2)

def _kernel_deriv_gaussian(u, bandwidth):
    h2 = bandwidth**2
    return -u / (h2 * bandwidth * math.sqrt(2 * math.pi)) * torch.exp(-0.5 * (u / bandwidth)**2)

def _kernel_laplacian(u, bandwidth):
    return (0.5 / bandwidth) * torch.exp(-torch.abs(u / bandwidth))

def _kernel_deriv_laplacian(u, bandwidth):
    return -0.5 * torch.sign(u) / (bandwidth**2) * torch.exp(-torch.abs(u / bandwidth))

def _kernel_cauchy(u, bandwidth):
    return 1 / (math.pi * bandwidth * (1 + (u / bandwidth)**2))

def _kernel_deriv_cauchy(u, bandwidth):
    h2 = bandwidth**2
    return -2 * u / (math.pi * h2 * bandwidth * (1 + (u / bandwidth)**2)**2)

def _kernel_epanechnikov(u, bandwidth):
    u_scaled_abs = torch.abs(u / bandwidth)
    return (0.75 / bandwidth) * torch.clamp(1 - u_scaled_abs**2, min=0)

def _kernel_deriv_epanechnikov(u, bandwidth):
    u_scaled_abs = torch.abs(u / bandwidth)
    mask = (u_scaled_abs < 1).float()
    return (-1.5 * u / (bandwidth**3)) * mask

def _kernel_tricube(u, bandwidth):
    u_scaled_abs = torch.abs(u / bandwidth)
    return (70 / (81 * bandwidth)) * torch.clamp(1 - u_scaled_abs**3, min=0)**3

def _kernel_deriv_tricube(u, bandwidth):
    u_scaled_abs = torch.abs(u / bandwidth)
    mask = (u_scaled_abs < 1).float()
    return -(70 * 3 / (81 * bandwidth)) * (u_scaled_abs**2 * torch.sign(u) / bandwidth) * torch.clamp(1 - u_scaled_abs**3, min=0)**2 * mask

def _rbf_kernel_matrix(t1, t2, length_scale):
    # t1: [N], t2: [M] -> result: [N, M]
    diff = t1.unsqueeze(1) - t2.unsqueeze(0)
    return torch.exp(-0.5 * (diff / length_scale)**2)

def _rbf_kernel_deriv_matrix(t_eval, t_grid, length_scale):
    diff = t_eval.unsqueeze(1) - t_grid.unsqueeze(0) # [N, M]
    K = torch.exp(-0.5 * (diff / length_scale)**2)
    return K * (-diff / (length_scale**2))


class KernelInterpolation(torchcde.interpolation_base.InterpolationBase):
    def __init__(self, coeffs, t=None, kernel_params={}, include_time=True, **kwargs):
        super(KernelInterpolation, self).__init__(**kwargs)
        if coeffs.dim() == 2:
            coeffs = coeffs.unsqueeze(0)

        self.bandwidth = kernel_params.get('bandwidth', 0.1)
        kernel = kernel_params.get('kernel', 'gaussian')
        self.include_time = include_time
        
        self._define_kernels()
        self._kernel_fn = self._kernel_fns[kernel]
        self._kernel_deriv_fn = self._kernel_deriv_fns[kernel]

        if t is not None:
            self.register_buffer('_t', t)
            self.register_buffer('_x', coeffs)
        else:
            self.register_buffer('_t', coeffs[0, :, 0]) # [seq_len]
            self.register_buffer('_x', coeffs[..., 1:]) # [batch, seq_len, feat]

    def _define_kernels(self):
        self._kernel_fns = {
            'gaussian': _kernel_gaussian, 'laplacian': _kernel_laplacian, 
            'cauchy': _kernel_cauchy, 'epanechnikov': _kernel_epanechnikov, 'tricube': _kernel_tricube,
        }
        self._kernel_deriv_fns = {
            'gaussian': _kernel_deriv_gaussian, 'laplacian': _kernel_deriv_laplacian, 
            'cauchy': _kernel_deriv_cauchy, 'epanechnikov': _kernel_deriv_epanechnikov, 'tricube': _kernel_deriv_tricube,
        }

    @property
    def grid_points(self): return self._t

    @property
    def interval(self): return torch.stack([self._t[0], self._t[-1]])

    def _evaluate_or_derivative(self, t_eval, get_derivative=False):
        # t_eval_b: [batch]
        t_eval_b = t_eval.expand(self._x.size(0)) if t_eval.dim() == 0 else t_eval
        
        time_diff = t_eval_b.unsqueeze(-1) - self._t
        
        if not get_derivative:
            weights = self._kernel_fn(time_diff, self.bandwidth)
            D = weights.sum(dim=-1, keepdim=True).clamp(min=1e-9)
            N = (weights.unsqueeze(-1) * self._x).sum(dim=1)
            result = N / D
            
            if self.include_time:
                return torch.cat([t_eval_b.unsqueeze(-1), result], dim=-1)
            else:
                return result
        else:
            K = self._kernel_fn(time_diff, self.bandwidth)
            K_deriv = self._kernel_deriv_fn(time_diff, self.bandwidth)
            
            D = K.sum(dim=-1, keepdim=True).clamp(min=1e-9)
            N = (K.unsqueeze(-1) * self._x).sum(dim=1)
            
            D_deriv = K_deriv.sum(dim=-1, keepdim=True)
            N_deriv = (K_deriv.unsqueeze(-1) * self._x).sum(dim=1)
            
            deriv_x = (N_deriv * D - N * D_deriv) / D.pow(2).clamp(min=1e-10)
            
            if self.include_time:
                deriv_t = torch.ones_like(t_eval_b).unsqueeze(-1)
                return torch.cat([deriv_t, deriv_x], dim=-1)
            else:
                return deriv_x

    def evaluate(self, t_eval): return self._evaluate_or_derivative(t_eval, False)
    def derivative(self, t_eval): return self._evaluate_or_derivative(t_eval, True)
    
    def get_debug_info(self, t_eval):
        """
        Returns interpolated values, normalization factor (denominator), and kernel values.
        """
        t_eval_b = t_eval.expand(self._x.size(0)) if t_eval.dim() == 0 else t_eval
        
        # t_eval: [Steps] -> [Batch, Steps, 1]
        t_eval_expanded = t_eval.unsqueeze(0).unsqueeze(-1).expand(self._x.size(0), -1, 1)
        # self._t: [SeqLen] -> [1, 1, SeqLen]
        t_grid = self._t.unsqueeze(0).unsqueeze(0)
        
        # time_diff: [Batch, Steps, SeqLen]
        time_diff = t_eval_expanded - t_grid
        
        # weights: [Batch, Steps, SeqLen]
        weights = self._kernel_fn(time_diff, self.bandwidth)
        
        # D: [Batch, Steps, 1]
        D = weights.sum(dim=-1, keepdim=True)
        D_clamped = D.clamp(min=1e-9)
        
        # self._x: [Batch, SeqLen, Feats] -> [Batch, 1, SeqLen, Feats]
        x_expanded = self._x.unsqueeze(1)
        
        # N: [Batch, Steps, SeqLen, 1] * [Batch, 1, SeqLen, Feats] -> sum over SeqLen -> [Batch, Steps, Feats]
        N = (weights.unsqueeze(-1) * x_expanded).sum(dim=2)
        
        result = N / D_clamped
        
        if self.include_time:
            # t_res: [Batch, Steps, 1]
            t_res = t_eval.unsqueeze(0).expand(self._x.size(0), -1).unsqueeze(-1)
            result = torch.cat([t_res, result], dim=-1)
            
        return result, D, {'kernel_values': weights}


class WeightedKernelInterpolation(torchcde.interpolation_base.InterpolationBase):
    def __init__(self, coeffs, weights, t=None, kernel_params={}, include_time=True, **kwargs):
        super(WeightedKernelInterpolation, self).__init__(**kwargs)

        # coeffs: [batch_size * num_heads, seq_len, input_channels]
        # weights: [batch_size * num_heads, seq_len]
        self.coeffs = coeffs
        self.weights = weights.unsqueeze(-1) # -> [batch_size * num_heads, seq_len, 1]
        self.include_time = include_time

        self.bandwidth = kernel_params.get('bandwidth', 0.1)
        if not isinstance(self.bandwidth, torch.Tensor):
            self.bandwidth = torch.tensor(self.bandwidth, device=coeffs.device, dtype=coeffs.dtype)
        
        if self.bandwidth.dim() == 0: # Global bandwidth
            pass
        elif self.bandwidth.dim() == 1: # Per-head bandwidth
            # Unsqueeze to [batch_size * num_heads, 1] for broadcasting
            self.bandwidth = self.bandwidth.unsqueeze(-1)

        kernel = kernel_params.get('kernel', 'gaussian')
        self._define_kernels()
        if kernel not in self._kernel_fns:
            raise ValueError(f"Unknown kernel: {kernel}. Supported kernels are {list(self._kernel_fns.keys())}")
        self._kernel_fn = self._kernel_fns[kernel]
        self._kernel_deriv_fn = self._kernel_deriv_fns[kernel]

        if t is not None:
             # Separated mode
            self.register_buffer('_t', t)
            # coeffs are pure features in this mode
            self._x = coeffs
        else:
             # Embedded mode
            self.register_buffer('_t', coeffs[0, :, 0])
            self._x = coeffs[..., 1:]

    def _define_kernels(self):
        self._kernel_fns = {
            'gaussian': _kernel_gaussian, 'laplacian': _kernel_laplacian, 'cauchy': _kernel_cauchy,
            'epanechnikov': _kernel_epanechnikov, 'tricube': _kernel_tricube,
        }
        self._kernel_deriv_fns = {
            'gaussian': _kernel_deriv_gaussian, 'laplacian': _kernel_deriv_laplacian, 'cauchy': _kernel_deriv_cauchy,
            'epanechnikov': _kernel_deriv_epanechnikov, 'tricube': _kernel_deriv_tricube,
        }

    @property
    def grid_points(self):
        return self._t

    @property
    def interval(self):
        return torch.stack([self._t[0], self._t[-1]])

    def evaluate(self, t_eval):
        # t_obs_grid: [seq_len]
        t_obs_grid = self._t
        # x_obs: [batch_size * num_heads, seq_len, data_channels]
        x_obs = self._x

        # If scalar, expand to the parallel batch size.
        t_eval_b = t_eval.expand(self.coeffs.size(0)) if t_eval.dim() == 0 else t_eval

        # time_diff: [batch_size * num_heads, seq_len]
        time_diff = t_eval_b.unsqueeze(-1) - t_obs_grid
        
        # kernel_values: [batch_size * num_heads, seq_len]
        kernel_values = self._kernel_fn(time_diff, self.bandwidth)

        # combined_weights: [batch_size * num_heads, seq_len, 1]
        combined_weights = self.weights * kernel_values.unsqueeze(-1)

        # weight_sum (D): [batch_size * num_heads, 1, 1]
        weight_sum = combined_weights.sum(dim=1, keepdim=True).clamp(min=1e-9)
        
        # numerator (N): [batch_size * num_heads, data_channels]
        numerator = (combined_weights * x_obs).sum(dim=1)

        # result: [batch_size * num_heads, data_channels]
        result = numerator / weight_sum.squeeze(1)

        if self.include_time:
            return torch.cat([t_eval_b.unsqueeze(-1), result], dim=-1)
        else:
            return result

    def derivative(self, t_eval):
        # t_obs_grid: [seq_len]
        t_obs_grid = self._t
        
        # x_obs: [batch_size * num_heads, seq_len, data_channels]
        x_obs = self._x

        t_eval_b = t_eval.expand(self.coeffs.size(0)) if t_eval.dim() == 0 else t_eval

        # time_diff: [batch_size * num_heads, seq_len]
        time_diff = t_eval_b.unsqueeze(-1) - t_obs_grid
        
        # K: [batch_size * num_heads, seq_len]
        K = self._kernel_fn(time_diff, self.bandwidth)
        
        # K_deriv: [batch_size * num_heads, seq_len]
        K_deriv = self._kernel_deriv_fn(time_diff, self.bandwidth)
        
        # W: [batch_size * num_heads, seq_len]
        W = self.weights.squeeze(-1)
        
        # D: [batch_size * num_heads, 1]
        D = (W * K).sum(dim=-1, keepdim=True)
        
        # N: [batch_size * num_heads, data_channels]
        N = ((W * K).unsqueeze(-1) * x_obs).sum(dim=-2)
        
        # D_deriv: [batch_size * num_heads, 1]
        D_deriv = (W * K_deriv).sum(dim=-1, keepdim=True)
        
        # N_deriv: [batch_size * num_heads, data_channels]
        N_deriv = ((W * K_deriv).unsqueeze(-1) * x_obs).sum(dim=-2)
        
        denominator_sq = D.pow(2).clamp(min=1e-20)
        
        # deriv_x: [batch_size * num_heads, data_channels]
        deriv_x = (N_deriv * D - N * D_deriv) / denominator_sq

        # deriv_t: [batch_size * num_heads, 1]
        if self.include_time:
            deriv_t = torch.ones(t_eval_b.size(0), 1, device=self.coeffs.device)
            return torch.cat([deriv_t, deriv_x], dim=-1)
        else:
            return deriv_x
            
    def get_debug_info(self, t_eval):
        t_obs_grid = self._t 
        x_obs = self._x 
        BH = self.coeffs.size(0)
        t_eval_expanded = t_eval.unsqueeze(0).unsqueeze(-1).expand(BH, -1, 1)
        t_grid = t_obs_grid.unsqueeze(0).unsqueeze(0)
        time_diff = t_eval_expanded - t_grid
        bw_expanded = self.bandwidth.unsqueeze(-1)
        kernel_vals = self._kernel_fn(time_diff, bw_expanded) 
        weights_expanded = self.weights.transpose(1, 2) 
        combined = weights_expanded * kernel_vals
        D = combined.sum(dim=-1, keepdim=True) 
        D_clamped = D.clamp(min=1e-9)
        x_expanded = x_obs.unsqueeze(1)
        N = (combined.unsqueeze(-1) * x_expanded).sum(dim=2)
        result = N / D_clamped
        if self.include_time:
             t_res = t_eval.unsqueeze(0).expand(BH, -1).unsqueeze(-1)
             result = torch.cat([t_res, result], dim=-1)
        return result, D, {'combined_weights': combined}

class GPInterpolation(torchcde.interpolation_base.InterpolationBase):
    def __init__(self, coeffs, t=None, gp_params={}, include_time=True, **kwargs):
        super(GPInterpolation, self).__init__(**kwargs)
        
        # coeffs: [batch, seq_len, channels]
        self.length_scale = gp_params.get('length_scale', 1.0)
        self.noise_std = gp_params.get('noise_std', 1e-2)
        self.include_time = include_time
        
        if t is not None:
            self.register_buffer('_t', t)
            self.register_buffer('_x', coeffs)
        else:
            self.register_buffer('_t', coeffs[0, :, 0])
            self.register_buffer('_x', coeffs[..., 1:])
            
        # K(T, T)
        K_tt = _rbf_kernel_matrix(self._t, self._t, self.length_scale)
        # Add noise to diagonal
        K_tt = K_tt + (self.noise_std**2) * torch.eye(self._t.size(0), device=self._t.device)
        
        # Solve (K + sigma^2 I) alpha = Y
        # _x is [Batch, Seq, Feat], permute to [Batch, Feat, Seq] for solve if needed, 
        # but torch.linalg.solve supports broadcasting.
        # equation: K_tt @ alpha = _x (on seq dim)
        # alpha should be [Batch, Seq, Feat]
        
        # We need broadcast: K_tt @ alpha_i = x_i
        self.alpha = torch.linalg.solve(K_tt, self._x)
        
    @property
    def grid_points(self): return self._t

    @property
    def interval(self): return torch.stack([self._t[0], self._t[-1]])
    
    def _evaluate_or_derivative(self, t_eval, get_derivative=False):
        # t_eval can be scalar or batch of scalars
        t_eval_b = t_eval.expand(self._x.size(0)) if t_eval.dim() == 0 else t_eval
        
        # K_star: K(t_eval, t_grid) -> [1, Seq] (if t_eval scalar) or [Batch, Seq]
        
        # t_eval_b: [Batch]
        # self._t: [Seq]
        
        if not get_derivative:
            # K_star: [Batch, Seq]
            diff = t_eval_b.unsqueeze(1) - self._t.unsqueeze(0)
            K_star = torch.exp(-0.5 * (diff / self.length_scale)**2)
            
            # Mean = K_star @ alpha
            # [Batch, Seq] @ [Batch, Seq, Feat] -> [Batch, Feat]
            # unsqueeze K_star to [Batch, 1, Seq]
            
            mu = torch.matmul(K_star.unsqueeze(1), self.alpha).squeeze(1)
            
            if self.include_time:
                return torch.cat([t_eval_b.unsqueeze(-1), mu], dim=-1)
            else:
                return mu
        else:
            # dMu/dt = dK_star/dt @ alpha
            diff = t_eval_b.unsqueeze(1) - self._t.unsqueeze(0) # [Batch, Seq]
            K_star = torch.exp(-0.5 * (diff / self.length_scale)**2)
            dK_star = K_star * (-diff / (self.length_scale**2))
            
            d_mu = torch.matmul(dK_star.unsqueeze(1), self.alpha).squeeze(1)
            
            if self.include_time:
                # Derivative of time channel is 1
                dt = torch.ones_like(t_eval_b).unsqueeze(-1)
                return torch.cat([dt, d_mu], dim=-1)
            else:
                return d_mu
                
    def evaluate(self, t_eval): return self._evaluate_or_derivative(t_eval, False)
    def derivative(self, t_eval): return self._evaluate_or_derivative(t_eval, True)
    
    def get_debug_info(self, t_eval):
        val = self.evaluate(t_eval)
        denom = torch.ones(val.shape[0], val.shape[1], 1, device=val.device)
        return val, denom, {}
    

    

class WeightedGPInterpolation(torchcde.interpolation_base.InterpolationBase):
    def __init__(self, coeffs, weights, t=None, gp_params={}, include_time=True, **kwargs):
        super(WeightedGPInterpolation, self).__init__(**kwargs)
        
        # coeffs: [Batch * Heads, Seq, Feat]
        # weights: [Batch * Heads, Seq] (Normalized usually by softmax in QFormer)
        
        self.length_scale = gp_params.get('length_scale', 1.0)
        self.base_noise_std = gp_params.get('noise_std', 1e-2)
        self.include_time = include_time

        if isinstance(self.length_scale, torch.Tensor):
            self.length_scale = self.length_scale.view(-1, 1, 1) 

        if t is not None:
            self.register_buffer('_t', t)
            self._x = coeffs
        else:
            self.register_buffer('_t', coeffs[0, :, 0])
            self._x = coeffs[..., 1:]

        diff = self._t.unsqueeze(1) - self._t.unsqueeze(0) # [Seq, Seq]
        diff_sq = diff.pow(2) # [Seq, Seq]
        
        K_tt = torch.exp(-0.5 * diff_sq.unsqueeze(0) / (self.length_scale**2))
        
        epsilon = 1e-5
        
        noise_diag_vals = (self.base_noise_std ** 2) / (weights + epsilon) # [Batch, Seq]
        
        noise_mat = torch.diag_embed(noise_diag_vals) # [Batch, Seq, Seq]
        
        LHS = K_tt + noise_mat
        
        self.alpha = torch.linalg.solve(LHS, self._x)
        
    @property
    def grid_points(self): return self._t
    @property
    def interval(self): return torch.stack([self._t[0], self._t[-1]])

    def evaluate(self, t_eval):
        # t_eval: scalar or [Batch] (assuming scalar for CDE usually)
        t_eval_b = t_eval.expand(self._x.size(0)) if t_eval.dim() == 0 else t_eval
        
        # Diff: [Batch, 1, Seq]
        diff = t_eval_b.unsqueeze(1).unsqueeze(2) - self._t.view(1, 1, -1)
        
        # K_star: [Batch, 1, Seq]
        K_star = torch.exp(-0.5 * diff.pow(2) / (self.length_scale**2))
        
        # Mean = K_star @ alpha
        # [Batch, 1, Seq] @ [Batch, Seq, Feat] -> [Batch, 1, Feat]
        mu = torch.matmul(K_star, self.alpha).squeeze(1)
        
        if self.include_time:
            return torch.cat([t_eval_b.unsqueeze(-1), mu], dim=-1)
        else:
            return mu
            
    def derivative(self, t_eval):
        t_eval_b = t_eval.expand(self._x.size(0)) if t_eval.dim() == 0 else t_eval
        
        diff = t_eval_b.unsqueeze(1).unsqueeze(2) - self._t.view(1, 1, -1) # [Batch, 1, Seq]
        K_star = torch.exp(-0.5 * diff.pow(2) / (self.length_scale**2))
        
        # dK/dt = K * (-diff / l^2)
        dK_star = K_star * (-diff / (self.length_scale**2))
        
        d_mu = torch.matmul(dK_star, self.alpha).squeeze(1)
        
        if self.include_time:
            dt = torch.ones_like(t_eval_b).unsqueeze(-1)
            return torch.cat([dt, d_mu], dim=-1)
        else:
            return d_mu