# OTfusion Library Fuctions
#
# Author: Moritz Imfeld <moimfeld@ethz.ch>


# Imports
import torch, logging, copy, ot
import numpy as np

from ground_metric import GroundMetric
from utils import matrix_stats, dict_get, dict_write, matrix_to_heatmap

#----------------#
# Encoder Fusion #
#----------------#
def encoder_fusion(args: dict, keys: dict, w_0: dict, w_1: dict, acts_0: dict, acts_1: dict, t_in: torch.Tensor, last_layer, device: torch.device, enc_key, log, alpha = 0.5, prev_out_acts = None):
    '''
    ## Description
    Performs OTFusion of two encoder layers.

    ------
    ## Parameters
    `args` Dictionary from YAML-based configuration file\\
    `keys`: Dictionary containing key lists to access data from the nested weight and acts dicts. The key list must be ordered in the order of the access of the nested dictionary.\\
    `w_0` Dictionary containing all weights of a encoder layer of model 0\\
    `w_1` Dictionary containing all weights of a encoder layer of model 1\\
    `acts_0` Dictionary containing all activations of a encoder layer of model 0\\
    `acts_1` Dictionary containing all activations of a encoder layer of model 1\\
    `t_in` Input transportation map (set to 'None' if previous layer had no permutations)\\
    `last_layer` Flag that indicates if this is the last layer that will be fused in the current experiment\\
    `device` torch.device()\\
    `enc_key` Used for logging\\
    `log` logging instance\\
    `prev_out_acts` Output activations of the last layer (needed to weigh residual transporation map)
    ------
    ## Outputs
    `t_out`   Transportation map of current layer\\
    `w_fused` Fused Weights
    '''

    # Retrieve key lists
    ln0_keys = keys['enc_ln0_keys']
    ln1_keys = keys['enc_ln1_keys']
    sa_keys  = keys['enc_sa_keys']
    ff0_keys = keys['enc_ff0_keys']
    ff1_keys = keys['enc_ff1_keys']

    # Init w_fused
    w_fused = {}

    # save transportation map for residual connection
    t_resid = t_in

    # Fuse Layer Normalization Layer (for bert there is no layernorm at the input)
    if args['model']['type'] != 'hf_bert_masked' and args['model']['type'] != 'hf_bert_class':
        if (args['fusion']['fuse_norm']):
            log.info(' Fusing encoder {0}: norm 0'.format(enc_key))
            w_norm_0_fused, t_out = ln_fusion(args = args, keys = keys, t_in = t_in, w_0 = dict_get(ln0_keys, w_0), w_1 = dict_get(ln0_keys, w_1), device = device)
        else:
            w_norm_0_fused = copy.deepcopy(dict_get(ln1_keys, w_0))
            t_out = t_in
    else:
        t_out = t_in

    # Fuse Self-Attention Layer
    if (args['fusion']['fuse_sa']):
        log.info(' Fusing encoder {0}: self-attention'.format(enc_key))
        w_self_attention_fused, t_out = act_attention_fusion(args = args, keys = keys, w_sa_0 = dict_get(sa_keys, w_0), w_sa_1 = dict_get(sa_keys, w_1),
                                                             acts_sa_0 = dict_get(sa_keys, acts_0), acts_sa_1 = dict_get(sa_keys, acts_1),
                                                             t_q_in = t_out, t_k_in = t_out, t_v_in = t_out, device = device, log = log, last_layer = last_layer and not args['fusion']['fuse_fc'])
    else:
        w_self_attention_fused = copy.deepcopy(dict_get(sa_keys, w_0))

    # Apply residual connection policy
    # Note: For the weighted policy, the activations generated by the self-attention are used to weight t_in and
    #       the activations generated by the previous layer are used to weight t_resid.
    if t_in != None:
        t_out = resid_policy(policy = args.get('fusion').get('resid_policy'), t_resid = t_resid, t_in = t_out,
                             resid_acts = prev_out_acts, in_acts = dict_get(sa_keys + ['data'], acts_1), log = log)
    t_resid = t_out

    # Fuse Layer Normalization Layer
    if args['fusion']['fuse_norm'] and args['model']['type'] != 'hf_bert_masked' and args['model']['type'] != 'hf_bert_class':
        log.info(' Fusing encoder {0}: norm 1'.format(enc_key))
        w_norm_1_fused, t_out = ln_fusion(args = args, keys = keys, t_in = t_out, w_0 = dict_get(ln1_keys, w_0), w_1 = dict_get(ln1_keys, w_1), device = device)
    elif args['fusion']['fuse_norm'] and (args['model']['type'] == 'hf_bert_masked' or args['model']['type'] == 'hf_bert_class'):
        log.info(' Fusing encoder {0}: norm 0'.format(enc_key))
        w_norm_0_fused, t_out = ln_fusion(args = args, keys = keys, t_in = t_out, w_0 = dict_get(ln0_keys, w_0), w_1 = dict_get(ln0_keys, w_1), device = device)
    elif args['model']['type'] != 'hf_bert_masked' and args['model']['type'] != 'hf_bert_class':
        w_norm_1_fused = copy.deepcopy(dict_get(ln1_keys, w_0))
    else:
        w_norm_0_fused = copy.deepcopy(dict_get(ln0_keys, w_0))

    # Fuse Fully Connected Layers
    if (args['fusion']['fuse_fc']):
        log.info(' Fusing encoder {0}: fully-connected; layer 0'.format(enc_key))
        w_ff0_fused, t_out = fc_fusion(args = args, keys = keys, t_in = t_out, w_0 = dict_get(ff0_keys, w_0),
                                             w_1 = dict_get(ff0_keys, w_1),
                                             act_0 = dict_get(ff0_keys, acts_0),
                                             act_1 = dict_get(ff0_keys, acts_1),
                                             device = device, log = log)
        log.info(' Fusing encoder {0}: fully-connected; layer 1'.format(enc_key))
        w_ff1_fused, t_out = fc_fusion(args = args, keys = keys, t_in = t_out, w_0 = dict_get(ff1_keys, w_0),
                                             w_1 = dict_get(ff1_keys, w_1),
                                             act_0 = dict_get(ff1_keys, acts_0),
                                             act_1 = dict_get(ff1_keys, acts_1),
                                             device = device, log = log, last_layer=last_layer)
    else:
        w_ff0_fused = copy.deepcopy(dict_get(ff0_keys, w_0))
        w_ff1_fused = copy.deepcopy(dict_get(ff1_keys, w_0))


    # Fuse second BERT layernorm
    if args['model']['type'] == 'hf_bert_masked' or args['model']['type'] == 'hf_bert_class':
        if (args['fusion']['fuse_norm']):
            log.info(' Fusing encoder {0}: norm 1'.format(enc_key))
            w_norm_1_fused, t_out = ln_fusion(args = args, keys = keys, t_in = t_out, w_0 = dict_get(ln1_keys, w_0), w_1 = dict_get(ln1_keys, w_1), device = device)
        else:
            w_norm_1_fused = copy.deepcopy(dict_get(ln1_keys, w_0))


    # Apply resid connection policy (first activations to weight t_resid have to be calculated)
    if not last_layer:
        resid_acts = torch.add(dict_get(sa_keys + ['data'], acts_1), prev_out_acts)
        t_out = resid_policy(policy = args.get('fusion').get('resid_policy'), t_resid = t_resid, t_in = t_out, resid_acts = resid_acts, in_acts = dict_get(ff1_keys, acts_1), log = log)

    # write weights to w_fused dict
    dict_write(w_fused, sa_keys, w_self_attention_fused)
    dict_write(w_fused, ln0_keys, w_norm_0_fused)
    dict_write(w_fused, ln1_keys, w_norm_1_fused)
    dict_write(w_fused, ff0_keys, w_ff0_fused)
    dict_write(w_fused, ff1_keys, w_ff1_fused)
    return w_fused, t_out

