"""Main class of our algorithmic framework.
This performs a forward pass across layers that are specified in its init. """

import torch
from torch import nn

from spodnet.models.perturbation_layers import UpdateTheta_Perturbations


class SpodNet(nn.Module):
    """ Main class : unrolled algorithm. """

    def __init__(self, K, p, layer_type, zeta, device):
        super().__init__()

        self.K = K
        self.p = p

        self.layer_type = layer_type

        self.device = device

        self.zeta = torch.tensor([zeta]).to(device)

        if layer_type == 'UBG_masks':
            print("Learning in UBG mode.")
            self.forward_stack = UpdateTheta_Perturbations(
                self.p, theta_12_generator='UBG', device=device)

        elif layer_type == 'PNP_masks':
            print("Learning in PNP mode.")
            self.forward_stack = UpdateTheta_Perturbations(
                self.p, theta_12_generator='PNP', device=device)

        if layer_type == 'E2E_masks':
            print("Learning in E2E mode.")
            self.forward_stack = UpdateTheta_Perturbations(
                self.p, theta_12_generator='E2E', device=device)

    def forward(self, S):
        """ Forward pass. """

        Theta = torch.linalg.pinv(
            S + torch.eye(S.shape[-1]).expand_as(S).type_as(S), hermitian=True)
        W = torch.linalg.pinv(Theta, hermitian=True)

        self.forward_stack.W = W.clone().detach()
        self.forward_stack.S = S.clone().detach()

        Theta_list = []
        for _ in range(self.K):
            Theta = self.forward_stack(Theta)
            Theta_list.append(Theta)

        return Theta, Theta_list
