from typing import (Any, Callable, Iterable, Iterator, Optional, Sequence,
                    Tuple, Union)

import numpy as np
import torch
import torch.utils.data
import torch_kmeans
from tqdm import notebook as tqdm


class SphericalSoftKMeans(torch_kmeans.SoftKMeans):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.distance = None

    @torch.no_grad
    def _cluster_iter(self, x: torch.Tensor, centers: torch.Tensor) -> torch.Tensor:
        # x: (BS, N, D), centers: (BS, num_init, K, D) -> dist: (BS, num_init, N, K)
        bs, n, d = x.size()
        _, num_init, k, _ = centers.size()
        dist = self._pairwise_distance(x, centers)
        # mask probability for non-existing centers with -inf
        msk = dist == 0  # | (dist == float("inf")) | torch.isnan(dist)
        dist = dist.clone()
        dist[msk] = float("-inf")
        # get soft cluster assignments
        c_assign = torch.softmax(self.temp * dist, dim=-1)
        per_cluster = c_assign.sum(dim=-2)
        # update cluster centers
        # (BS, num_init, N, K)
        # -> (BS, num_init, K, 1, N) @ (BS, num_init, K, N, D)
        # -> (BS, num_init, K, D)

        cluster_mean = torch.einsum("bink,bnd->bikd", c_assign, x)
        # cluster_mean = (
        #    c_assign.permute(0, 1, 3, 2)[:, :, :, None, :]
        #    @ x[:, None, None, :, :].expand(bs, num_init, k, n, d)
        # ).squeeze(-2)

        centers = torch.diag_embed(1.0 / (per_cluster + self.eps)) @ cluster_mean
        # centers = torch.einsum("bik,bikd->bikd", 1.0 / (per_cluster + self.eps), cluster_mean)
        centers[msk.any(dim=-2)] = 0

        centers /= torch.norm(centers, p=2, dim=-1, keepdim=True) + self.eps

        return centers

    def _pairwise_distance(self, x: torch.Tensor, centers: torch.Tensor, **kwargs):
        """Calculate pairwise distances between samples in x and all centers."""
        bs, n, d = x.size()
        bs, num_init, k_max, d = centers.size()

        return 1.0 - torch.einsum("bnl,bikd->bink", x, centers)

    @staticmethod
    @torch.jit.script
    def _calculate_shift(
        centers: torch.Tensor, old_centers: torch.Tensor, p: int = 2
    ) -> torch.Tensor:
        """Calculate center shift w.r.t. centers from last iteration."""
        # calculate euclidean distance while replacing inf with 0 in sum
        d = torch.norm((centers - old_centers), p=p, dim=-1)
        d[d == float("inf")] = 0
        # sum(d, dim=-1)**2 -> use mean to be independent of number of points
        return torch.mean(d, dim=-1)

    def _cluster(self, x, centers, k, **kwargs):
        """
        Run soft version of Lloyd's k-means algorithm.

        Args:
            x: (BS, N, D)
            centers: (BS, num_init, k_max, D)
            k: (BS, )

        """
        bs, n, d = x.size()
        # mask centers for which  k < k_max with inf to get correct assignment
        k_max = torch.max(k).cpu().item()
        k_max_range = torch.arange(k_max, device=x.device)[None, :].expand(bs, -1)
        k_mask = k_max_range >= k[:, None]
        k_mask = k_mask[:, None, :].expand(bs, self.num_init, -1)

        all_shift_per_cluster = []

        # run soft k-means to convergence
        if self.verbose:
            range_ = tqdm.tqdm(range(self.max_iter), total=self.max_iter)
        else:
            range_ = range(self.max_iter)
        with torch.no_grad():
            for i in range_:
                centers[k_mask] = 0
                old_centers = centers.clone()

                # print(centers[0, :, 0])

                # update
                centers = self._cluster_iter(x, centers)
                # calculate center shift
                if self.tol is not None:
                    shift = self._calculate_shift(centers, old_centers, p=self.p_norm)
                    shift_per_cluster = torch.norm(
                        (centers - old_centers), p=self.p_norm, dim=-1
                    )[0, 0]
                    all_shift_per_cluster.append(shift_per_cluster)
                    if (shift < self.tol).all():
                        if self.verbose:
                            print(
                                f"Full batch converged at iteration "
                                f"{i + 1}/{self.max_iter} "
                                f"with center shifts = "
                                f"{shift.view(-1, self.num_init).mean(-1)}."
                            )
                        break

        if self.verbose and i == self.max_iter - 1:
            print(
                f"Full batch did not converge after {self.max_iter} "
                f"maximum iterations."
                f"\nThere were some center shifts in last iteration "
                f"larger than specified threshold {self.tol}: "
                f"\n{shift.view(-1, self.num_init).mean(-1)}"
            )
            masked_shift_per_cluster = shift_per_cluster.clone()
            masked_shift_per_cluster[masked_shift_per_cluster < self.tol] = 0
            print("Masked shifts per cluster:", masked_shift_per_cluster)
            print(shift_per_cluster.mean(-1), shift)

        if self.num_init > 1:
            centers[k_mask] = 0
            dist = self._pairwise_distance(x, centers)
            dist[k_mask[:, :, None, :].expand(bs, self.num_init, n, -1)] = float("-inf")
            best_init = torch.argmax(dist.sum(-1).sum(-1), dim=-1)
            b_idx = torch.arange(bs, device=x.device)
            centers = centers[b_idx, best_init].unsqueeze(1)
            k_mask = k_mask[b_idx, best_init].unsqueeze(1)

        # enable (approx.) grad computation in final iteration
        with torch.enable_grad():
            centers[k_mask] = 0
            centers = self._cluster_iter(x, centers.detach().clone())
            centers[k_mask] = 0
            dist = self._pairwise_distance(x, centers)
            dist = dist.clone()
            # mask probability for non-existing centers
            dist[k_mask[:, :, None, :].expand(bs, 1, n, -1)] = float("-inf")
            soft_assignment = torch.softmax(self.temp * dist, dim=-1)

        dist = dist.squeeze(1)
        centers = centers.squeeze(1)
        soft_assignment = soft_assignment.squeeze(1)

        # hard assignment via argmax of similarity value to each cluster center
        c_assign = torch.argmax(dist, dim=-1).squeeze(1)
        all_same = (c_assign == c_assign[:, 0].unsqueeze(-1)).all(-1)
        if all_same.any():
            warn(
                f"Distance to all cluster centers is the same for instance(s) "
                f"with idx: {all_same.nonzero().squeeze().cpu().numpy().tolist()}. "
                f"Assignment will be random!"
            )
            same_dist = dist[all_same]
            if self.seed is not None:
                gen = torch.Generator(device=x.device)
                gen.manual_seed(self.seed)
            else:
                gen = None
            c_assign[all_same] = torch.randint(
                low=0,
                high=k_max,
                size=same_dist.shape[:-1],
                generator=gen,
                device=x.device,
            )
        self._cluster_shift_trajectories = torch.stack(all_shift_per_cluster)
        return c_assign, centers, dist, soft_assignment


