import copy
import logging
import pdb
import torch
import os

from torch.cuda.random import initial_seed

from src.tlp_model_fusion.utils import memory_safe_methods
from src.tlp_model_fusion.utils import ot_algorithms
from src.tlp_rnn_fusion import rnn_models
from src.tlp_rnn_fusion import fuse_rnn_models
from src.tlp_model_fusion import fuse_models

class OTFusion:
    """
    This class implements the ot fusion technique mentioned in following paper:
    https://arxiv.org/abs/1910.05653
    The method is based on alignment of layers to one of the pre-initialized
    target model.
    """

    def __init__(self, args, train_init, base_models, target_model, data):
        self.args = args
        self.train_init = train_init
        self.base_models = base_models
        self.target_model = target_model
        self.data = data

    def fuse(self):
        logging.info("Starting model fusion")
        logging.info("Sinkhorn reg {}".format(self.args.ad_hoc_sinkhorn_regularization))
        if torch.cuda.is_available():
            for model in self.base_models:
                model.cuda()
            self.target_model.cuda()
            if self.data is not None:
                self.data = self.data.cuda()

        # initialization - the first layer
        prev_pi = []
        pis_model_1 = [] # couplings for model 1
        pis_model_2 = [] # couplings for model 2

        if self.args.encoder:
            initial_dimension = self.target_model.input_dim
        else:
            #initial_dimension = self.target_model.channels[0]
            initial_dimension = self.target_model.input_dim

        #input_dim = self.target_model.input_dim
        for i in range(len(self.base_models)):
            # For the input layer identity is the coupling between the nodes.
            pi = torch.eye(initial_dimension, dtype=torch.float) / (1.0 * initial_dimension)
            if torch.cuda.is_available():
                pi = pi.cuda()
            prev_pi.append(pi)
        pis_model_1.append(pi)
        pis_model_2.append(pi)

        for i in range(1, self.target_model.num_layers + 1):
            cur_pi = self.fuse_single_layer(layer=i, prev_pi=prev_pi)
            prev_pi = cur_pi
            if not self.args.target_diff_architecture:
                pis_model_1.append(cur_pi[0]) # get couplings for model 1
                pis_model_2.append(cur_pi[1]) # get couplings for model 2

        if not self.args.target_diff_architecture:
            logging.info('Begin to generate permuted model 1.')  
            if self.train_init is not None:
                permuted_model_1_path = os.path.join(self.train_init.model_path, 'permuted_model_1.pth')
                permuted_model_2_path = os.path.join(self.train_init.model_path, 'permuted_model_2.pth')
            else:
                permuted_model_1_path = os.path.join(self.args.save_path, 'permuted_model_1.pth')
                permuted_model_2_path = os.path.join(self.args.save_path, 'permuted_model_2.pth')
            if self.args.model_name == 'FC':
                self.generate_permuted_model_fc(self.base_models[0], pis_model_1, permuted_model_1_path)
            elif self.args.model_name == 'vgg11':
                self.generate_permuted_model_vgg(self.base_models[0], pis_model_1, permuted_model_1_path)
            elif self.args.model_name == 'resnet18':
                self.generate_permuted_model_resnet(self.base_models[0], pis_model_1, permuted_model_1_path)
            elif self.args.model_name in ['rnn', 'RNN', 'lstm', 'LSTM']:
                self.generate_permuted_model_rnn(self.base_models[0], pis_model_1, permuted_model_1_path)
            logging.info('Finish generating permuted model 1.')

            logging.info('Begin to generate permuted model 2.')
            if self.args.model_name == 'FC':
                self.generate_permuted_model_fc(self.base_models[1], pis_model_2, permuted_model_2_path)
            elif self.args.model_name == 'vgg11':
                self.generate_permuted_model_vgg(self.base_models[1], pis_model_2, permuted_model_2_path)
            elif self.args.model_name == 'resnet18':
                self.generate_permuted_model_resnet(self.base_models[1], pis_model_2, permuted_model_2_path)
            elif self.args.model_name in ['rnn', 'RNN', 'lstm', 'LSTM']:
                self.generate_permuted_model_rnn(self.base_models[1], pis_model_2, permuted_model_2_path)
            logging.info('Finish generating permuted model 2.')
        
        
        logging.info('Model fusion completed.')

    
    def generate_permuted_model_rnn(self, model, pis, path):
        config = model.get_model_config()
        print(config)

        permuted_model = fuse_rnn_models.get_model(self.args.model_name, self.args.input_dim, config, self.args.encoder)
        if torch.cuda.is_available():
            permuted_model.cuda()
    
        for i in range(1, len(pis)):
            prev_pi = pis[i-1]
            cur_pi = pis[i]
            k_l = cur_pi.size(-2)
            k_l_prev = prev_pi.size(-2)

            Ws, Hs = model.get_layer_weights(i)

            # Compute permuted Ws
            # Ws_permuted = torch.matmul(cur_pi, torch.matmul(Ws, k_l * k_l_prev * prev_pi.transpose(0, 1)))
            Ws_permuted = torch.matmul(cur_pi.transpose(0, 1), torch.matmul(Ws, k_l * k_l_prev * prev_pi))

            # Compute permuted Hs
            if Hs is not None:
                # Hs_permuted = torch.matmul(cur_pi, torch.matmul(Hs, k_l * k_l * cur_pi.transpose(0, 1)))
                Hs_permuted = torch.matmul(cur_pi.transpose(0, 1), torch.matmul(Hs, k_l * k_l * cur_pi))
            else:
                Hs_permuted = None
            
            # Update permuted model
            if self.args.model_name in ['LSTM', 'lstm']:
                permuted_model.update_layer_weights(i, Ws_permuted, Hs_permuted)
            else:
                permuted_model_Ws, permuted_model_Hs = permuted_model.get_layer_weights(i)
                permuted_model_Ws.data = Ws_permuted.data
                if permuted_model_Hs is not None:
                    permuted_model_Hs.data = Hs_permuted.data
    
        # Synthetic check on permuted model
        for parameter_1, parameter_2 in zip(model.parameters(), permuted_model.parameters()):
            print('Difference between the weights of original model and permuted model:', (parameter_1 - parameter_2).abs().mean()) 

        self.save_target_model(permuted_model, path)

    def generate_permuted_model_fc(self, model, pis, path):
        config = model.get_model_config()
        print(config)
        permuted_model = fuse_models.get_model(self.args.model_name, config)
        if torch.cuda.is_available():
            permuted_model.cuda()
        
        for i in range(1, len(pis)):
            prev_pi = pis[i-1]
            cur_pi = pis[i]
            k_l = cur_pi.size(-2)
            k_l_prev = prev_pi.size(-2)

            Ws= model.get_layer_weights(i)

            # Compute permuted Ws
            Ws_permuted = torch.matmul(cur_pi.transpose(0, 1), torch.matmul(Ws, k_l * k_l_prev * prev_pi))

            # Update permuted model
            permuted_model_Ws = permuted_model.get_layer_weights(i)
            permuted_model_Ws.data = Ws_permuted.data
        
        # Synthetic check on permuted model
        for parameter_1, parameter_2 in zip(model.parameters(), permuted_model.parameters()):
            print('Difference between the weights of original model and permuted model:', (parameter_1 - parameter_2).abs().mean()) 

        self.save_target_model(permuted_model, path)

    def generate_permuted_model_vgg(self, model, pis, path):
        config = model.get_model_config()
        print(config)
        permuted_model = fuse_models.get_model(self.args.model_name, config)
        if torch.cuda.is_available():
            permuted_model.cuda()

        is_linear_layer = False

        permuted_model_weights = []
        for parameter in permuted_model.parameters():
            permuted_model_weights.append(parameter)
        
        for i in range(1, len(pis)):
            logging.info('Generate layer {} for permuted model'.format(i))
            prev_pi = pis[i-1]
            cur_pi = pis[i]
            k_l = cur_pi.size(-2)
            k_l_prev = prev_pi.size(-2)

            Ws= model.get_layer_weights(i)
            permuted_model_Ws = permuted_model.get_layer_weights(i)
            
            if not is_linear_layer and len(Ws.size()) == 2:
                # To make use of the existing code, we convert the first linear layer weights
                # to cxdx7x7 format. This is to make use of the prev_pi.
                logging.info('First linear layer for VGG.')
                is_linear_layer = True
                Ws = Ws.view((Ws.size(0), -1, 7, 7))
                permuted_model_Ws = permuted_model_Ws.view((permuted_model_Ws.size(0), -1, 7, 7))
            # logging.info('The shape of Ws {}'.format(Ws.size()))

            # Compute permuted Ws
            is_conv = len(Ws.size()) == 4
            if is_conv:
                Ws = Ws.permute(2, 3, 0, 1)
                Ws_permuted = torch.matmul(cur_pi.transpose(0, 1), torch.matmul(Ws, k_l * k_l_prev * prev_pi))
                Ws_permuted = Ws_permuted.permute(2, 3, 0, 1)
            else:
                Ws_permuted = torch.matmul(cur_pi.transpose(0, 1), torch.matmul(Ws, k_l * k_l_prev * prev_pi))

            # Update permuted model
            # permuted_model_Ws.data = Ws_permuted.data
            permuted_model_Ws.data.copy_(Ws_permuted.data)
        
        # Synthetic check on permuted model
        for parameter_1, parameter_2 in zip(model.parameters(), permuted_model.parameters()):
            print('Difference between the weights of original model and permuted model:', (parameter_1 - parameter_2).abs().mean()) 

        self.save_target_model(permuted_model, path)

    def generate_permuted_model_resnet(self, model, pis, path):
        config = model.get_model_config()
        print(config)
        permuted_model = fuse_models.get_model(self.args.model_name, config)
        if torch.cuda.is_available():
            permuted_model.cuda()

        permuted_model_weights = []
        for parameter in permuted_model.parameters():
            permuted_model_weights.append(parameter)
        
        for i in range(1, len(pis)):
            logging.info('Generate layer {} for permuted model'.format(i))

            prev_similar_layer = model.get_prev_similar_layer(layer_num=i)
            #logging.info('Previous similar layer is layer {}'.format(prev_similar_layer))

            prev_layers_list = model.get_prev_layers(layer_num=i)
            if prev_similar_layer is None:
                cur_pi = pis[i]
            else:
                cur_pi = pis[prev_similar_layer]
            prev_pi = pis[prev_layers_list[0]]
            # cur_pi = pis[i]
            k_l = cur_pi.size(-2)
            k_l_prev = prev_pi.size(-2)

            Ws= model.get_layer_weights(i)
            permuted_model_Ws = permuted_model.get_layer_weights(i)
            #logging.info("Size of Ws {}".format(Ws.size()))
            #logging.info('Size of cur_pi {}'.format(cur_pi.size()))

            # Compute permuted Ws
            is_conv = len(Ws.size()) == 4
            if is_conv:
                Ws = Ws.permute(2, 3, 0, 1)
                Ws_permuted = torch.matmul(cur_pi, torch.matmul(Ws, k_l * k_l_prev * prev_pi.transpose(0, 1)))
                Ws_permuted = Ws_permuted.permute(2, 3, 0, 1)
            else:
                Ws_permuted = torch.matmul(cur_pi, torch.matmul(Ws, k_l * k_l_prev * prev_pi.transpose(0, 1)))

            # Update permuted model
            # permuted_model_Ws.data = Ws_permuted.data
            permuted_model_Ws.data.copy_(Ws_permuted.data)
        
        # Synthetic check on permuted model
        for parameter_1, parameter_2 in zip(model.parameters(), permuted_model.parameters()):
            print('Difference between the weights of original model and permuted model:', (parameter_1 - parameter_2).abs().mean()) 

        self.save_target_model(permuted_model, path)

    def save_target_model(self, model, path):
        torch.save(
            {
              'model_state_dict': model.state_dict(),
              'config': model.get_model_config()
            },
            path
        )
        logging.info('permuted model saved at {}'.format(path))


    def fuse_single_layer_helper(self, num_models, layer, is_last_layer, base_weights,
                                 target_weights, prev_pi, prev_similar_pi=None):
        logging.info('Helper Fusing layer {}'.format(layer))
        beta_prev = target_weights.size(1)
        beta = target_weights.size(0)
        is_conv = len(target_weights.size()) == 4
        is_lstm_layer = len(target_weights.size()) == 3
        tmp_target_weights = copy.deepcopy(target_weights.data)
        if is_lstm_layer:
          tmp_target_weights = tmp_target_weights.permute(1, 2, 0)
          beta_prev = tmp_target_weights.size(1)
          beta = tmp_target_weights.size(0)
        for i in range(len(base_weights)):
            if is_conv:
                base_weights[i] = torch.matmul(base_weights[i].permute(2, 3, 0, 1),
                                               prev_pi[i] * beta_prev)
                base_weights[i] = base_weights[i].permute(2, 3, 0, 1)
            elif is_lstm_layer:
                #print('Size of base_weights', i, 'before times the coulpings of previous layer', base_weights[i].size())
                base_weights[i] = torch.matmul(base_weights[i],
                                               prev_pi[i] * beta_prev)
                base_weights[i] = base_weights[i].permute(1, 2, 0)
                #print('Size of base_weights', i, 'after times the coulpings of previous layer', base_weights[i].size())
            else:
                base_weights[i] = torch.matmul(base_weights[i], prev_pi[i] * beta_prev)

        if is_last_layer:
            pi = []
            # For last layer, coupling is provided by identity coupling.
            for i in range(num_models):
                n = target_weights.size(0)
                cur_pi = torch.eye(n, dtype=torch.float) / (1.0 * n)
                if torch.cuda.is_available():
                    cur_pi = cur_pi.cuda()
                pi.append(cur_pi)
        else:
            if prev_similar_pi is None:
                if self.args.ad_hoc_cost_choice == 'activation':
                    pi = self.get_activation_coupling(layer=layer)
                else:
                    pi = self.get_weights_coupling(base_weights, tmp_target_weights, prev_pi, layer == 1)
            else:
                pi = prev_similar_pi
        #print('Shape of couplings of current layer:', pi[0].size())

        for i in range(len(base_weights)):
            if is_conv:
                base_weights[i] = torch.matmul(beta * torch.transpose(pi[i], 0, 1),
                                               base_weights[i].permute(2, 3, 0, 1))
                base_weights[i] = base_weights[i].permute(2, 3, 0, 1)
            elif is_lstm_layer:
                base_weights[i] = torch.matmul(beta * torch.transpose(pi[i], 0, 1),
                                               base_weights[i].permute(2, 0, 1))
            else:
                base_weights[i] = torch.matmul(beta * torch.transpose(pi[i], 0, 1),
                                               base_weights[i])

        weights = None
        if 'model_weights' not in self.args.__dict__ or self.args.model_weights is None:
            model_weights = [1.0 / len(pi)] * len(pi)
        else:
            model_weights = self.args.model_weights
        for i, base_weight in enumerate(base_weights):
            if weights is None:
                weights = model_weights[i] * base_weight
            else:
                weights += model_weights[i] * base_weight
        target_weights.data = weights.data
        return pi

    def fuse_single_layer(self, layer, prev_pi):
        logging.info('Fusing layer {}'.format(layer))
        base_weights = []
        for model in self.base_models:
            base_weights.append(model.get_layer_weights(layer_num=layer))
        target_weights = self.target_model.get_layer_weights(layer_num=layer)
        return self.fuse_single_layer_helper(len(self.base_models), layer, layer == self.target_model.num_layers,
                                             base_weights, target_weights, prev_pi)

    def get_weights_coupling(self, base_weights, target_weights, prev_pi,
                             is_first_layer=False):
        # Coupling from base weights to target_weights.
        pi = []
        is_conv = len(target_weights.size()) == 4
        is_lstm_layer = len(target_weights.size()) == 3
        with torch.no_grad():
            w = target_weights.unsqueeze(0)
            #print('Shape of w after unsqueeze:', w.size())
            for w_i, prev_pi_i in zip(base_weights, prev_pi):
                w_i = w_i.unsqueeze(1)
                #print('Shape of w_i after unsqueeze:', w_i.size())
                try:
                    cost = (w - w_i).pow(2).sum(dim=-1)
                except RuntimeError as e:
                    error = "{}".format(e)
                    if error.startswith("CUDA out of memory."):
                        cost_arr = []
                        for w_i_row in w_i:
                            row_cost = (w - w_i_row.unsqueeze(0)).pow(2).sum(dim=-1)
                            cost_arr.append(row_cost)
                        cost = torch.cat(cost_arr, dim=0)
                    else:
                        print(error)
                        raise ImportError(e)
                if is_conv:
                    cost = cost.sum(-1).sum(-1)
                elif is_lstm_layer:
                    cost = cost.sum(-1)
                # Different algorithms for solving Linear equations goes here.
                new_pi = self.ot_solver(cost)
                pi.append(new_pi)
            return pi

    def get_activation_coupling(self, layer):
        with torch.no_grad():
            target_model_activations = self.target_model.get_activations(self.data,
                                                                         layer_num=layer)
            target_model_activations = target_model_activations.reshape(target_model_activations.size(0), -1)
            pi = []
            for model in self.base_models:
                activations = model.get_activations(self.data, layer_num=layer,
                                                    pre_activations=self.args.use_pre_activations)
                activations = activations.reshape(activations.size(0), -1)
                cost = memory_safe_methods.get_activation_cost(activations, target_model_activations)
                # Different algorithms for solving Linear equations can go here.
                cur_pi = self.ot_solver(cost)
                pi.append(cur_pi)
        return pi

    def ot_solver(self, cost):
        if self.args.ad_hoc_ot_solver == 'sinkhorn':
            epsilon = self.args.ad_hoc_sinkhorn_regularization
            pi, _ = ot_algorithms.sinkhorn_coupling(cost, epsilon=epsilon, niter=100)
            return pi
        elif self.args.ad_hoc_ot_solver == 'emd':
            pi, _ = ot_algorithms.ot_emd_coupling(cost)
            return pi
        else:
            raise NotImplementedError


