import torch
import torch.nn as nn
import numpy as np

class Tora(nn.Module):
    """
    Removes principal components of word representations based on an elbow-finding method.
    Processes data in batch mode: shape [B, T, D].
    """
    isactive_ = False

    @classmethod
    def activate(cls):
        cls.isactive_ = True

    @classmethod
    def isactive(cls):
        return cls.isactive_

    def __init__(self, smooth_window=1, log_transform=False):
        """
        Args:
            smooth_window (int): Smoothing window size for the chord method.
            log_transform (bool): Whether to apply log to the singular values before elbow detection.
        """
        super().__init__()
        self.smooth_window = smooth_window
        self.log_transform = log_transform
        self.U = None
        self.S = None
        self.V = None
        self.elbow = None
        
    def fit(self, x, contrast_vector=None):
        """
        Fits PCA on input data, and finds the elbow point in the singular value curve.
        
        Args:
            x (torch.Tensor): Input tensor of shape [B, T, D].
            contrast_vector (torch.Tensor, optional): Contrast vector to enhance direction finding.
        """
        # Compute mean and center the data
        x_mean = x.mean(dim=1, keepdim=True)
        x_centered = x - x_mean
        
        # Reshape to 2D for SVD: [B*T, D]
        batch_size, seq_len, dim = x.shape
        x_flat = x_centered.reshape(-1, dim)
        
        # Compute SVD
        self.U, self.S, self.V = torch.svd(x_flat)
        
        # Find the elbow point in the singular values
        singular_values = self.S.cpu().numpy()
        
        # Apply log transform if specified
        if self.log_transform:
            singular_values = np.log(singular_values + 1e-10)
        
        # Find elbow using the chord method
        self.elbow = self._find_elbow_chord(singular_values)
        
        return self
        
    def _find_elbow_chord(self, singular_values):
        """
        Finds the elbow point in the singular value curve using the chord method.
        
        Args:
            singular_values (numpy.ndarray): Array of singular values.
            
        Returns:
            int: Index of the elbow point.
        """
        # Apply smoothing if window size > 1
        if self.smooth_window > 1:
            kernel = np.ones(self.smooth_window) / self.smooth_window
            singular_values = np.convolve(singular_values, kernel, mode='same')
        
        # Create points for the chord method
        points = np.array([(i, s) for i, s in enumerate(singular_values)])
        
        # Start and end points of the chord
        start, end = points[0], points[-1]
        
        # Calculate distances from points to the chord
        line_vec = end - start
        line_mag = np.linalg.norm(line_vec)
        unit_line_vec = line_vec / line_mag
        vec_from_start = points - start
        
        # Calculate the scalar projection and then the vector projection
        scalar_proj = np.dot(vec_from_start, unit_line_vec)
        proj = np.outer(scalar_proj, unit_line_vec)
        
        # Calculate perpendicular distances
        distances = np.linalg.norm(vec_from_start - proj, axis=1)
        
        # Find the point with maximum distance
        elbow_idx = np.argmax(distances)
        
        return elbow_idx
    
    def __call__(self, x):
        """
        Removes the principal components beyond the elbow point.
        
        Args:
            x (torch.Tensor): Input tensor of shape [B, T, D].
            
        Returns:
            torch.Tensor: Transformed tensor with principal components removed.
        """
        if self.elbow is None:
            return x
        
        # Center the data
        x_mean = x.mean(dim=1, keepdim=True)
        x_centered = x - x_mean
        
        # Reshape to 2D: [B*T, D]
        batch_size, seq_len, dim = x.shape
        x_flat = x_centered.reshape(-1, dim)
        
        # Project data onto principal components
        proj = torch.matmul(x_flat, self.V)
        
        # Zero out components beyond the elbow
        mask = torch.ones_like(proj)
        mask[:, self.elbow:] = 0
        filtered_proj = proj * mask
        
        # Project back to original space
        x_filtered_flat = torch.matmul(filtered_proj, self.V.t())
        
        # Reshape back to [B, T, D] and add mean back
        x_filtered = x_filtered_flat.reshape(batch_size, seq_len, dim)
        x_filtered = x_filtered + x_mean
        
        return x_filtered