#------------------#
# Attention Fusion #
#------------------#
def act_attention_fusion(args: dict, keys: dict, w_sa_0: dict, w_sa_1: dict, acts_sa_0: dict, acts_sa_1: dict, t_q_in: torch.Tensor,
                         t_k_in: torch.Tensor, t_v_in: torch.Tensor, device: torch.device, log: logging.Logger, alpha = 0.5, last_layer = False):
    '''
    ## Description
    Performs OTFusion of two attention layers.

    ------
    ## Parameters
    `args` Dictionary from YAML-based configuration file\\
    `keys`: Dictionary containing key lists to access data from the nested weight and acts dicts. The key list must be ordered in the order of the access of the nested dictionary.\\
    `w_sa_0` Dictionary containing all weights of a attention layer of model 0\\
    `w_sa_1` Dictionary containing all weights of a attention layer of model 1\\
    `acts_sa_0` Dictionary containing all activations of a attention layer of model 0\\
    `acts_sa_1` Dictionary containing all activations of a attention layer of model 1\\
    `t_q_in` Input transportation map for w_q (set to 'None' if previous layer had no permutations)\\
    `t_k_in` Input transportation map for w_k (set to 'None' if previous layer had no permutations)\\
    `t_v_in` Input transportation map for w_v (set to 'None' if previous layer had no permutations)\\
    `alpha` Weighting parameter for anker model\\
    `device` torch.device()\\
    `log` logging instance
    ------
    ## Outputs
    `t_out`   Transportation map of current layer\\
    `w_fused` Fused Weights
    '''

    # Retrieve key lists
    w_q_keys   = keys['w_q']
    w_k_keys   = keys['w_k']
    w_v_keys   = keys['w_v']
    w_o_keys   = keys['w_o']

    # init
    w_sa_fused = {}

    # fuse
    if args.get('fusion', {}).get('qk_fusion', 'separate') == 'separate':
        log.info(' Calculating separate t_map for W_Q and W_K')
        log.info(' Fusing W_Q')
        w_q_fused, t_q_out   = fc_fusion(args = args, keys = keys, t_in = t_q_in, w_0 = dict_get(w_q_keys, w_sa_0), w_1 = dict_get(w_q_keys, w_sa_1), act_0 = dict_get(w_q_keys, acts_sa_0), act_1 = dict_get(w_q_keys, acts_sa_1), device = device, log = log)

        log.info(' Fusing W_K')
        w_k_fused, t_k_out   = fc_fusion(args = args, keys = keys, t_in = t_k_in, w_0 = dict_get(w_k_keys, w_sa_0), w_1 = dict_get(w_k_keys, w_sa_1), act_0 = dict_get(w_k_keys, acts_sa_0), act_1 = dict_get(w_k_keys, acts_sa_1), device = device, log = log)
    elif args.get('fusion', {}).get('qk_fusion', 'separate') == 'eq_t_map':
        log.info(' Calculating one single t_map for both W_Q and W_K')
        log.info(' Fusing W_Q')
        if args['fusion']['type'] == 'acts':
            w_q_fused, t_q_out   = fc_fusion(args = args, keys = keys, t_in = t_q_in, w_0 = dict_get(w_q_keys, w_sa_0), w_1 = dict_get(w_q_keys, w_sa_1), act_0 = torch.cat((dict_get(w_k_keys, acts_sa_0), dict_get(w_q_keys, acts_sa_0)), dim = 0), act_1 = torch.cat((dict_get(w_k_keys, acts_sa_1), dict_get(w_q_keys, acts_sa_1)), dim = 0), device = device, log = log)
        else:
            # weights are also passed as activations for the 'wts' + 'eq_t_map' configuration, additionally the wts_eq_t_map is flag is set in the fc_fusion function
            w_q_fused, t_q_out   = fc_fusion(args = args, keys = keys, t_in = t_q_in, w_0 = dict_get(w_q_keys, w_sa_0), w_1 = dict_get(w_q_keys, w_sa_1), act_0 = torch.cat((dict_get(w_k_keys, w_sa_0)['weight'], dict_get(w_q_keys, w_sa_0)['weight']), dim = 0), act_1 = torch.cat((dict_get(w_k_keys, w_sa_1)['weight'], dict_get(w_q_keys, w_sa_1)['weight']), dim = 0), device = device, log = log, wts_eq_t_map = True)
        log.info(' Fusing W_K')
        if args['fusion']['type'] == 'acts':
            w_k_fused, t_k_out   = fc_fusion(args = args, keys = keys, t_in = t_k_in, w_0 = dict_get(w_k_keys, w_sa_0), w_1 = dict_get(w_k_keys, w_sa_1), act_0 = torch.cat((dict_get(w_k_keys, acts_sa_0), dict_get(w_q_keys, acts_sa_0)), dim = 0), act_1 = torch.cat((dict_get(w_k_keys, acts_sa_1), dict_get(w_q_keys, acts_sa_1)), dim = 0), device = device, log = log)
        else:
            # weights are also passed as activations for the 'wts' + 'eq_t_map' configuration, additionally the wts_eq_t_map is flag is set in the fc_fusion function
            w_k_fused, t_k_out   = fc_fusion(args = args, keys = keys, t_in = t_k_in, w_0 = dict_get(w_k_keys, w_sa_0), w_1 = dict_get(w_k_keys, w_sa_1), act_0 = torch.cat((dict_get(w_k_keys, w_sa_0)['weight'], dict_get(w_q_keys, w_sa_0)['weight']), dim = 0), act_1 = torch.cat((dict_get(w_k_keys, w_sa_1)['weight'], dict_get(w_q_keys, w_sa_1)['weight']), dim = 0), device = device, log = log, wts_eq_t_map = True)
    elif args.get('fusion', {}).get('qk_fusion', 'separate') == 'joint':
        log.info(' Joint W_K and W_Q fusion')
        w_k0 = dict_get(w_k_keys, w_sa_0)
        w_k1 = dict_get(w_k_keys, w_sa_1)
        w_q0 = dict_get(w_q_keys, w_sa_0)
        w_q1 = dict_get(w_q_keys, w_sa_1)
        a_k0 = dict_get(w_k_keys, acts_sa_0)
        a_k1 = dict_get(w_k_keys, acts_sa_1)
        a_q0 = dict_get(w_q_keys, acts_sa_0)
        a_q1 = dict_get(w_q_keys, acts_sa_1)
        w_qk0_join =  {}
        a_qk0_join =  {}
        w_qk1_join =  {}
        a_qk1_join =  {}
        w_qk0_join['weight'] = torch.cat((w_q0['weight'], w_k0['weight']), dim = 0)
        w_qk0_join['bias']   = torch.cat((w_q0['bias'], w_k0['bias']), dim = 0)
        a_qk0_join           = torch.cat((a_k0, a_q0), dim = 1)
        w_qk1_join['weight'] = torch.cat((w_q1['weight'], w_k1['weight']), dim = 0)
        w_qk1_join['bias']   = torch.cat((w_q1['bias'], w_k1['bias']), dim = 0)
        a_qk1_join           = torch.cat((a_k1, a_q1), dim = 1)
        t_qk_in              = t_q_in # The assumption is that the transportation map is the same for both Q and K
        w_q_fused = {}
        w_k_fused = {}
        w_qk_fused, t_qk_out = fc_fusion(args = args, keys = keys, t_in = t_qk_in, w_0 = w_qk0_join, w_1 = w_qk1_join, act_0 = a_qk0_join, act_1 = a_qk1_join, device = device, log = log)
        w_q_fused['weight'], w_k_fused['weight'] = torch.chunk(w_qk_fused['weight'], chunks=2, dim=0)
        w_q_fused['bias'], w_k_fused['bias']     = torch.chunk(w_qk_fused['bias'], chunks=2, dim=0)
    else:
        raise NotImplementedError
    log.info(' Fusing W_V')
    w_v_fused, t_v_out   = fc_fusion(args = args, keys = keys, t_in = t_v_in, w_0 = dict_get(w_v_keys, w_sa_0), w_1 = dict_get(w_v_keys, w_sa_1), act_0 = dict_get(w_v_keys, acts_sa_0), act_1 = dict_get(w_v_keys, acts_sa_1), device = device, log = log)
    if (args['fusion'].get('fusion_t_w_in') == None) or (not args['fusion']['fusion_t_w_in']):
        t_w_o_in             = t_v_out
    else:
        _, t_w_o_in = fc_fusion(args = args, keys = keys, t_in = None, w_0 = w_sa_0['3'], w_1 = w_sa_1['3'], act_0 = acts_sa_0['intermediate_attn'].squeeze(dim = 1), act_1 = acts_sa_1['intermediate_attn'].squeeze(dim = 1), device = device, log = log, last_layer = last_layer, fusion_t_w_in=True)

    log.info(' Fusing W_O')
    w_o_fused, t_w_o_out = fc_fusion(args = args, keys = keys, t_in = t_w_o_in, w_0 = dict_get(w_o_keys, w_sa_0), w_1 = dict_get(w_o_keys, w_sa_1), act_0 = dict_get(w_o_keys, acts_sa_0), act_1 = dict_get(w_o_keys, acts_sa_1), device = device, log = log, last_layer = last_layer)

    dict_write(w_sa_fused, w_q_keys, w_q_fused)
    dict_write(w_sa_fused, w_k_keys, w_k_fused)
    dict_write(w_sa_fused, w_v_keys, w_v_fused)
    dict_write(w_sa_fused, w_o_keys, w_o_fused)
    t_out = t_w_o_out
    return w_sa_fused, t_out