class OTFusionVGG(OTFusion):
    """
    Handles the OT Fusion of VGG network.
    For VGG networks, one needs to separately handle the case of first linear layer
    in the classifier head.
    Currently the code handles only same VGG network architectures.
    """

    def __init__(self, args, train_init, base_models, target_model, data):
        super().__init__(args, train_init, base_models, target_model, data)
        self.is_linear_layer = False

    def fuse_single_layer(self, layer, prev_pi):
        """
        This method is overridden to handle the case of first linear layer in
        VGG networks. The is done since the first linear layer comes after
        Adaptive avg pooling layer whose output is (7x7).
        """
        logging.info('Fusing layer {}'.format(layer))
        # Fuses a singe layer of the networks.
        base_weights = []
        for model in self.base_models:
            base_weights.append(model.get_layer_weights(layer_num=layer))
        target_weights = self.target_model.get_layer_weights(layer_num=layer)
        logging.info('Target weight dimensions {}'.format(str(target_weights.size())))
        if not self.is_linear_layer and len(target_weights.size()) == 2:
            # To make use of the existing code, we convert the first linear
            # layer weights to cxdx7x7 format. This is to make use of the prev_pi.
            logging.info('First linear layer for VGG')
            self.is_linear_layer = True
            for i in range(len(base_weights)):
                base_weights[i] = base_weights[i].view((base_weights[i].size(0), -1, 7, 7))
            target_weights = target_weights.view((target_weights.size(0), -1, 7, 7))

        return self.fuse_single_layer_helper(len(self.base_models), layer, layer == self.target_model.num_layers,
                                             base_weights, target_weights, prev_pi)


