import torch
import copy
import torch.nn as nn

from models.model_wrapper import ModelWrapper

def reset_weights(m):
    for layer in m.children():
        if hasattr(layer, 'reset_parameters'):
            layer.reset_parameters()

class EnsembleWrapper(ModelWrapper):
    def __init__(self, model, ensemble_size=10, **kwargs):
        super(EnsembleWrapper, self).__init__()

        self.models = nn.ModuleList()

        for i in range(ensemble_size):
            new_model = copy.deepcopy(model)

            new_model.apply(reset_weights)

            for param in new_model.parameters():
                param.optimizer = i

            self.models.append(new_model)

    def forward(self, input):
        outputs = []
        for model in self.models:
            outputs.append(model(input))

        outputs = torch.stack(outputs, dim=0)

        if self.models.training:
            return outputs
        else:
            return torch.mean(outputs, axis=0)

    def __getattr__(self, name):
        # Due to weird PyTorch behavior, we have to return any wrapped modules like this
        if name in self._modules:
            return self._modules[name]

        for submodule in self.models:
            try:
                return getattr(submodule, name)
            except:
                pass

        # If we get here, no submodules had the attribute
        raise AttributeError(f"'{type(self).__name__}' object nor its wrapped module have attribute '{name}'")
