# models/reducer.py
import torch
import torch.nn as nn
import numpy as np
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from umap import UMAP
from joblib import dump, load
import warnings

class Autoencoder(nn.Module):
    """增强型自编码器"""
    def __init__(self, input_dim, latent_dim, dropout=0.2):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, latent_dim*4),
            nn.LeakyReLU(0.2),
            nn.Dropout(dropout),
            nn.Linear(latent_dim*4, latent_dim*2),
            nn.Tanh(),
            nn.Linear(latent_dim*2, latent_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, latent_dim*2),
            nn.LeakyReLU(0.2),
            nn.Linear(latent_dim*2, latent_dim*4),
            nn.Dropout(dropout),
            nn.Linear(latent_dim*4, input_dim)
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded

class DimensionalityReducer(nn.Module):
    """
    负责维度约简的模块，支持PCA, UMAP, t-SNE, Autoencoder等方法。
    """
    def __init__(self, method='pca', input_dim=128, latent_dim=32, device='cpu'):
        super().__init__()
        self.method = method
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.device = device

        if self.method == 'autoencoder':
            self.encoder = nn.Sequential(
                nn.Linear(self.input_dim, self.latent_dim * 2),
                nn.ReLU(),
                nn.Linear(self.latent_dim * 2, self.latent_dim),
                nn.ReLU()
            ).to(self.device)
            self.decoder = nn.Sequential(
                nn.Linear(self.latent_dim, self.latent_dim * 2),
                nn.ReLU(),
                nn.Linear(self.latent_dim * 2, self.input_dim)
            ).to(self.device)
        elif self.method in ['pca', 'umap', 't-sne']:
            # 对于这些方法，我们将在 forward 方法中动态处理，它们通常需要 numpy 数组
            # 或者通过 scikit-learn 等库实现，这里不需要预先定义的 PyTorch 层
            pass
        else:
            raise ValueError(f"Unsupported dimensionality reduction method: {method}")

    def forward(self, x):
        """
        输入 x 可能是 [batch_size, seq_len, features] (3D) 或 [num_samples, features] (2D)。
        输出将是 [batch_size, seq_len, latent_dim] (3D) 或 [num_samples, latent_dim] (2D)，
        取决于原始输入 x 的维度。
        """
        original_shape = x.shape
        
        # 将输入展平为 2D 形式 (num_samples, features) 以便降维处理
        # 展平后：[batch_size * seq_len, features]
        if len(original_shape) == 3:
            # 如果输入是 3D ([batch, seq_len, features])
            # 则将其展平为 [batch * seq_len, features]
            x_flat = x.reshape(-1, original_shape[-1])
        elif len(original_shape) == 2:
            # 如果输入已经是 2D ([num_samples, features])
            x_flat = x
        else:
            raise ValueError(f"Unsupported input shape for reducer: {original_shape}. Expected 2D or 3D.")

        x_flat = x_flat.to(self.device)

        if self.method == 'autoencoder':
            compressed_flat = self.encoder(x_flat)
        elif self.method == 'pca':
            # 假设 x_flat 是 numpy 数组，如果不是则需要转换
            if isinstance(x_flat, torch.Tensor):
                x_flat_np = x_flat.cpu().numpy()
            else:
                x_flat_np = x_flat

            # 简单的PCA实现，这里可以替换为 sklearn 的 PCA
            # 为了避免额外引入 sklearn 依赖，这里使用简单的矩阵乘法模拟PCA (需要训练阶段计算主成分)
            # 考虑到当前是在前向传播，我们不会在这里进行PCA的fit操作
            # 假设 PCA 的转换矩阵 W 已经以某种方式传入或计算好 (在实际应用中会预训练)
            # 为了让代码可以运行，这里暂时返回一个随机或截断的线性变换，实际上需要一个训练好的PCA模型
            # 这是一个占位符，在实际使用PCA时需要替换为正确加载PCA模型并进行transform
            # 为了通过形状检查，这里临时使用线性层模拟PCA的降维
            # 注意：这并非真正的PCA，只是为了维度匹配
            if not hasattr(self, '_pca_linear'):
                 self._pca_linear = nn.Linear(x_flat.shape[-1], self.latent_dim).to(self.device)
                 warnings.warn("Using a random linear layer to simulate PCA. For real PCA, train/load a PCA model.")
            compressed_flat = self._pca_linear(x_flat)

        elif self.method in ['umap', 't-sne']:
            # UMAP 和 t-SNE 主要用于可视化，不适合作为模型前向传播的实时降维步骤
            # 它们通常需要 Fit/Transform 过程，且计算成本较高
            # 如果作为模型的一部分，需要实现其可微版本或使用近似方法
            # 这里也使用线性层作为占位符，以匹配维度
            if not hasattr(self, '_placeholder_linear'):
                 self._placeholder_linear = nn.Linear(x_flat.shape[-1], self.latent_dim).to(self.device)
                 warnings.warn(f"Using a random linear layer to simulate {self.method}. For real {self.method}, consider its applicability in forward pass.")
            compressed_flat = self._placeholder_linear(x_flat)
        else:
            raise ValueError(f"Unsupported dimensionality reduction method: {self.method}")

        # 将压缩后的结果重新塑形回原始批次和序列维度（如果原始输入是 3D）
        if len(original_shape) == 3:
            # [batch_size * seq_len, latent_dim] -> [batch_size, seq_len, latent_dim]
            compressed = compressed_flat.view(original_shape[0], original_shape[1], self.latent_dim)
        else:
            # 如果原始输入是 2D，则输出也保持 2D
            compressed = compressed_flat
        
        return compressed

    def fit(self, X):
        """统一训练/拟合接口"""
        if self.method != 'autoencoder':
            # 传统方法拟合
            X_np = X.detach().cpu().numpy() if isinstance(X, torch.Tensor) else np.array(X)
            if self.method == 't-sne':
                print(X_np)
                print(X_np.shape)
                self.model.fit(X_np)  # t-SNE不需要transform
            else:
                self.model.fit(X_np)
        
        self.is_fitted = True
    
    def _reduce_batch(self, X):
        """核心降维方法"""
        if not self.is_fitted:
            self.fit(X)
        if self.method == 'autoencoder':
            self.model.eval()
            with torch.no_grad():
                encoded, _ = self.model(X.to(self.device))
            return encoded.cpu()
            
        X_np = X.detach().cpu().numpy() if isinstance(X, torch.Tensor) else np.array(X)
        
        if self.method == 't-sne':
            return torch.from_numpy(self.model.fit_transform(X_np)).float()
            
        if self.method == 'umap':
            emb = self.model.transform(X_np)
            return torch.from_numpy(emb).float()
            
        return torch.from_numpy(self.model.transform(X_np)).float()

    def evaluate(self, X):
        """跨方法的降维质量评估"""
        if self.method == 'autoencoder':
            with torch.no_grad():
                _, reconstructed = self.model(X.to(self.device))
            return self.criterion(reconstructed, X).item()
        elif self.method == 'pca':
            return np.sum(self.model.explained_variance_ratio_)
        elif self.method == 'umap':
            emb = self.model.transform(X.cpu().numpy())
            return (emb.max(axis=0) - emb.min(axis=0)).mean()  # 跨度指标
        elif self.method == 't-sne':
            return 0.0  # t-SNE无明确评估指标
        return 0.0

    def save(self, path):
        """统一保存接口"""
        if self.method == 'autoencoder':
            torch.save({
                'model': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict()
            }, path)
        else:
            dump(self.model, path)

    def load(self, path):
        """统一加载接口"""
        if self.method == 'autoencoder':
            checkpoint = torch.load(path, map_location=self.device)
            self.model.load_state_dict(checkpoint['model'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
        else:
            self.model = load(path)

    def to(self, device):
        """设备转移增强"""
        if self.method == 'autoencoder':
            super().to(device)
            self.device = device
        return self

# 使用示例
if __name__ == "__main__":
    # 初始化
    reducer = DimensionalityReducer(method='t-sne', input_dim=128, latent_dim=5)
    
    # 生成模拟数据
    dummy_data = torch.randn(1000, 128)
    
    # 训练/拟合
    reducer.fit(dummy_data)
    
    # 执行降维
    reduced_data = reducer._reduce_batch(dummy_data[:40])
    print(f"降维结果尺寸：{reduced_data.shape}")
    
    # 质量评估
    quality = reducer.evaluate(dummy_data)
    print(f"降维质量指标：{quality:.4f}")
    
    # 保存模型
    # reducer.save_model("dim_reducer.pkl")