class OTFusionResnet(OTFusion):
    """
    Handles the OT Fusion of Resnet network.
    """

    def __init__(self, args, train_init, base_models, target_model, data):
        super().__init__(args, train_init, base_models, target_model, data)
        self.prev_pi_list = []

    def fuse_single_layer(self, layer, prev_pi):
        """
        This method is overridden to handle the case of skip connections.
        """
        logging.info('Fusing layer {}'.format(layer))
        if layer == 1:
            # For the first layer the prev_pi would not have been added to list.
            self.prev_pi_list.append(prev_pi)
        # Fuses a singe layer of the networks.
        base_weights = []
        for model in self.base_models:
            w = model.get_layer_weights(layer_num=layer)
            base_weights.append(w)
        # For the adhoc OT fusion, there is no notion of handling the skip connections.
        target_weights = self.target_model.get_layer_weights(layer_num=layer)
        prev_layers_list = self.target_model.get_prev_layers(layer_num=layer)
        if self.args.resnet_skip_connection_handling == 'pre':
            prev_similar_layer = self.target_model.get_prev_similar_layer(layer_num=layer)
            if prev_similar_layer is not None:
                cur_pi = self.prev_pi_list[prev_similar_layer]
                _ = self.fuse_single_layer_helper(len(self.base_models), layer, layer == self.target_model.num_layers,
                                                  base_weights, target_weights,
                                                  self.prev_pi_list[prev_layers_list[0]], prev_similar_pi=cur_pi)
                self.prev_pi_list.append(cur_pi)
                return cur_pi
        logging.info('Target weight dimensions {}'.format(str(target_weights.size())))
        cur_pi = self.fuse_single_layer_helper(len(self.base_models), layer, layer == self.target_model.num_layers,
                                               base_weights, target_weights,
                                               self.prev_pi_list[prev_layers_list[0]])
        if self.args.tlp_cost_choice == 'activation':
            cur_pi = self.get_activation_coupling(layer=layer)
        self.prev_pi_list.append(cur_pi)
        return cur_pi


