from logging import Logger
from typing import Literal, Union,Callable,TypedDict
from numpy.typing import ArrayLike

import numpy as np
import torch

from src.utils import sqrtmh


def augment(X: ArrayLike, memory_length:int = 1, backend: Union['torch','numpy'] = 'numpy')->ArrayLike:
    n = X.shape[0]
    window_shape=n-memory_length
    if backend=='numpy':
        X = np.lib.stride_tricks.sliding_window_view(X, window_shape=window_shape, axis=0).T.reshape(window_shape,-1,order='F')
    elif backend=='torch':
        X = X.unfold(0,window_shape,1).reshape(window_shape,-1).T.reshape(window_shape,-1)
    else:
        Logger.warning(
            f"Warning: Backend {backend} is not supported."
        )
    return X

def polynomial_feature_map(X,order=5):
   X_ = X.reshape(X.shape[0],-1)
   Z = X_
   for i in range(order-1):
       Z = np.concatenate((Z, X_**(i+1)), axis=1)
   return Z


class LinearOperator(TypedDict):
   name: str 
   norm: float 
   metric_inner_product: Callable
   operator_inner_product: Callable

def TransferOperator(name:str = 'Transfer Operator')->LinearOperator:
    norm = 1.
        
    def metric_inner_product(U:ArrayLike,V:Union[ArrayLike,None]=None)->ArrayLike:
        if V is None:
            return torch.einsum('ki,kj->ij',U,U) / U.shape[0]
        else:
            return torch.einsum('ki,kj->ij',U,V) / U.shape[0]

    def operator_inner_product(U:ArrayLike,V:ArrayLike)->ArrayLike:
        return torch.einsum('ik,ik->k',U[:-1,:],V[1:,:]) / (U.shape[0]-1)
    
    result: LinearOperator = {'name':name, 'norm':norm, 'metric_inner_product':metric_inner_product,'operator_inner_product':operator_inner_product}
    return result


class SimpleMLP(torch.nn.Module):
    def __init__(
        self, feature_dim: int, layer_dims: list[int], data_shape:list[int], activation=torch.nn.LeakyReLU
    ):
        super().__init__()
        self.activation = activation
        lin_dims = (layer_dims + [feature_dim]) 

        layers = []

        for layer_idx in range(len(lin_dims)-2):
            layers.append(torch.nn.Linear(lin_dims[layer_idx], lin_dims[layer_idx + 1], bias=False))
            layers.append(activation())

        layers.append(torch.nn.Linear(lin_dims[-2], lin_dims[-1], bias=True))

        layers.append(activation())

        self.layers = torch.nn.ModuleList(layers)

    def forward(self, x):
        # MLP
        for layer in self.layers:
            x = layer(x)

        # Create a new tensor with the first feature set to 1
        if x.dim() == 2:  # If the input is a batch of vectors
            x = torch.cat([torch.ones(x.size(0), 1, device=x.device), x[:, 1:]], dim=1)
        elif x.dim() == 1:  # If the input is a single vector
            x = torch.cat([torch.tensor([1.0], device=x.device), x[1:]])

        return x


