import torch
import numpy as np
from typing import Union
def procrustes_mapping_torch(corpus_emb_1: torch.Tensor, corpus_emb_2: torch.Tensor, overlap_index: torch.Tensor, transformed_candidate: torch.Tensor, apporximate: bool = True, q: int = 1500) -> np.ndarray:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if not isinstance(corpus_emb_1, torch.Tensor):
        corpus_emb_1 = torch.tensor(corpus_emb_1, device=device, dtype=torch.float32)
    if not isinstance(corpus_emb_2, torch.Tensor):
        corpus_emb_2 = torch.tensor(corpus_emb_2, device=device, dtype=torch.float32)
    if not isinstance(overlap_index, torch.Tensor):
        overlap_index = torch.tensor(overlap_index, device=device, dtype=torch.int64)
    if not isinstance(transformed_candidate, torch.Tensor):
        transformed_candidate = torch.tensor(transformed_candidate, device=device, dtype=torch.float32)
    origin = corpus_emb_1[overlap_index]
    target = corpus_emb_2[overlap_index]
    origin_centroid = torch.mean(origin, dim=0)
    target_centroid = torch.mean(target, dim=0)
    origin_norm = torch.norm(origin)
    target_norm = torch.norm(target)
    origin_centered = (origin - origin_centroid) / origin_norm
    target_centered = (target - target_centroid) / target_norm
    covariance_matrix = torch.mm(origin_centered.T, target_centered)
    if apporximate:
        U, S, Vt = torch.svd_lowrank(covariance_matrix, q=q)
    else:
        U, S, Vt = torch.linalg.svd(covariance_matrix)
    rotation_matrix = torch.mm(U, Vt.T)
    transformed_centered = (transformed_candidate - origin_centroid) / origin_norm
    transformed_candidate_centered = torch.mm(transformed_centered, rotation_matrix.T) * target_norm + target_centroid
    return transformed_candidate_centered.cpu().numpy()
def procrustes_mapping_torch(corpus_emb_1: torch.Tensor, corpus_emb_2: torch.Tensor, overlap_index: torch.Tensor, transformed_candidate: torch.Tensor, target_candidate: torch.Tensor, approximate: bool = True, q: int = 1500, with_rotation: bool = True) -> np.ndarray:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if not isinstance(corpus_emb_1, torch.Tensor):
        corpus_emb_1 = torch.tensor(corpus_emb_1, device=device, dtype=torch.float32)
    if not isinstance(corpus_emb_2, torch.Tensor):
        corpus_emb_2 = torch.tensor(corpus_emb_2, device=device, dtype=torch.float32)
    if not isinstance(overlap_index, torch.Tensor):
        overlap_index = torch.tensor(overlap_index, device=device, dtype=torch.int64)
    if not isinstance(transformed_candidate, torch.Tensor):
        transformed_candidate = torch.tensor(transformed_candidate, device=device, dtype=torch.float32)
    if not isinstance(target_candidate, torch.Tensor):
        target_candidate = torch.tensor(target_candidate, device=device, dtype=torch.float32)
    origin = corpus_emb_1[overlap_index]
    target = corpus_emb_2[overlap_index]
    origin_centroid = torch.mean(origin, dim=0)
    target_centroid = torch.mean(target, dim=0)
    origin_norm = torch.norm(origin)
    target_norm = torch.norm(target)
    origin_centered = (origin - origin_centroid) / origin_norm
    target_centered = (target - target_centroid) / target_norm
    if with_rotation:
        covariance_matrix = torch.mm(origin_centered.T, target_centered)
        if approximate:
            U, S, Vt = torch.svd_lowrank(covariance_matrix, q=q)
        else:
            U, S, Vt = torch.linalg.svd(covariance_matrix)
        rotation_matrix = torch.mm(U, Vt.T)
    else:
        rotation_matrix = torch.eye(origin_centered.shape[1], device=device)
    transformed_candidate_centered = torch.mm((transformed_candidate - origin_centroid) / origin_norm, rotation_matrix.T) * target_norm + target_centroid
    a = 1
    beta_origin = torch.norm(transformed_candidate - target_candidate, dim=1).max()
    beta_target = torch.norm(transformed_candidate_centered - target_candidate, dim=1).max()
    beta_center = torch.norm(((transformed_candidate - origin_centroid) / origin_norm) - ((target_candidate - target_centroid) / target_norm), dim=1).max()
    print(f"beta_origin: {beta_origin} => beta_target: {beta_target} [beta_center: {beta_center}]")
    return transformed_candidate_centered.cpu().numpy()
