import torch
from torch import nn

from utils import get_unique_num


class KoopmanLayer(nn.Module):

    def __init__(self, args):
        super(KoopmanLayer, self).__init__()

        self.args = args
        self.n_frames = args.n_frames

        # Koopyman hyper-parameters
        self.k_dim = args.k_dim
        self.static = args.static_size

        # loss functions
        self.loss_func = nn.MSELoss()
        self.dynamic_threshold_loss = nn.Threshold(args.dynamic_thresh, 0)

    def forward(self, Z):
        # Z is in b * t x c x 1 x 1
        Zr = Z.squeeze().reshape(-1, self.n_frames, self.k_dim)

        # split
        X, Y = Zr[:, :-1], Zr[:, 1:]

        # solve linear system (broadcast)
        Ct = torch.linalg.pinv(X.reshape(-1, self.k_dim)) @ Y.reshape(-1, self.k_dim)

        # predict (broadcast)
        Y2 = X @ Ct
        Z2 = torch.cat((X[:, 0].unsqueeze(dim=1), Y2), dim=1)

        assert (torch.sum(torch.isnan(Y2)) == 0)

        return Z2.reshape(Z.shape), Ct

    def loss(self, X_dec, X_dec2, Z, Z2, Ct):
        # predict ambient
        E1 = self.loss_func(X_dec, X_dec2)

        # predict latent
        E2 = self.loss_func(Z, Z2)

        # Koopman operator constraints (disentanglement)
        D = torch.linalg.eigvals(Ct)

        # Dn = torch.real(torch.conj(D) * D)
        Dr = torch.real(D)
        Db = torch.sqrt((Dr - torch.ones(len(Dr)).to(Dr.device)) ** 2 + torch.imag(D) ** 2)

        # ----- static loss ----- #
        I = torch.argsort(Db)
        new_static_number = get_unique_num(D, torch.flip(I, dims=[0]), self.static)
        Is, Id = I[:new_static_number], I[new_static_number:]
        Dbs = torch.index_select(Db, 0, Is)
        E3_static = self.loss_func(Dbs, torch.zeros(len(Dbs)).to(Dbs.device))

        Drd = torch.index_select(Dr, 0, Id)
        E3_dynamic = torch.mean(self.dynamic_threshold_loss(Drd))

        E3 = E3_static + E3_dynamic

        return E1, E2, E3
