import torch

class DebugContainer(torch.nn.Module):
    """
    Simple class to collect the key aspects of the weight matrices during the ADMM process. Reuses 
    the PyTorch module implementation to provide easy load/save.
    """

    def __init__(self, listOfParameters: [str]):
        super().__init__()

        self._tracked_parameters = listOfParameters
        self._z_groups = dict((key, torch.nn.ParameterList()) for key in listOfParameters)
        self._w_groups = dict((key, torch.nn.ParameterList()) for key in listOfParameters)
        self._mask_groups = dict((key, torch.nn.ParameterList())  for key in listOfParameters)

    def forward(self, x):
        pass

    def addTensor(self, name: str, z_tensor: torch.Tensor, w_tensor: torch.Tensor, mask_tensor: torch.Tensor):
        if name in self._tracked_parameters:
            self._z_groups[name].append(torch.nn.Parameter(z_tensor, requires_grad=False))
            self._w_groups[name].append(torch.nn.Parameter(w_tensor, requires_grad=False))
            self._mask_groups[name].append(torch.nn.Parameter(mask_tensor, requires_grad=False))
        else:
            raise KeyError(name)
    
    def getLists(self, name: str):
        if name in self._tracked_parameters:
            z_list = self._z_groups[name]
            w_list = self._w_groups[name]
            mask_list = self._mask_groups[name]

            return (z_list, w_list, mask_list)
        else:
            raise KeyError(name)
    
    def validateStructure(self):
        """
        Validate that we have collected all of the appropriate Tensors for each parameter in tracked_parameters.
        """

        lastLength = -1

        for name in self._tracked_parameters:
            if (name not in self._z_groups or
                name not in self._w_groups or
                name not in self._mask_groups):
                return False

            if (len(self._z_groups[name]) != len(self._w_groups[name]) or
                len(self._z_groups[name]) != len(self._mask_groups[name])):
                return False
            
            if lastLength != -1 and lastLength != len(self._z_groups[name]):
                return False

            lastLength = len(self._z_groups[name])

        return True

    @property
    def debugParameters(self):
        return self._tracked_parameters