#------------------------------#
# Fully Connected Layer Fusion #
#------------------------------#
def fc_fusion(args: dict, keys: dict, t_in: torch.Tensor , w_0: torch.Tensor, w_1: torch.Tensor, act_0: torch.Tensor, act_1: torch.Tensor, device: torch.device, log: logging.Logger, alpha = 0.5, last_layer = False, is_embed = False, is_vit_fc = False, is_vit_embed = False, fusion_t_w_in=False, wts_eq_t_map = False):
    '''
    ## Description
    Performs OTFusion of two fully connected layers.

    1. align weights w.r.t. transportation map of previous layer
    2. compute mu and nu
    3. compute ground metric
    4. compute transprotation map (`t_out`)
    5. normalize `t_out` with marginals
    6. align weights w.r.t. current layer (`t_out`)
    7. fusion
    ------
    ## Parameters
    `args`   Dictionary from YAML-based configuration file\\
    `keys`  Dictionary containing key lists to access data from the nested weight and acts dicts. The key list must be ordered in the order of the access of the nested dictionary.\\
    `t_in`   Transportation map of the previous layer (set to 'None' if previous layer had no permutations)\\
    `w_0`    Weights of current layer model 0\\
    `w_1`    Weights of current layer model 1\\
    `act_0`  Activations of current layer model 0\\
    `act_1`  Activations of current layer model 1\\
    `alpha` Weighting parameter for anker model\\
    `device` torch.device()\\
    `last_layer` Flag that indicates if this is the last layer that will be fused in the current experiment\\
    `is_embed` Flag to indicate that embeddings are fused (need to transpose weight matrix)\\
    `is_vit_fc` Flag to indicate that ViT fully-connected layer is fused (don't need to flatten activations)
    `is_vit_embed` - 
    `fusion_t_w_in` -
    ------
    ## Outputs
    `t_out`   Transportation map of current layer\\
    `w_fused` Fused Weights
    '''

    # Retrieve key lists
    w_keys  = keys['weights']
    b_keys  = keys['bias']

    # Init
    t_out   = None
    w_fused = {}
    gm      = GroundMetric(args)
    if args['fusion']['fuse_bias'] and not is_embed:
        bias_0 = dict_get(b_keys, w_0)
        bias_1 = dict_get(b_keys, w_1)
    if not is_vit_embed:
        w_0 = dict_get(w_keys, w_0)
        w_1 = dict_get(w_keys, w_1)
    else:
        w_0 = w_0.data.squeeze(dim = 1)
        w_1 = w_1.data.squeeze(dim = 1)

    # align weights with t_in (if t_in != None)
    if (not is_embed) or (args['fusion']['type']=='acts'):
        w_0_aligned = w_0 if t_in == None else torch.matmul(w_0, t_in)  
    else:
        w_0_aligned = w_0 if t_in == None else torch.matmul(w_0, t_in.t())

    if not last_layer:
        # mu and nu calculation
        if not is_embed:
            mu_cardinality = w_0.shape[0]
            nu_cardinality = w_1.shape[0]
        elif fusion_t_w_in:
            mu_cardinality = act_0.shape[-1]
            nu_cardinality = act_1.shape[-1]
        else:
            mu_cardinality = w_0.shape[-1]
            nu_cardinality = w_1.shape[-1]
        mu = np.divide(np.ones(mu_cardinality), mu_cardinality)
        nu = np.divide(np.ones(nu_cardinality), nu_cardinality)
        # print(f"mu shape: {mu.shape}")  # 应该是 (1024,)
        # print(f"nu shape: {nu.shape}")  # 应该是 (1024,)
        if args['fusion']['type'] == 'acts':
            # process activations with PCA if set in config
            if args.get('fusion', {}).get('pca', False) == True:
                def get_projection_PCA(matrix, k):
                    mean = torch.mean(matrix, dim=0)

                    # Subtract the mean from the matrix
                    centered = matrix - mean
                    
                    # svd decomposition
                    U,S,V = torch.linalg.svd(centered)
                    selected_U = U[:,:k]
                    
                    return selected_U
                
                PCA_k = int(args.get('fusion', {}).get('pca_k', '1000'))

                projection_1 = get_projection_PCA(act_1, PCA_k)
                
                mean_act_1 = torch.mean(act_1, dim=0)
                
                act_0 = torch.matmul(projection_1.t(), act_0 - mean_act_1)
                act_1 = torch.matmul(projection_1.t(), act_1 - mean_act_1)

            M0 = gm.process(act_0.t(), act_1.t())

        else:
            if not is_embed:
                if wts_eq_t_map:
                    # Special handling for eq_t_map qk_fusion in wts based type fusion (function assumes that wts of both models are passed as act_0 and act_1)
                    # 1. align w.r.t. incoimg t_map
                    wts_0_q, wts_0_k = torch.split(act_0, act_0.shape[1], dim = 0)
                    wts_1_q, wts_1_k = torch.split(act_1, act_1.shape[1], dim = 0)
                    wts_0_q_aligned = torch.matmul(wts_0_q, t_in)
                    wts_0_k_aligend = torch.matmul(wts_0_k, t_in)
                    wts_0_qk_aligend = torch.cat((wts_0_q_aligned, wts_0_k_aligend), dim = 1)
                    wts_1_qk_reshaped = torch.cat((wts_1_q, wts_1_k), dim = 1)
                    # 2. generate ground metric from aligned weights
                    M0 = gm.process(wts_0_qk_aligend, wts_1_qk_reshaped)
                else:
                    M0 = gm.process(w_0_aligned, w_1)
            else:
                M0 = gm.process(w_0_aligned.t(), w_1.t())

        if args.get('fusion', {}).get('ot_solver', 'emd') == 'emd':
            log.info(' Using emd solver to calculate t_map for this layer')
            t_numpy = ot.emd(mu, nu, M0)
        elif args.get('fusion', {}).get('ot_solver', 'emd') == 'sinkhorn':
            log.info(' Using sinkhorn solver to calculate t_map for this layer')
            # print(f"mu shape: {mu.shape}")  # 应该是 (1024,)
            # print(f"nu shape: {nu.shape}")  # 应该是 (1024,)
            # print(f"M shape: {M0.shape}")    # 应该是 (1024, 1024)
            t_numpy = ot.bregman.sinkhorn(mu, nu, M0, reg = float(args.get('fusion', {}).get('sinkhorn_reg', 1e-2)))
        elif args.get('fusion', {}).get('ot_solver', 'emd') == 'sinkhorn_for_widening':
            if t_in != None:
                if M0.shape[0] > t_in.shape[0] or M0.shape[1] > t_in.shape[1]:
                    log.info(' Using sinkhorn solver to calculate t_map for this layer')
                    t_numpy = ot.bregman.sinkhorn(mu, nu, M0, reg = float(args.get('fusion', {}).get('sinkhorn_reg', 1e-2)))
                else:
                    log.info(' Using emd solver to calculate t_map for this layer')
                    t_numpy = ot.emd(mu, nu, M0)
            else:
                log.info(' Using emd solver to calculate t_map for this layer')
                t_numpy = ot.emd(mu, nu, M0)
        else:
            raise NotImplementedError
        t_out   = torch.from_numpy(t_numpy).float().to(w_0_aligned.device)

        # normalize t_out with beta (marginals)
        beta  = 1 / t_out.shape[0]
        t_out = torch.mul(t_out, 1 / beta)

        if fusion_t_w_in:
            return None, t_out

        # align weights with t_out
        if not is_embed and not is_vit_embed:
            w_0_aligned = torch.matmul(t_out.t(), w_0_aligned)
        else:
            w_0_aligned = torch.matmul(t_out.t(), w_0_aligned.t()).t()
    # fuse aligned weights
    w_fused['weight'] = torch.add(w_0_aligned * (1-alpha), w_1 * alpha)

    # align and fuse bias
    if args['fusion']['fuse_bias'] and not is_embed:
        if not last_layer:
            bias_0_aligned  = torch.matmul(t_out.t(), bias_0)
        else:
            bias_0_aligned = bias_0
        w_fused['bias'] = torch.add(bias_0_aligned * (1-alpha), bias_1 * alpha)

    # debug statements
    if t_out != None:
        log.debug(matrix_stats(M0, 'ground metric'))
        log.debug(matrix_stats(t_out, 't_out'))

    # Free memory
    gm = None

    return w_fused, t_out