def procrustes_mapping(corpus_emb_1: np.ndarray, corpus_emb_2: np.ndarray, overlap_index: np.ndarray, transformed_candidate: np.ndarray) -> np.ndarray:
    origin = corpus_emb_1[overlap_index]
    target = corpus_emb_2[overlap_index]
    origin_centroid = np.mean(origin, axis=0)
    target_centroid = np.mean(target, axis=0)
    origin_norm = np.linalg.norm(origin)
    target_norm = np.linalg.norm(target)
    origin_centered = (origin - origin_centroid) / origin_norm
    target_centered = (target - target_centroid) / target_norm
    covariance_matrix = np.dot(origin_centered.T, target_centered)
    U, S, Vt = np.linalg.svd(covariance_matrix)
    rotation_matrix = np.dot(U, Vt)
    transformed_centroid = np.mean(transformed_candidate, axis=0)
    transformed_norm = np.linalg.norm(transformed_candidate)
    transformed_candidate_centered = np.dot((transformed_candidate - transformed_centroid)/transformed_norm, rotation_matrix.T) * target_norm + target_centroid
    return transformed_candidate_centered
def procrustes_no_norm_scale_with_param(
    corpus_emb_1: torch.Tensor, 
    corpus_emb_2: torch.Tensor, 
    overlap_index: torch.Tensor, 
    transformed_candidate: torch.Tensor = None, 
    target_candidate: torch.Tensor = None, 
    apporximate: bool = True, 
    q: int = 1500, 
    with_rotation: bool = True, 
    params: dict = None
) -> Union[np.ndarray, dict]:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    origin = corpus_emb_1[overlap_index]
    target = corpus_emb_2[overlap_index]
    if transformed_candidate is None:
        transformed_candidate = origin
    if target_candidate is None:
        target_candidate = target
    if not isinstance(origin, torch.Tensor):
        origin = torch.tensor(origin, device=device, dtype=torch.float32)
    if not isinstance(target, torch.Tensor):
        target = torch.tensor(target, device=device, dtype=torch.float32)
    if not isinstance(transformed_candidate, torch.Tensor):
        transformed_candidate = torch.tensor(transformed_candidate, device=device, dtype=torch.float32)
    if not isinstance(target_candidate, torch.Tensor):
        target_candidate = torch.tensor(target_candidate, device=device, dtype=torch.float32)
    origin_centroid = torch.mean(origin, dim=0)
    target_centroid = torch.mean(target, dim=0)
    origin_centered = origin - origin_centroid
    target_centered = target - target_centroid
    if transformed_candidate.dim() == 1:
        transformed_candidate = transformed_candidate.unsqueeze(0)
    if params is not None and 'rotation_matrix' in params and 'scaling_factor' in params and 'target_centroid' in params:
        rotation_matrix = torch.tensor(params['rotation_matrix'], device=device, dtype=torch.float32)
        k = params['scaling_factor']
        target_centroid = torch.tensor(params['target_centroid'], device=device, dtype=torch.float32)
    else:
        covariance_matrix = torch.mm(origin_centered.T, target_centered)
        if apporximate:
            U, S, Vt = torch.svd_lowrank(covariance_matrix, q=q)
        else:
            U, S, Vt = torch.linalg.svd(covariance_matrix)
        rotation_matrix = torch.mm(U, Vt.T)
        k = torch.norm(S, p='fro') / torch.trace(torch.mm(origin_centered, origin_centered.T))
        if torch.isnan(k):
            k = 1
        params = {
            'rotation_matrix': rotation_matrix.cpu().numpy(),
            'scaling_factor': k,
            'target_centroid': target_centroid.cpu().numpy()
        }
    candidate_centered = transformed_candidate - origin_centroid
    transformed_candidate_scaled = k * torch.mm(candidate_centered, rotation_matrix) + target_centroid
    return transformed_candidate_scaled.cpu().numpy(), params
def procrustes_pca_mapping(corpus_emb_1: torch.Tensor, corpus_emb_2: torch.Tensor, overlap_index: torch.Tensor, transformed_candidate: torch.Tensor, target_candidate: torch.Tensor, apporximate: bool = True, q: int = 1500, with_rotation: bool = True, reduced_dim: int = 5) -> np.ndarray:
    from sklearn.decomposition import PCA
    reduced_dim = min(len(overlap_index), reduced_dim)
    pca_corpus_emb_1 = PCA(n_components=reduced_dim, svd_solver='randomized')
    pca_corpus_emb_2 = PCA(n_components=reduced_dim, svd_solver='randomized')
    pca_corpus_emb_1.fit(corpus_emb_1[overlap_index])
    pca_corpus_emb_2.fit(corpus_emb_2[overlap_index])
    corpus_emb_1 = pca_corpus_emb_1.transform(corpus_emb_1)
    corpus_emb_2 = pca_corpus_emb_2.transform(corpus_emb_2)
    transformed_candidate = pca_corpus_emb_1.transform(transformed_candidate)
    return pca_corpus_emb_2.inverse_transform(procrustes_no_norm_scale_with_param(corpus_emb_1, corpus_emb_2, overlap_index, transformed_candidate, target_candidate, apporximate, q, with_rotation))
