import torch
import torch.nn as nn
from utils.matrix_processing import compute_pdist_sq

class Kernel(nn.Module):
    def __init__(self, kernel_type='linear', gamma=1.0, degree=3., coef0=1., is_trainable=False):
        super(Kernel, self).__init__()
        self.kernel_type = kernel_type
        self.log_gamma = nn.Parameter(torch.log(torch.tensor(gamma, dtype=torch.float32)), requires_grad=is_trainable)
        self.log_degree = nn.Parameter(torch.log(torch.tensor(degree, dtype=torch.float32)), requires_grad=is_trainable)
        self.coef0 = nn.Parameter(torch.tensor(coef0, dtype=torch.float32), requires_grad=is_trainable)
        

    def forward(self, X1, X2=None):
        if X2 is None:
            X2 = X1 # remove .clone()

        if self.kernel_type == 'linear':
            return torch.matmul(X1, X2.T)
        elif self.kernel_type == 'rbf':
            # exp(-1/(2*gamma) * ||x - y||^2)
            # return torch.exp(-compute_pdist_sq(X1 / torch.exp(self.log_gamma/2), X2 / torch.exp(self.log_gamma/2)) / 2)
            inv_length_scale = torch.exp(-self.log_gamma / 2)
            X1_scaled = X1 * inv_length_scale
            X2_scaled = X2 * inv_length_scale
            dists_sq = torch.cdist(X1_scaled, X2_scaled, p=2) ** 2
            return torch.exp(-dists_sq / 2)
        elif self.kernel_type == 'polynomial':
            return (torch.exp(self.log_gamma) * torch.matmul(X1, X2.T) + self.coef0) ** torch.exp(self.log_degree)
        elif self.kernel_type == 'sigmoid':
            return torch.tanh(torch.exp(self.log_gamma) * torch.matmul(X1, X2.T) + self.coef0)
        elif self.kernel_type == 'kronecker':
            # Kronecker delta kernel: returns 1 if x == y, 0 otherwise
            # Useful for discrete/categorical variables
            return (X1[:, None, :] == X2[None, :, :]).all(dim=-1).float()
        else:
            raise ValueError(f"Unsupported kernel type: {self.kernel_type}")



class BaseModel(nn.Module):
    """
    Base model class for kernel functions with feature extraction.

    Args:
        kernel_type (str): Type of kernel function to use.
        gamma (float): Parameter for the RBF, polynomial, exponential chi2
        and sigmoid kernels. Interpretation of the default value is left to
        the kernel. Ignored by other kernels.
        is_trainable (bool): Whether to learn the kernel parameters
            from the data.
        degree : Degree of the polynomial kernel. Ignored by other kernels.
        coef0 : Zero coefficient for polynomial and sigmoid kernels. Ignored by other kernels.
        feature_extractor_parameters (dict): Parameters for the feature extractor.

    """
    def __init__(self, kernel_type='linear', gamma=1.0, gamma_dim=1, degree=3., coef0=1., ridge_lambda=1e-4,
                 feature_extractor=None, is_trainable=False, gamma_init_method=None, **kwargs):
        super(BaseModel, self).__init__()
        gamma = gamma if gamma_dim == 1 else [[gamma]*gamma_dim]
        self.ridge_lambda = nn.Parameter(torch.tensor(ridge_lambda, dtype=torch.float32), requires_grad=is_trainable)
        # gamma_init_method: how to initialize RBF kernel bandwidth from data
        # None = use config gamma value, 'variance' = use data variance, 'median' = median heuristic
        self.gamma_init_method = gamma_init_method
        # self.log_ridge_lambda = nn.Parameter(torch.log(torch.tensor(ridge_lambda, dtype=torch.float32)), requires_grad=is_trainable)
        self.kernel = Kernel(kernel_type=kernel_type, gamma=gamma, degree=degree, coef0=coef0, 
                             is_trainable=is_trainable)
        self.feature_extractor = feature_extractor if feature_extractor is not None \
                                    else nn.Identity()
        self.is_trainable = is_trainable
        self._train_feature = None
        self._kernel_matrix = None

    @property
    def train_feature(self):
        if self._train_feature is None:
            raise ValueError("Training features have not been set. Call `set_kernel_matrix` first.")
        return self._train_feature

    @property
    def kernel_matrix(self):
        if self._kernel_matrix is None:
            raise ValueError("Kernel matrix has not been computed. Call `set_kernel_matrix` first.")
        return self._kernel_matrix

    def set_kernel_matrix(self, train_X):
        """
        Sets the training features and kernel matrix.
        For trainable models, this function should be called after training.
        """
        # if train_X.shape[0] > 2000:
        #     indices = torch.randperm(train_X.shape[0])[:2000]
        #     train_X = train_X[indices]
        
        self._train_feature = self.feature_extractor(train_X).detach()
        if self.gamma_init_method is not None and self.kernel.kernel_type == 'rbf' and self.is_trainable == False:
            self.set_gamma_from_data(self._train_feature.clone().detach())
        self._kernel_matrix = self.kernel(self._train_feature).detach()

    def set_gamma_from_data(self, features):
        """
        Set the RBF kernel bandwidth (gamma) based on data statistics.
        Sets gamma per dimension for multidimensional data.
        
        Args:
            data: Input tensor of shape (n, d) to compute statistics from
            method: Method to compute gamma:
                - 'variance': gamma = var(data) per dimension
                - 'median': gamma = median of pairwise distances per dimension
                - 'std': gamma = std(data) per dimension
        """
        if self.kernel.kernel_type != 'rbf':
            return  # Only applicable for RBF kernel
        
        with torch.no_grad():
            
            # Ensure features is 2D
            if features.dim() == 1:
                features = features.unsqueeze(1)
            
            n_dims = features.shape[1]
            
            if self.gamma_init_method == 'variance':
                # Set gamma to the variance of each dimension
                gamma = features.var(dim=0)
                # Avoid zero variance
                gamma = torch.clamp(gamma, min=1e-6)
            elif self.gamma_init_method == 'std':
                # Set gamma to the standard deviation of each dimension
                gamma = features.std(dim=0)
                gamma = torch.clamp(gamma, min=1e-6)
            elif self.gamma_init_method == 'median':
                # Median heuristic per dimension: gamma_d = median of |x_d - y_d|^2
                gamma_list = []
                for d in range(n_dims):
                    col = features[:, d:d+1]  # (n, 1)
                    dists = torch.cdist(col, col, p=2)  # (n, n)
                    # Get upper triangular (excluding diagonal)
                    triu_indices = torch.triu_indices(dists.shape[0], dists.shape[1], offset=1)
                    pairwise_dists = dists[triu_indices[0], triu_indices[1]]
                    median_dist = torch.median(pairwise_dists)
                    gamma_list.append(median_dist ** 2)
                gamma = torch.stack(gamma_list)
                gamma = torch.clamp(gamma, min=1e-6)
            else:
                raise ValueError(f"Unknown method: {self.gamma_init_method}. Use 'variance', 'std', or 'median'.")
            
            # Update log_gamma parameter - ensure shape matches
            self.kernel.log_gamma.data = torch.log(gamma)


    def forward(self, X1, X2=None):
        """
        Compute the kernel matrix for the input data
        """
        features_X1 = self.feature_extractor(X1)
        if X2 is not None:
            # K(X1, X2)
            features_X2 = self.feature_extractor(X2)
        else:
            if self._train_feature is not None:
                # K(X1, train_X)
                features_X2 = self.train_feature
            else:
                raise ValueError("X2 is None and training features are not set. Cannot compute kernel matrix.")

        return self.kernel(features_X1, features_X2)