#-------------------#
# Layer Norm Fusion #
#-------------------#
def ln_fusion(args: dict, keys: dict, t_in: torch.Tensor , w_0: dict, w_1: dict, device: torch.device, alpha = 0.5):
    '''
    ## Description
    Performs OTFusion of a layer norm layer.

    1. align normalization weight vectors w.r.t. transportation map of previous layer
    2. fusion
    ------
    ## Parameters
    `args`   Dictionary from YAML-based configuration file\\
    `keys`   Dictionary containing key lists to access data from the nested weight and acts dicts. The key list must be ordered in the order of the access of the nested dictionary.\\
    `t_in`   Transportation map of the previous layer (set to 'None' if previous layer had no permutations)\\
    `w_0`    Layer normalization weight dictionary of model 0 (containing key a_2 and b_2)\\
    `w_1`    Layer normalization weight dictionary of model 1 (containing key a_2 and b_2)\\
    `alpha` Weighting parameter for anker model
    ------
    ## Outputs
    `t_out`   Transportation map of current layer\\
    `w_fused` Fused Weights (dictionary containing both a_2 and b_2 normalization weight vecotrs)
    '''

    # Retrieve key lists
    w_a_keys = keys['a']
    w_b_keys = keys['b']

    # Init
    w_fused = {}

    # Alignment
    w_a_0_aligned = dict_get(w_a_keys, w_0) if t_in == None else torch.matmul(dict_get(w_a_keys, w_0), t_in)
    w_b_0_aligned = dict_get(w_b_keys, w_0) if t_in == None else torch.matmul(dict_get(w_b_keys, w_0), t_in)

    # Fusion
    w_a_fused = torch.add(w_a_0_aligned * (1-alpha), dict_get(w_a_keys, w_1) * alpha)
    w_b_fused = torch.add(w_b_0_aligned * (1-alpha), dict_get(w_b_keys, w_1) * alpha)

    dict_write(w_fused, w_a_keys, w_a_fused)
    dict_write(w_fused, w_b_keys, w_b_fused)

    t_out = t_in

    return w_fused, t_out