class OTFusionRNN(OTFusion):
    """
    Handles the OT Fusion of RNN units;
    For RNN units, the fusion happens using the input to hidden weights;
    The hidden to hidden weights need to be adjusted according to the couplings.
    """

    def __init__(self, args, train_init, base_models, target_model, data):
        super().__init__(args, train_init, base_models, target_model, data)

    def fuse_single_layer(self, layer, prev_pi):
        logging.info('Fusing layer {}'.format(layer))
        # Fuses a singe layer of the networks.
        base_input_weights_arr = []
        base_hidden_weights_arr = []
        for model in self.base_models:
            base_input_weights, base_hidden_weights = model.get_layer_weights(layer_num=layer)
            base_input_weights_arr.append(base_input_weights)
            base_hidden_weights_arr.append(base_hidden_weights)
        target_weights, target_hidden_weights = self.target_model.get_layer_weights(layer_num=layer)
        pi = self.fuse_single_layer_helper(len(self.base_models), layer,
                                           layer == self.target_model.num_layers,
                                           base_input_weights_arr, target_weights, prev_pi)
        print('Size of input to hidden weights:', target_weights.size())
        is_lstm_layer = len(target_weights) == 3
        if target_hidden_weights is not None:
            # The input to hidden weights coupling is applied to orient the
            # hidden to hidden layer weights.
            weights = None
            if 'model_weights' not in self.args.__dict__ or self.args.model_weights is None:
                model_weights = [1.0 / len(pi)] * len(pi)
            else:
                model_weights = self.args.model_weights
            for i, (pi_i, w_i) in enumerate(zip(pi, base_hidden_weights_arr)):
                k_l_0 = pi_i.size(0)
                #print('Shape of pi_i', pi_i.size())
                #print('Shape of w_i', w_i.size())
                if is_lstm_layer:
                    weights_i = torch.matmul(pi_i.transpose(0, 1), torch.matmul(w_i, k_l_0 * k_l_0 * pi_i))
                else:
                    weights_i = torch.matmul(pi_i.transpose(0, 1), torch.matmul(w_i, k_l_0 * k_l_0 * pi_i))
                weights_i *= model_weights[i]
                if weights is None:
                    weights = weights_i
                else:
                    weights += weights_i
            target_hidden_weights.data = weights.data
        if isinstance(self.target_model, (rnn_models.LSTMWithDecoder, rnn_models.LSTMWithEncoderDecoder)):
          self.target_model.update_layer_weights(layer, target_weights, target_hidden_weights)
        return pi