class LinearModel(BaseModel):
    def __init__(self, ridge_lambda=1e-4, is_trainable=False, **kwargs):
        super(LinearModel, self).__init__(kernel_type='linear', 
                                         ridge_lambda=ridge_lambda, 
                                         feature_extractor=None, 
                                         is_trainable=is_trainable)

class KroneckerModel(BaseModel):
    def __init__(self, ridge_lambda=1e-4, is_trainable=False, **kwargs):
        super(KroneckerModel, self).__init__(kernel_type='kronecker', 
                                         ridge_lambda=ridge_lambda, 
                                         feature_extractor=None, 
                                         is_trainable=is_trainable)
class RBFModel(BaseModel):
    def __init__(self, input_dim, gamma=1.0, ridge_lambda=1e-4, is_trainable=False, gamma_init_method=None, **kwargs):
        super(RBFModel, self).__init__(kernel_type='rbf', gamma=gamma, gamma_dim=input_dim, 
                                      ridge_lambda=ridge_lambda, 
                                      feature_extractor=None, 
                                      is_trainable=is_trainable,
                                      gamma_init_method=gamma_init_method)

class FCModel(BaseModel):
    def __init__(self, kernel_type, input_dim, hidden_dim, gamma=1.0, degree=3., coef0=1., ridge_lambda=1e-4,
                 is_trainable=False, **kwargs):
        feature_extractor = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
        )
        super(FCModel, self).__init__(kernel_type, gamma, hidden_dim, degree, coef0, ridge_lambda, feature_extractor, 
                                      is_trainable)


class MLPModel(BaseModel):
    def __init__(
        self,
        kernel_type, input_dim, hidden_dim, output_dim, gamma=1.0, degree=3., coef0=1.,
        ridge_lambda=1e-4, dropout=0.5, is_trainable=False, **kwargs):
        layers = []
        prev_dim = input_dim
        
        # hidden layers
        for h in hidden_dim:
            layers.append(nn.Linear(prev_dim, h))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout))
            prev_dim = h
        
        # output layer
        layers.append(nn.Linear(prev_dim, output_dim))
        # optionally:
        # layers.append(nn.ReLU())
        layers.append(nn.Dropout(dropout))

        feature_extractor = nn.Sequential(*layers)
        
        # Freeze feature extractor if not trainable
        if not is_trainable:
            for param in feature_extractor.parameters():
                param.requires_grad = False

        super().__init__(
            kernel_type, gamma, output_dim, degree, coef0,
            ridge_lambda, feature_extractor, is_trainable
        )