class SimpleCNN(torch.nn.Module):
    def __init__(self, feature_dim: int, layer_dims: list[int], data_shape: list[int], conv_type: str = "1d",
                 activation=torch.nn.ReLU):
        """
        Simple CNN for embedding 1-channel 32x32 images.

        Args:
            latent_dim (int): Size of the output embedding.
            layer_dims (list): Number of filters in each convolutional layer.
            data_shape (list): shape of tensor in the form 1d: (number_of_input_channels, length) 2d:(number_of_input_channels, width, hight)
            conv_type (str): Type of convolution ('1d' or '2d'). Defaults to '1d'.
            activation: activation function (defaults ReLU)
        """
        super(SimpleCNN, self).__init__()
        self.data_shape = data_shape 
        self.conv_type = conv_type

        # Convolutional layers
        self.conv_layers = torch.nn.ModuleList()
        in_channels = data_shape[0]  # Input is number_of_input_channels image
        for out_channels in layer_dims:
            if self.conv_type == "1d":
                layer = torch.nn.Sequential(
                    torch.nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
                    activation(),
                    #torch.nn.MaxPool1d(kernel_size=2, stride=1)
                )
            elif self.conv_type == "2d":
                # Assuming data_shape is in the form (channels, height, width)
                layer = torch.nn.Sequential(
                    torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
                    activation(),
                    torch.nn.MaxPool2d(kernel_size=2, stride=2)
                )
            else:
                raise ValueError("conv_type must be either '1d' or '2d'")
            self.conv_layers.append(layer)  
            in_channels = out_channels

        # Calculate the size of the flattened feature map
        self.flattened_size = self._get_flattened_size()

        # Fully connected layer to produce the latent embedding
        self.fc = torch.nn.Linear(self.flattened_size, feature_dim)

    def _get_flattened_size(self):
        """
        Helper function to calculate the size of the flattened feature map.
        """
        with torch.no_grad():
            if self.conv_type == "1d":
                dummy_input = torch.zeros(1, self.data_shape[0], self.data_shape[1])  # Batch of 1, input_channels, length
            elif self.conv_type == "2d":
                dummy_input = torch.zeros(1, self.data_shape[0], self.data_shape[1], self.data_shape[2])  # Batch of 1, input_channels, width x hight image
            for layer in self.conv_layers:
                dummy_input = layer(dummy_input)
            return dummy_input.view(1, -1).shape[1]

    def forward(self, x):
        """
        Forward pass for the CNN.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, input_channels, width, hight).

        Returns:
            torch.Tensor: Output embedding of shape (batch_size, latent_dim).
        """
        # Apply convolutional layers
        for layer in self.conv_layers:
            x = layer(x)

        # Flatten the feature maps
        x = x.view(x.size(0), -1)

        # Apply fully connected layer to get the latent embedding
        x = self.fc(x)

        return x