def resid_policy(policy, t_resid, t_in, resid_acts, in_acts, log):
    if policy == None:
        log.info(' No residual connection policy defined; defaulting to "no_resid" policy')
        t_out = t_in
    elif policy == 'no_resid':
        log.info(' "no_resid" residual connection policy used; propagating t_out from MHA')
        t_out = t_in
    elif policy == 'only_resid':
        log.info(' "only_resid" residual connection policy used; propagating t_resid from residual connection')
        t_out = t_resid
    elif policy == 'resid_as_identity':
        log.info(' "resid_as_identity" residual connection policy used; propagating identity connection')
        t_out = torch.zeros_like(t_resid)
        t_out.fill_diagonal_(1)
    elif policy == 'mean':
        log.info(' "mean" residual connection policy used; propagating the average of t_resid and t_out from MHA')
        t_out = torch.add(t_resid, t_in)
        t_out = torch.div(t_out, 2)
    elif policy == 'weighted_scalar':
        resid_abs_mean = torch.mean(torch.abs(resid_acts))
        in_abs_mean    = torch.mean(torch.abs(in_acts))
        total = resid_abs_mean + in_abs_mean
        resid_weight = resid_abs_mean / total
        in_weight    = in_abs_mean    / total
        log.info(' "weighted_scalar" residual connection policy used; propagating {0:.4} * t_resid and {1:.4} * t_out from MHA'.format(resid_weight, in_weight))
        t_out = torch.add(torch.mul(resid_weight, t_resid), torch.mul(in_weight, t_in))
    elif policy == 'weighted_matrix':
        resid_abs_mean_vect = torch.mean(torch.abs(resid_acts), dim = 0)
        in_abs_mean_vect    = torch.mean(torch.abs(in_acts), dim = 0)
        total_vec           = torch.add(resid_abs_mean_vect, in_abs_mean_vect)
        resid_weight_vec    = torch.div(resid_abs_mean_vect, total_vec)
        in_weight_vec       = torch.div(in_abs_mean_vect, total_vec)
        resid_weight_mat    = torch.diag(resid_weight_vec)
        in_weight_mat       = torch.diag(in_weight_vec)
        log.info(' "weighted_matrix" residual connection policy used')
        t_out = torch.add(torch.matmul(t_resid, resid_weight_mat), torch.matmul(t_in, in_weight_mat))
    else:
        raise NotImplementedError
    return t_out