########## TESTS ############

def test_ad_hoc_fusion_fuse_runs():
    # Fast testing of all the methods in TLP Fusion class.
    from src.tlp_model_fusion import model
    input_dim = 10
    output_dim = 10
    hidden_dims = [11, 12, 13]
    hidden_dims_2 = [12, 13, 14]
    sample_size = 8

    target_model = model.FCModel(input_dim=input_dim, hidden_dims=hidden_dims, output_dim=output_dim)
    base_models = []
    num_models = 5
    for i in range(num_models):
        new_model = model.FCModel(input_dim=input_dim, hidden_dims=hidden_dims if i % 2 == 0 else hidden_dims_2,
                                  output_dim=output_dim)
        base_models.append(new_model)
    data = torch.rand(sample_size, input_dim)

    if torch.cuda.is_available():
        data = data.cuda()

    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--ad_hoc_cost_choice', type=str, default=None)
    parser.add_argument('--ad_hoc_ot_solver', type=str, default='sinkhorn')
    parser.add_argument('--ad_hoc_sinkhorn_regularization', type=float, default=0.1)
    args = parser.parse_args("")

    fusion = OTFusion(args=args, target_model=target_model, base_models=base_models,
                      data=data)
    fusion.fuse()


def test_ad_hoc_fusion_fuse_runs_for_resnet():
    # Fast testing of all the methods in TLP Fusion class.
    from src.tlp_model_fusion import resnet_models
    target_model = resnet_models.resnet18(num_classes=10)
    base_models = []
    num_models = 2
    for i in range(num_models):
        new_model = resnet_models.resnet18(num_classes=10)
        base_models.append(new_model)

    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--ad_hoc_cost_choice', type=str, default=None)
    parser.add_argument('--ad_hoc_ot_solver', type=str, default='sinkhorn')
    parser.add_argument('--ad_hoc_sinkhorn_regularization', type=float, default=0.1)
    parser.add_argument('--resnet_skip_connection_handling', type=str, default=['pre'],
                        choices=['pre', 'post'],
                        help='Pre means use pis from previously similar layer, post means handle later')
    args = parser.parse_args("")

    args.resnet_skip_connection_handling = 'pre'
    fusion = OTFusionResnet(args=args, target_model=target_model, base_models=base_models,
                            data=None)
    fusion.fuse()

    args.resnet_skip_connection_handling = 'post'
    target_model = resnet_models.resnet18(num_classes=10)
    fusion = OTFusionResnet(args=args, target_model=target_model, base_models=base_models,
                            data=None)
    fusion.fuse()

    print('Ad hoc OT fusion runs for Resnet')