class SparseAutoEncoder(torch.nn.Module):
    def __init__(self, n_inits: int, dim: int, hidden_dim: int) -> None:
        super().__init__()
        self._E = torch.nn.Parameter(
            data=torch.randn(n_inits, dim, hidden_dim, dtype=torch.float32)
            / np.sqrt(dim),
            requires_grad=True,
        )
        self._D = torch.nn.Parameter(
            data=torch.randn(n_inits, hidden_dim, dim, dtype=torch.float32)
            / np.sqrt(hidden_dim),
            requires_grad=True,
        )
        self._bias_E = torch.nn.Parameter(
            data=torch.randn(n_inits, hidden_dim, dtype=torch.float32)
            / np.sqrt(dim)
            * 1e-6,
            requires_grad=True,
        )
        self._bias_D = torch.nn.Parameter(
            data=torch.randn(n_inits, dim, dtype=torch.float32)
            / np.sqrt(hidden_dim)
            * 1e-6,
            requires_grad=True,
        )

    def encode(
        self, x: torch.Tensor, include_pre_relu: bool = False
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        if x.ndim == 2:
            x = x[:, None] - self._bias_D
        elif x.ndim == 3:
            x = x - self._bias_D
        else:
            raise ValueError("Invalid shape.")

        y = torch.einsum("bid,idh->bih", x, self._E) + self._bias_E
        z = torch.nn.functional.relu(y)

        if include_pre_relu:
            return z, y
        else:
            return z

    def decode(self, x: torch.Tensor) -> torch.Tensor:
        if x.ndim == 2:
            return torch.einsum("bh,ihd->bid", x, self._D) + self._bias_D
        elif x.ndim == 3:
            return torch.einsum("bih,ihd->bid", x, self._D) + self._bias_D
        else:
            raise ValueError("Invalid shape.")

    def forward(
        self, x: torch.Tensor, include_pre_relu: bool = False
    ) -> Union[
        Tuple[torch.Tensor, torch.Tensor],
        Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
    ]:
        z = self.encode(x, include_pre_relu)
        if include_pre_relu:
            z, y = z
            return self.decode(z), (z, y)
        else:
            return self.decode(z), z


class InfiniteIterator:
    def __init__(self, iterable: Iterable) -> None:
        self.iterable = iterable
        self.iterator = iter(iterable)

    def __iter__(self) -> Iterator:
        return self

    def __next__(self) -> Any:
        try:
            return next(self.iterator)
        except StopIteration:
            self.iterator = iter(self.iterable)
            return next(self.iterator)


class SparseAutoEncoderTrainer:
    def __init__(
        self,
        hidden_dim: int,
        n_steps: int,
        batch_size: int = -1,
        l1_weight: Union[float, Sequence[float]] = 1.0,
        n_inits: Optional[int] = None,
        get_optimizer: Optional[
            Callable[[Iterator[torch.nn.Parameter]], torch.optim.Optimizer]
        ] = torch.optim.Adam,
        reanimation_frequency: int = 10,
        verbose: bool = True,
    ) -> None:
        self.hidden_dim = hidden_dim

        if not hasattr(l1_weight, "__len__"):
            l1_weight = [l1_weight]

        l1_weight = torch.tensor(l1_weight)

        if n_inits is None:
            n_inits = len(l1_weight)
        else:
            if len(l1_weight) != n_inits:
                raise ValueError("Shape of l1_weight does not match n_inits.")

        if batch_size < -1 or batch_size == 0:
            raise ValueError("Invalid batch_size. Must be -1 or > 0.")

        self.l1_weight = l1_weight
        self.n_inits = n_inits
        self.n_steps = n_steps
        self.verbose = verbose
        self.batch_size = batch_size
        self.get_optimizer = get_optimizer
        self.reanimation_frequency = reanimation_frequency
        self.model: Optional[SparseAutoEncoder] = None

    def fit(self, x: torch.Tensor, finetune: bool = False) -> None:
        batch_size = self.batch_size if self.batch_size != -1 else x.shape[0]

        pbar = range(self.n_steps)
        if self.verbose:
            pbar = tqdm.tqdm(pbar, desc="Training", leave=False)

            def log(s: str) -> None:
                pbar.set_description(s)

        else:

            def log(s: str) -> None:
                pass

        data_loader = torch.utils.data.DataLoader(
            torch.utils.data.TensorDataset(x), batch_size=batch_size, shuffle=True
        )
        data_iter = InfiniteIterator(data_loader)

        if not finetune:
            self.model = SparseAutoEncoder(self.n_inits, x.shape[-1], self.hidden_dim)
        else:
            if self.model is None:
                raise ValueError("Cannot finetune without a model. Call fit() first.")

        self.model = self.model.to(x.device)

        self.model._bias_D.data = x.mean(0)[None].repeat(self.n_inits, 1)

        optimizer = self.get_optimizer(self.model.parameters())

        l1_weight = self.l1_weight.to(x.device)

        for i, x_batch in zip(pbar, data_iter):
            x_batch = x_batch[0]

            x_hat_batch, z_batch = self.model(x_batch)
            reconstruction_losses = (
                ((x_hat_batch - x_batch[:, None]) ** 2).sum(-1).mean(0)
            )
            sparsity_losses = z_batch.sum(-1).mean(0)

            # Reset dead units (i.e. units that are always zero) every 10 steps.
            dead_units_mask = z_batch.mean(0) == 0
            if i % self.reanimation_frequency == 0 and torch.any(dead_units_mask):
                print(
                    f"Warning: {dead_units_mask.sum(-1).cpu().numpy()} units are dead."
                )
                # Reset dead units.
                # TODO(zimmerrol): Check whether we really have to re-initialize the
                #   weights of the encoder.
                self.model._E.data.transpose(1, 2)[dead_units_mask] = (
                    torch.randn(
                        dead_units_mask.sum(),
                        self.model._E.shape[1],
                        dtype=torch.float32,
                        device=x.device,
                    )
                    / np.sqrt(self.model._E.shape[2])
                    * 1e-4
                )
                self.model._D.data[dead_units_mask] = (
                    torch.randn(
                        dead_units_mask.sum(),
                        self.model._D.shape[2],
                        dtype=torch.float32,
                        device=x.device,
                    )
                    / np.sqrt(self.model._D.shape[1])
                    * 1e-4
                )
                self.model._bias_E.data[dead_units_mask] = (
                    torch.randn(
                        dead_units_mask.sum(), dtype=torch.float32, device=x.device
                    )
                    / np.sqrt(self.model._E.shape[2])
                    * 1e-10
                )

            losses = reconstruction_losses + l1_weight * sparsity_losses
            loss = losses.sum()
            self.model.zero_grad()
            loss.backward()
            optimizer.step()
            log(
                f"Loss: ({', '.join([f'{it:.2f}' for it in losses])}). Reconstruction: ({', '.join([f'{it:.2f}' for it in reconstruction_losses])}). Sparsity: ({', '.join([f'{it:.2f}' for it in sparsity_losses])})"
            )
