import torch
import torch.nn as nn


class EMADummyModel(nn.Module):
    def __init__(self, parameters_list):
        super().__init__()
        self.param_sets = nn.ParameterList([
            nn.ParameterList([nn.Parameter(p.detach(), requires_grad=False) for p in params])
            for params in parameters_list
        ])

    def get_param_sets(self):
        
        return [[p for p in param_set] for param_set in self.param_sets]