def test_ad_hoc_rnn_fusion_fuse_runs():
    # Fast testing of all the methods in TLP Fusion class.
    from src.tlp_model_fusion import model
    input_dim = 10
    output_dim = 10
    hidden_dims = [10]
    n_steps = 28

    target_model = model.ImageRNN(n_inputs=input_dim, n_neurons=hidden_dims,
                                  n_outputs=output_dim, n_steps=n_steps)
    base_models = []
    num_models = 2
    for i in range(num_models):
        new_model = model.ImageRNN(n_inputs=input_dim, n_neurons=hidden_dims,
                                   n_outputs=output_dim, n_steps=n_steps)
        base_models.append(new_model)
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--ad_hoc_cost_choice', type=str, default=None)
    parser.add_argument('--ad_hoc_ot_solver', type=str, default='sinkhorn')
    parser.add_argument('--ad_hoc_sinkhorn_regularization', type=float, default=0.1)
    args = parser.parse_args("")

    fusion = OTFusionRNN(args=args, target_model=target_model, base_models=base_models,
                         data=None)
    fusion.fuse()


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    # test_ad_hoc_fusion_fuse_runs()
    # test_ad_hoc_fusion_fuse_runs_for_resnet()
    test_ad_hoc_rnn_fusion_fuse_runs()
