import logging
import torch

class AvgFusion:
    def __init__(self, args, base_models, target_model):
        self.args = args
        self.base_models = base_models
        self.target_model = target_model

    def fuse(self):
        logging.info('Starting model fusion')
        # Fuse the parameters
        for i in range(1, self.target_model.num_layers + 1):
            if self.args.model_name in ['rnn', 'RNN', 'lstm', 'LSTM']:
                self.fuse_single_layer_rnn(i)
            else:
                self.fuse_single_layer(i)

        # Synthetic check on whether fused model get updated or not
        # for parameter_1, parameter_2 in zip(self.base_models[0].parameters(), self.target_model.parameters()):
        #     print('Difference between weights of model 1 and fused model:', (parameter_1 - parameter_2).abs().mean())

        # for parameter_1, parameter_2 in zip(self.base_models[1].parameters(), self.target_model.parameters()):
        #     print('Difference between weights of model 2 and fused model:', (parameter_1 - parameter_2).abs().mean())

        logging.info('Finish vanilla averaging model fusion.')

    def fuse_single_layer(self, layer):
        avg_weight = None
        with torch.no_grad():
            for model in self.base_models:
                if avg_weight is None:
                    avg_weight = model.get_layer_weights(layer_num=layer)
                else:
                    avg_weight += model.get_layer_weights(layer_num=layer)
            avg_weight /= len(self.base_models)

        target_weights = self.target_model.get_layer_weights(layer_num=layer)
        target_weights.data = avg_weight.data

    def fuse_single_layer_rnn(self, layer):
        avg_Ws = None
        avg_Hs = None
        with torch.no_grad():
            for model in self.base_models:
                Ws, Hs = model.get_layer_weights(layer_num = layer)
                if avg_Ws is None:
                    avg_Ws = Ws
                else:
                    avg_Ws += Ws
                if Hs is not None:
                    if avg_Hs is None:
                        avg_Hs = Hs
                    else:
                        avg_Hs += Hs
            avg_Ws /= len(self.base_models)
            if avg_Hs is not None:
                avg_Hs /= len(self.base_models)
            if self.args.model_name in ['LSTM', 'lstm']:
                self.target_model.update_layer_weights(layer, avg_Ws, avg_Hs)
            else:
                target_Ws, target_Hs = self.target_model.get_layer_weights(layer_num = layer)
                target_Ws.data = avg_Ws.data
                if avg_Hs is not None:
                    target_Hs.data = avg_Hs.data











