from typing import Optional, Sequence

import torch
import torch.nn.functional as F
from torch import nn

from rel2abs.modules.anchor_selection import anchor_pruning
from rel2abs.modules.polar import (
    absolute_projection_list,
    anchor_augmentation_list,
    invert_anchors_list,
    relative_projection_list,
)


def svd_translation(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
    # """Compute the translation vector that aligns A to B using SVD."""
    assert A.size(1) == B.size(1)
    u, s, vt = torch.svd((B.T @ A).T)
    R = u @ vt.T
    return R, s


class LatentTranslation(nn.Module):
    def __init__(self, seed: int, centering: bool, std_correction: bool, l2_norm: bool, method: str) -> None:
        super().__init__()

        self.seed: int = seed
        self.centering: bool = centering
        self.std_correction: bool = std_correction
        self.l2_norm: bool = l2_norm
        self.method: str = method
        self.sigma_rank: Optional[float] = None

        self.translation_matrix: Optional[torch.Tensor]
        self.mean_encoding_anchors: Optional[torch.Tensor]
        self.mean_decoding_anchors: Optional[torch.Tensor]
        self.std_encoding_anchors: Optional[torch.Tensor]
        self.std_decoding_anchors: Optional[torch.Tensor]
        self.encoding_norm: Optional[torch.Tensor]
        self.decoding_norm: Optional[torch.Tensor]

    @torch.no_grad()
    def fit(self, encoding_anchors: torch.Tensor, decoding_anchors: torch.Tensor) -> None:
        if self.method == "absolute":
            return
        # First normalization: 0 centering
        if self.centering:
            mean_encoding_anchors: torch.Tensor = encoding_anchors.mean(dim=(0,))
            mean_decoding_anchors: torch.Tensor = decoding_anchors.mean(dim=(0,))
        else:
            mean_encoding_anchors: torch.Tensor = torch.as_tensor(0)
            mean_decoding_anchors: torch.Tensor = torch.as_tensor(0)

        if self.std_correction:
            std_encoding_anchors: torch.Tensor = encoding_anchors.std(dim=(0,))
            std_decoding_anchors: torch.Tensor = decoding_anchors.std(dim=(0,))
        else:
            std_encoding_anchors: torch.Tensor = torch.as_tensor(1)
            std_decoding_anchors: torch.Tensor = torch.as_tensor(1)

        self.encoding_dim: int = encoding_anchors.size(1)
        self.decoding_dim: int = decoding_anchors.size(1)

        self.register_buffer("mean_encoding_anchors", mean_encoding_anchors)
        self.register_buffer("mean_decoding_anchors", mean_decoding_anchors)
        self.register_buffer("std_encoding_anchors", std_encoding_anchors)
        self.register_buffer("std_decoding_anchors", std_decoding_anchors)

        self.register_buffer("encoding_norm", encoding_anchors.norm(p=2, dim=-1).mean())
        self.register_buffer("decoding_norm", decoding_anchors.norm(p=2, dim=-1).mean())

        encoding_anchors = (encoding_anchors - mean_encoding_anchors) / std_encoding_anchors
        decoding_anchors = (decoding_anchors - mean_decoding_anchors) / std_decoding_anchors

        # Second normalization: scaling
        if self.l2_norm:
            encoding_anchors = F.normalize(encoding_anchors, p=2, dim=-1)
            decoding_anchors = F.normalize(decoding_anchors, p=2, dim=-1)

        if self.method == "svd":
            # padding if necessary
            if encoding_anchors.size(1) < decoding_anchors.size(1):
                padded = torch.zeros_like(decoding_anchors)
                padded[:, : encoding_anchors.size(1)] = encoding_anchors
                encoding_anchors = padded
            elif encoding_anchors.size(1) > decoding_anchors.size(1):
                padded = torch.zeros_like(encoding_anchors)
                padded[:, : decoding_anchors.size(1)] = decoding_anchors
                decoding_anchors = padded

                self.encoding_anchors = encoding_anchors
                self.decoding_anchors = decoding_anchors

            translation_matrix, sigma = svd_translation(A=encoding_anchors, B=decoding_anchors)
            self.sigma_rank = (~sigma.isclose(torch.zeros_like(sigma), atol=1e-1).bool()).sum().item()
        elif self.method == "lstsq":
            translation_matrix = torch.linalg.lstsq(encoding_anchors, decoding_anchors).solution
        elif self.method == "lstsq+ortho":
            translation_matrix = torch.linalg.lstsq(encoding_anchors, decoding_anchors).solution
            U, _, Vt = torch.svd(translation_matrix)
            translation_matrix = U @ Vt.T
        else:
            raise NotImplementedError

        translation_matrix = torch.as_tensor(translation_matrix, dtype=torch.float32, device=encoding_anchors.device)
        self.register_buffer("translation_matrix", translation_matrix)

    def transform(self, X: torch.Tensor, compute_info: bool = True) -> torch.Tensor:
        if self.method == "absolute":
            return {"source": X, "target": X, "info": {}}

        encoding_x = (X - self.mean_encoding_anchors) / self.std_encoding_anchors

        if self.l2_norm:
            encoding_x = F.normalize(encoding_x, p=2, dim=-1)

        if self.method == "svd" and self.encoding_dim < self.decoding_dim:
            padded = torch.zeros(X.size(0), self.decoding_dim, device=X.device)
            padded[:, : self.encoding_dim] = encoding_x
            encoding_x = padded

        decoding_x = encoding_x @ self.translation_matrix

        decoding_x = decoding_x[:, : self.decoding_dim]

        # restore scale
        if self.l2_norm:
            decoding_x = decoding_x * self.decoding_norm

        # restore center
        decoding_x = (decoding_x * self.std_decoding_anchors) + self.mean_decoding_anchors

        info = {k: [] for k in ()}
        if compute_info:
            pass

        return {"source": encoding_x, "target": decoding_x, "info": info}


class RelativeLatentTranslation:
    def __init__(
        self,
        seed: int,
        delta: float,
        omega: int,
        centering: bool,
        std_correction: bool,
        l2_norm: bool,
        pooling_method: str,
        completion: bool = True,
    ) -> None:
        # pruning: delta
        # num_subspaces: omega
        self.seed: int = seed
        self.delta: float = delta
        self.omega: int = omega
        self.centering: bool = centering
        self.std_correction: bool = std_correction
        self.l2_norm: bool = l2_norm
        self.pooling_method: str = pooling_method
        self.completion: bool = completion

    @torch.no_grad()
    def fit(self, encoding_anchors: torch.Tensor, decoding_anchors: torch.Tensor) -> None:
        # encoding_space_column = dataset_config.encoding_column_template.format(encoder=encoding_space)
        # decoding_space_column = dataset_config.encoding_column_template.format(encoder=decoding_space)
        # anchor_data = data["train"].shuffle(seed=seed).select(list(range(num_anchors)))

        # encoder_test_data = data["test"].remove_columns(
        #     [
        #         column
        #         for column in data["test"].column_names
        #         if column not in {"index", label_column, encoding_space_column, decoding_space_column}
        #     ]
        # )
        # test_loader = DataLoader(encoder_test_data, batch_size=batch_size, pin_memory=True, shuffle=False, num_workers=0)

        # orig_acc = Accuracy(task="multiclass", num_classes=encoder_test_data.features["label"].num_classes).to(DEVICE)
        # rescaled_acc = Accuracy(task="multiclass", num_classes=encoder_test_data.features[label_column].num_classes).to(
        #     DEVICE
        # )
        # reconstruction_similarity = CosineSimilarity(reduction="mean").to(DEVICE)
        # reconstruction_mse = MeanAbsoluteError(reduction="mean").to(DEVICE)

        # encoding_anchors: torch.Tensor = anchor_data[encoding_space_column].to(DEVICE)  # .double()
        # decoding_anchors: torch.Tensor = anchor_data[decoding_space_column].to(DEVICE)  # .double()

        # encoder1_stats: Stats = encoder2stats[encoding_space]
        # encoder2_stats: Stats = encoder2stats[decoding_space]

        # First normalization: 0 centering
        if self.centering:
            mean_encoding_anchors: torch.Tensor = encoding_anchors.mean(dim=(0,))
            mean_decoding_anchors: torch.Tensor = decoding_anchors.mean(dim=(0,))
        else:
            mean_encoding_anchors: torch.Tensor = 0
            mean_decoding_anchors: torch.Tensor = 0

        if self.std_correction:
            std_encoding_anchors: torch.Tensor = encoding_anchors.std(dim=(0,))
            std_decoding_anchors: torch.Tensor = decoding_anchors.std(dim=(0,))
        else:
            std_encoding_anchors: torch.Tensor = 1
            std_decoding_anchors: torch.Tensor = 1

        self.mean_encoding_anchors = mean_encoding_anchors
        self.mean_decoding_anchors = mean_decoding_anchors
        self.std_encoding_anchors = std_encoding_anchors
        self.std_decoding_anchors = std_decoding_anchors

        self.encoding_norm = encoding_anchors.norm(p=2, dim=-1).mean()
        self.decoding_norm = decoding_anchors.norm(p=2, dim=-1).mean()

        encoding_anchors = (encoding_anchors - mean_encoding_anchors) / std_encoding_anchors
        decoding_anchors = (decoding_anchors - mean_decoding_anchors) / std_decoding_anchors

        # Second normalization: scaling
        if self.l2_norm:
            encoding_anchors = F.normalize(encoding_anchors, p=2, dim=-1)
            decoding_anchors = F.normalize(decoding_anchors, p=2, dim=-1)

        # decoding_anchors_original_rank = torch.linalg.matrix_rank(decoding_anchors.T).float().mean().cpu().item()
        # decoding_anchors_original_condition = decoding_anchors.T.norm(p=2) * torch.linalg.pinv(decoding_anchors.T).norm(
        #     p=2
        # )

        # decoding_anchors_original_condition = torch.ones(1)

        # anchor pruning
        # decoding_anchors_similarities = (decoding_anchors @ decoding_anchors.T) - torch.eye(
        #     decoding_anchors.size(0), device=decoding_anchors.device
        # )
        # decoding_anchors_similarities = decoding_anchors_similarities.abs()

        # current_anchor_index: int = 0
        # while current_anchor_index < decoding_anchors_similarities.size(0):
        #     current_anchor_similarities = decoding_anchors_similarities[current_anchor_index, :]
        #     too_similar = (current_anchor_similarities >= anchor_threshold).nonzero()
        #     decoding_anchors_similarities[too_similar, :] = 0
        #     decoding_anchors_similarities[:, too_similar] = 0
        #     current_anchor_index += 1
        # anchors_to_keep = decoding_anchors_similarities.sum(dim=0) != 0

        # anchor pruning
        anchors_to_keep: Sequence[torch.Tensor] = list(
            pruned_space
            for subspace_seed in range(self.omega)
            for pruned_space in anchor_pruning(
                anchors=decoding_anchors, stop_distance=self.delta, random_seed=subspace_seed
            )
        )
        decoding_anchors: Sequence[torch.Tensor] = [
            decoding_anchors[sub_anchors_to_keep] for sub_anchors_to_keep in anchors_to_keep
        ]
        encoding_anchors: Sequence[torch.Tensor] = [
            encoding_anchors[sub_anchors_to_keep] for sub_anchors_to_keep in anchors_to_keep
        ]

        # anchor augmentation
        if self.completion:
            augmented_anchors = anchor_augmentation_list(
                encoding_anchors=encoding_anchors, decoding_anchors=decoding_anchors, centering=self.centering
            )
            decoding_anchors = [aug_anchors["decoding"] for aug_anchors in augmented_anchors]
            encoding_anchors = [aug_anchors["encoding"] for aug_anchors in augmented_anchors]

        self.decoding_anchors_inverse = invert_anchors_list(anchors=decoding_anchors, normalize_anchors=False)

        self.encoding_anchors = encoding_anchors
        self.decoding_anchors = decoding_anchors

        # subspace_condition_numbers = torch.stack(
        #     [
        #         subspace_anchors.T.norm(p=2) * inverse_subspace_anchors.norm(p=2)
        #         for subspace_anchors, inverse_subspace_anchors in zip(decoding_anchors, decoding_anchors_inverse)
        #     ]
        # )

    def transform(self, X: torch.Tensor, decoding_l2: bool = True, compute_info: bool = True) -> torch.Tensor:

        # scaling
        encoding_x = (X - self.mean_encoding_anchors) / self.std_encoding_anchors

        # relative encoding
        rel_encoding_x = relative_projection_list(
            x=encoding_x, anchors=self.encoding_anchors, rel_norm=False, abs_norm=True
        )
        rec_decoding_x = absolute_projection_list(
            rel_x=rel_encoding_x,
            anchor_inverse=self.decoding_anchors_inverse,
        )

        # subspace pooling
        if self.pooling_method == "mean":
            rec_decoding_x = rec_decoding_x.mean(dim=0)
            # kept_anchors = (
            #     torch.as_tensor(
            #         [sub_decoding_anchors.size(0) for sub_decoding_anchors in decoding_anchors], dtype=torch.float32
            #     )
            #     .mean()
            #     .cpu()
            #     .item()
            # )
        # elif self.pooling_method == "anchor_best":
        #     scaled_sub_decoding_x = F.normalize(rec_decoding_x, dim=-1, p=2) * encoder2_stats.norm_mean.abs()
        #     scaled_sub_decoding_x = (scaled_sub_decoding_x * std_decoding_anchors) + mean_decoding_anchors
        #     # TODO: selection metric. What do we want to consider? Angles?
        #     subspace_scores = F.cosine_similarity(decoding_x, scaled_sub_decoding_x, dim=-1).mean(dim=1)
        #     best_subspace = subspace_scores.argmax(dim=0)
        #     rec_decoding_x = rec_decoding_x[best_subspace, :, :]
        #     kept_anchors = decoding_anchors[best_subspace].size(0)

        if decoding_l2:
            rec_decoding_x = F.normalize(rec_decoding_x, p=2, dim=-1)

        # restore scale
        rec_decoding_x = rec_decoding_x * self.decoding_norm

        # restore center
        rec_decoding_x = (rec_decoding_x * self.std_decoding_anchors) + self.mean_decoding_anchors

        info = {
            k: []
            for k in (
                "post_pruning_anchors",
                "kept_anchors",
                "decoding_anchors_original_rank",
                "decoding_anchors_original_condition",
                "mean_subspace_condition_number",
            )
        }
        if compute_info:
            pass
            # post_pruning_anchors=torch.as_tensor(
            #     [sub_anchors_to_keep.numel() for sub_anchors_to_keep in anchors_to_keep]
            # )
            # .float()
            # .mean()
            # .cpu()
            # .item(),

        return {"source": encoding_x, "relative": rel_encoding_x, "target": rec_decoding_x, "info": info}
