'''
=====
- Associated publication:
url: 
doi: 
github: 
=====
'''
import torch
import torch.nn as nn
import logging
import numpy as np
from typing import Optional
from .embedding_model import EmbeddingModel
from torch.autograd import Variable

class CylinderPCAEmbedding(EmbeddingModel):
    """Embedding PCA model for the 2D flow around a cylinder system

    Args:
        config (:class:`config.configuration_phys.PhysConfig`) Configuration class with transformer/embedding parameters
        n_pca (int, optional): Number of principle components to save, defaults to config.n_embd
    """
    model_name = "embedding_cylinder_pca"

    def __init__(self, config, n_pca:Optional[int] = None):
        """Constructor method
        """
        super().__init__(config)
        self.n_dim = np.prod(config.state_dims)
        if n_pca is None:
            n_pca = config.n_embd
        self.n_pca = n_pca
        # Pytorch and Numpy x, y is kinda reverse x= ncols, y = nrows
        X, Y = np.meshgrid(np.linspace(-2, 14, 128), np.linspace(-4, 4, 64))
        self.mask = torch.tensor(np.sqrt(X ** 2 + Y ** 2) < 1, dtype=torch.bool)
        # Principle components
        self.register_buffer("_v", torch.zeros(self.n_dim, n_pca)) # [n x k]

        # Normalization occurs inside the model
        self.register_buffer('mu', torch.tensor(0.))
        self.register_buffer('std', torch.tensor(1.))

    def forward(self, x):
        """Forward pass; embedds the state variables

        Args:
            x (torch.Tensor): [B, 3, H, W] Input feature tensor

        Returns:
            (torch.Tensor): [B, config.n_embd] Koopman observables
        """
        return self.embed(x)

    def embed(self, x):
        """Embeds tensor of state variables using PCA

        Args:
            x (torch.Tensor): [B, 3, H, W] Input feature tensor

        Returns:
            (torch.Tensor): [B, n_pca] Koopman observables
        """
        x = x.view(x.size(0), -1)
        assert x.shape[-1] == self.n_dim, 'State dimensions of the input tensor not correct.'
        x = self._normalize(x)
        out = torch.zeros(x.size(0), self.n_pca).to(x.device)
        for i in range(x.size(0)):
            out[i] = torch.mm(x[i].unsqueeze(0), self._v).squeeze(1)
        return out

    def recover(self, g):
        """Recovers approximate feature tensor from PCA projection

        Args:
            g (torch.Tensor): [B, config.n_embd] Koopman observables

        Returns:
            (torch.Tensor): [B, 3, H, W] Physical feature tensor

        Note:
            See this `stack-exchange answer <https://stats.stackexchange.com/questions/229092/how-to-reverse-pca-and-reconstruct-original-variables-from-several-principal-com>`_ for a good explaination if you are not completely familiar with PCA.
        """
        g = g.view(g.size(0), -1)
        assert g.shape[-1] == self.n_pca, 'Embedded dimensions of the input tensor not correct.'
        out = torch.zeros(g.size(0), self.n_dim).to(g.device)
        for i in range(g.size(0)):
            out[i] = torch.mm(g[i].unsqueeze(0), self._v.T).squeeze(1)
        out = out.view([-1] + self.input_dims) # Reshape to flow domain
        out = self._unnormalize(out)
        # Apply mask
        mask0 = self.mask.repeat(out.size(0), out.size(1), 1, 1) == True
        out[mask0] = 0
        return out

    @property
    def pc(self):
        """Principle components

        Returns:
            torch.Tensor: Current principle components
        """
        return self._v

    @pc.setter
    def pc(self, v0):
        assert v0.size(0) == self.n_dim, 'Mismatch of the state dimension of the principle components.'
        assert v0.size(1) == self.n_pca, 'Mismatch of the number of principle components.'
        self._v.data = v0

    def _normalize(self, x):
        x = (x - self.mu) / self.std
        return x

    def _unnormalize(self, x):
        return self.std*x + self.mu