class SingularValueEmbedding(torch.nn.Module):
    def __init__(self, 
                 embedding: torch.nn.Module, operator: LinearOperator, 
                 feature_dim: int, layer_dims: list[int], data_shape: list[int], activation=torch.nn.Tanh,
                 centered: bool = False, symmetric: bool = False, learnable_sigma: bool = True, whitening :bool = True
    ):
        super(SingularValueEmbedding, self).__init__()
        self.embedding = embedding
        self.feature_dim = feature_dim
        self.centered = centered
        self.operator = operator
        self.learnable_sigma = learnable_sigma
        self.symmetric = symmetric
        self.whitening = whitening
        self.lsvecs = None
        self.rsvecs = None 

        if symmetric:
            self.u = embedding(feature_dim = feature_dim, layer_dims = layer_dims, data_shape = data_shape, activation=activation)
            self.v = self.u
        else:    
            self.u = embedding(feature_dim = feature_dim,layer_dims = layer_dims, data_shape = data_shape, activation=activation)
            self.v = embedding(feature_dim = feature_dim,layer_dims = layer_dims, data_shape = data_shape, activation=activation)
        if learnable_sigma:
            self.log_sqrt_sigma = torch.nn.Parameter(torch.zeros(feature_dim,dtype=torch.float32))

    def sigma(self):
        return torch.exp(-self.log_sqrt_sigma**2) * self.operator['norm']
    
    
    def regularized_loss(self, x:ArrayLike, on_weigth: float = 0.1, requires_grad = False):
        npts = x.shape[0]
        x.requires_grad=requires_grad    
        #embed data
        U = self.u(x)
        V = self.v(x)
        correction = 1.
        if self.centered:
            U = U - torch.mean(U,dim=0,keepdim=True)
            V = V - torch.mean(V,dim=0,keepdim=True)
            correction = U.shape[0]/(U.shape[0]-1)
        
        if self.learnable_sigma:
            s = self.sigma()
        else:
            s = self.operator['operator_inner_product'](U,V)

        #compute the term tr(Sigma U^*U \Sigma V^*V)
        loss_1 = torch.einsum('i,ij,j,ji->',s,self.operator['metric_inner_product'](U),s,self.operator['metric_inner_product'](V))*(correction**2)
        loss_1 /= self.operator['norm']

        #compute the term -2 tr(V^* Operator U Sigma)
        loss_2 = - 2*(self.operator['operator_inner_product'](U,V)*s).sum()*correction
        loss_2 /= self.operator['norm']
        
        # compute ortho-normality term || U^*U - I||_F^2 + || V^*V - I||_F^2, split the batch for uniased estiamtion
        perm_ = torch.randperm(npts)
        split = npts//2
        loss_3 = torch.einsum('ij,ji->',
                            self.operator['metric_inner_product'](U[perm_][:split]),
                            self.operator['metric_inner_product'](U[perm_][split:]))*(correction**2)
        loss_3 -= torch.einsum('ii->',self.operator['metric_inner_product'](U[perm_][:split]))*correction
        loss_3 -= torch.einsum('ii->',self.operator['metric_inner_product'](U[perm_][split:]))*correction
        loss_3 += self.feature_dim
        if self.centered:
            loss_3 += (self.operator['metric_inner_product'](U,torch.ones(U.shape[0],1))**2).sum()
            
        if not self.symmetric:
            loss_3 += torch.einsum('ij,ji->',
                                self.operator['metric_inner_product'](V[perm_][perm_][:split]),
                                self.operator['metric_inner_product'](V[perm_][split:]))*(correction**2)
            loss_3 -= torch.einsum('ii->',self.operator['metric_inner_product'](V[perm_][:split]))*correction
            loss_3 -= torch.einsum('ii->',self.operator['metric_inner_product'](V[perm_][split:]))*correction
            loss_3 += self.feature_dim
            if self.centered:
                loss_3 += (self.operator['metric_inner_product'](V,torch.ones(U.shape[0],1))**2).sum()
        
        return {"total": loss_1 + loss_2 + on_weigth*loss_3, "svd": loss_1 + loss_2, "ortho-normal": loss_3}
    
    def fit(self,x:ArrayLike,requires_grad = False):
        npts = x.shape[0]
        x.requires_grad=requires_grad    
        #embed data
        U = self.u(x[:-1])
        V = self.v(x[1:])
        correction = 1.
        if self.centered:
            self.U_mean = torch.mean(U,dim=0,keepdim=True)
            U = U - self.U_mean
            self.V_mean = torch.mean(V,dim=0,keepdim=True)
            V = V - self.V_mean
            correction = U.shape[0]/(U.shape[0]-1)
        else:
            self.U_mean = torch.zeros((1,U.shape[1]))
            self.V_mean = torch.zeros((1,V.shape[1]))
        
        if self.learnable_sigma:
            s = self.sigma()
        else:
            s = self.operator['operator_inner_product'](U,V)

        if not self.whitening:
            self.svals = s
            self.lsvecs = torch.eye(U.shape[1])
            self.rsvecs = torch.eye(V.shape[1])
        else:
            U = U @ torch.diag(torch.sqrt(s))
            V = V @ torch.diag(torch.sqrt(s))
        
            Cu = correction*torch.einsum('ij,ik->jk',U,U)/U.shape[0]
            Cv = correction*torch.einsum('ij,ik->jk',V,V)/V.shape[0]
            Cuv = correction*torch.einsum('ij,ik->jk',U,V)/U.shape[0]

            # write in a stable way
            Cu_sqrt_inv = torch.linalg.pinv(sqrtmh(Cu))
            Cv_sqrt_inv = torch.linalg.pinv(sqrtmh(Cv))
            M = Cu_sqrt_inv @ Cuv @ Cv_sqrt_inv
            svals, lsvecs = torch.linalg.eigh(M @ M.T)
            rsvecs = M.T @ lsvecs / svals
            self.svals = svals
            self.lsvecs = Cu_sqrt_inv @ lsvecs
            self.rsvecs = Cv_sqrt_inv @ rsvecs
        
    def input_fmap(self,x):
        return (self.V(x)-self.V_mean)@self.rsvecs

    def output_fmap(self,x):
        return (self.u(x)-self.U_mean)@self.lsvecs