import torch
import torch.nn.functional as F
import numpy as np
import copy
from rich.progress import track
import math


def get_new_out_channel(reversed_channel_1,reversed_channel_2):
    tmp = list(reversed_channel_1) + list(reversed_channel_2)
    tmp = list(set(tmp))
    new_out_channel = len(tmp)
    return new_out_channel, tmp


def get_best_channel(layer, channel, num_layer, idx_of_scores):
    best_channel_in_layer = -1 * np.ones(num_layer,dtype=np.int64)
    for i in range(channel.shape[0]):
        layer_id = layer[i]
        channel_id = channel[i] - idx_of_scores[layer_id]
        best_channel_in_layer[layer_id] = channel_id
    return best_channel_in_layer


def sort_score(score_list, pruning_scale, idx_of_scores, out_c_of_para):
    scores = list(score_list)
    for i in range(6):
        scores.append(math.inf)
    scores = np.array(scores)

    num_to_prune = int(scores.shape[0] * pruning_scale)
    channel_to_layer = np.zeros(scores.shape[0],dtype=np.int32)

    for j in range(len(idx_of_scores)):
        start = idx_of_scores[j]
        end = idx_of_scores[j] + out_c_of_para[j]
        length = out_c_of_para[j]
        channel_to_layer[start:end] = j * np.ones(length, dtype=np.int32)

    sorted_idx_list = scores.argsort()

    layer_to_prune = channel_to_layer[sorted_idx_list]

    channel_prune = sorted_idx_list[:num_to_prune]
    layer_prune = channel_to_layer[channel_prune]
    best_channels = get_best_channel(layer_prune, channel_prune, layer_to_prune.max()+1, idx_of_scores)
    channel_prune_in_each_layer = list(-1 * np.ones(len(idx_of_scores), dtype=np.int32))

    for i, layer in enumerate(layer_prune):
        if channel_prune_in_each_layer[layer] == -1:
            channel_prune_in_each_layer[layer] = [channel_prune[i]-idx_of_scores[layer]]
        else:
            channel_prune_in_each_layer[layer].append(channel_prune[i]-idx_of_scores[layer])

    for i in range(len(channel_prune_in_each_layer)):
        if channel_prune_in_each_layer[i] != -1:
            tmp = np.array(channel_prune_in_each_layer[i])
            channel_prune_in_each_layer[i] = np.sort(tmp)

    return scores, channel_prune, channel_prune_in_each_layer, best_channels


def pruning(model, prune_type, get_module_type, ckpt_path, pciel, layer_after_att, best_channels):
    assert prune_type == 'Conv2d', "This function is designed for dealing with Conv2d layer."
    ckpt = torch.load(ckpt_path, map_location="cpu")
    new_ckpt = {}
    reversed_channels_last_layer = []
    reversed_channels_before_skip = []
    reverse_info = {}
    in_blocks_chs = []
    out_blocks_origin_ch = []
    long_skip_ch = None
    long_skip_origin_ch = -1

    layer_id = -1
    for k,v in track(ckpt.items()):
        type_k,_ = get_module_type(model,k)

        if type_k == 'Conv1d':
            ks = k.split('.')
            new_ckpt[k] = v.cpu().clone()

            if ks[0] == 'input_blocks' and ks[-1] == 'bias' and ks[-2]=='proj_out':
                in_blocks_chs.pop()
                in_blocks_chs.append(list(range(v.shape[0])))

            continue


        if type_k != prune_type:
            new_ckpt[k] = v.cpu().clone()
            continue

        if k.endswith('weight'):
            """
            The reverse_info[k] will contain two:
            reversed channels in last layer,
            reversed channels in this layer.
            """
            name = k[:-7]
            reverse_info[name] = []

            layer_id += 1

            if isinstance(pciel[layer_id],np.ndarray):
                pruned_channels = list(pciel[layer_id])
            else:
                pruned_channels = []

            if len(pruned_channels) == v.shape[0]:
                best_channel = best_channels[layer_id]
                pruned_channels.remove(best_channel)

            def need_reserve(n):
                return n not in pruned_channels
            reserved_channels = list(filter(need_reserve,range(v.shape[0])))

            """
            Get the channels reserved in the previous layer.
            Also, if the previous block is Attention block,
            input channel is no need to prune.
            Also, long skip connection must be considered.
            """
            ks = k.split('.')

            
            #################
            if ks[0] == 'middle_block' and ks[1] == '2' and ks[-3] == 'out_layers':
                out_blocks_origin_ch.append(v.shape[0])
            if ks[0] == 'output_blocks' and ks[-3] == 'out_layers':
                if ks[2] == '2':
                    out_blocks_origin_ch.pop()
                out_blocks_origin_ch.append(v.shape[0])


            if name in layer_after_att:
                if ks[-2] == 'skip_connection':
                    if ks[0] != 'output_blocks':
                        reversed_channels_before_skip = range(v.shape[0])
                else:
                    reversed_channels_last_layer = range(v.shape[0])
            if ks[-2] == 'skip_connection':
                reverse_info[name].append(reversed_channels_before_skip)
                if reversed_channels_before_skip != []:
                    new_ckpt[k] = v[:,reversed_channels_before_skip,...].cpu().clone()

            else:
                if ks[0]=='output_blocks' and ks[2]=='0' and ks[-2]=='2' and ks[-3]=='in_layers':
                    long_skip_ch = in_blocks_chs.pop()
                    long_skip_origin_ch = out_blocks_origin_ch.pop()
                    reversed_channels_last_layer = list(reversed_channels_last_layer) + list(np.array(long_skip_ch) + long_skip_origin_ch)


                if ks[-3] != 'out_layers':
                    reverse_info[name].append(reversed_channels_last_layer)
                else:
                    reverse_info[name].append(list(range(v.shape[0])))
                new_ckpt[k] = v.cpu().clone()

                # @@@@@@@@@@@@@
                if k == 'out.2.weight':
                    name1 = 'output_blocks.17.0.out_layers.3'
                    name2 = 'output_blocks.17.0.skip_connection'
                    the_last_ch, out_inch = get_new_out_channel(reverse_info[name1][1],reverse_info[name2][1])
                    new_ckpt[k] = new_ckpt[k][:,out_inch].cpu().clone()
                if ks[-3]=='in_layers' and ks[-2]=='2':
                    chhh = reverse_info[name][0]
                    new_ckpt[k] = new_ckpt[k][:,chhh].cpu().clone()
                if ks[-3]=='out_layers' and ks[-2]=='3':
                    if ks[0] != 'middle_block':
                        name_ = ks[0]+'.'+ks[1]+'.'+ks[2]+'.in_layers.2'
                    else:
                        name_ = ks[0]+'.'+ks[1]+'.in_layers.2'
                    chhh_o = reverse_info[name_][1]
                    new_ckpt[k] = new_ckpt[k][:,chhh_o].cpu().clone()



            """
            Prune this layer's out channels.
            Long skip connection must be considered.
            """
            if ks[0] != 'out':
                new_ckpt[k] = new_ckpt[k][reserved_channels].cpu().clone()
                reverse_info[name].append(reserved_channels)
            else:
                new_ckpt[k] = new_ckpt[k].cpu().clone()
                reverse_info[name].append(range(v.shape[0]))

            """
            Prepare reversed channels for skip conenction
            at the beginning of a resblock.
            """
            if ks[-2]=='2' and ks[-3] == 'in_layers':
                reversed_channels_before_skip = reversed_channels_last_layer

            """
            Input channels of the next resblock is jointly determined
            by out channels from both conv layers and skip connection.
            """
            if ks[-2] != 'skip_connection':
                if (ks[0] == 'input_blocks' and ks[-2]=='3' and ks[-3]=='out_layers') or (ks[0]=='input_blocks' and ks[1]=='0'):
                    if (ks[0]+'.'+ks[1]+'.'+ks[2]+'.'+'skip_connection.weight') in ckpt.keys():
                        in_blocks_chs.append(reserved_channels)
                    else:
                        _, reversed_channels_last_layer = get_new_out_channel(reversed_channels_before_skip,reserved_channels)
                        in_blocks_chs.append(reversed_channels_last_layer)

                if (ks[0] == 'input_blocks' and ks[-3]=='out_layers') or (ks[0] == 'middle_block' and ks[-3]=='out_layers') or (ks[0] == 'output_blocks' and ks[-3]=='out_layers'):
                    if (ks[0]+'.'+ks[1]+'.'+ks[2]+'.'+'skip_connection.weight') in ckpt.keys():
                        reversed_channels_last_layer = reserved_channels
                    else:
                        _, reversed_channels_last_layer = get_new_out_channel(reversed_channels_before_skip,reserved_channels)
                else:
                    reversed_channels_last_layer = reserved_channels


            else:
                _, reversed_channels_last_layer = get_new_out_channel(reversed_channels_last_layer,reserved_channels)
                if ks[0] == 'input_blocks':
                    in_blocks_chs.pop()
                    in_blocks_chs.append(reversed_channels_last_layer)
        elif k.endswith('bias'):
            if not k.startswith('out.'):
                new_ckpt[k] = v[reserved_channels].cpu().clone()
            else:
                new_ckpt[k] = v.cpu().clone()

    return new_ckpt, reverse_info, ckpt


def align_skip_channel(model, score_list, idx_of_scores, layer_after_att, get_module_type):
    convs = []
    muti_align_1 = ['input_blocks.0.0']
    muti_align_2 = ['input_blocks.7.0.skip_connection']
    name_to_layer_id = {}
    idx = 0
    for k,p in model.named_parameters():
        type_p,_ = get_module_type(model,k)
        if type_p=='Conv2d' and k.endswith('weight'):
            convs.append(k[:-7])
            name_to_layer_id[k[:-7]] = idx
            idx += 1
    for i in range(6):
        muti_align_1.append('input_blocks.'+str(i+1)+'.0.out_layers.3')
    for i in range(4):
        muti_align_2.append('input_blocks.'+str(i+7)+'.0.out_layers.3')
    
    muti_score_sum_1 = []
    for name in muti_align_1:
        layer_id = name_to_layer_id[name]
        idx1 = idx_of_scores[layer_id]
        idx2 = idx_of_scores[layer_id+1]
        if muti_score_sum_1 == []:
            muti_score_sum_1 = np.array(score_list[idx1 : idx2])
        else:
            muti_score_sum_1 += np.array(score_list[idx1 : idx2])
    for name in muti_align_1:
        layer_id = name_to_layer_id[name]
        idx1 = idx_of_scores[layer_id]
        idx2 = idx_of_scores[layer_id+1]
        score_list[idx1 : idx2] = list(muti_score_sum_1/7)
    
    muti_score_sum_2 = []
    for name in muti_align_2:
        layer_id = name_to_layer_id[name]
        idx1 = idx_of_scores[layer_id]
        idx2 = idx_of_scores[layer_id+1]
        if muti_score_sum_2 == []:
            muti_score_sum_2 = np.array(score_list[idx1 : idx2])
        else:
            muti_score_sum_2 += np.array(score_list[idx1 : idx2])
    for name in muti_align_2:
        layer_id = name_to_layer_id[name]
        idx1 = idx_of_scores[layer_id]
        idx2 = idx_of_scores[layer_id+1]
        score_list[idx1 : idx2] = list(muti_score_sum_2/5)


    for i,name in enumerate(convs):
        if name.endswith('out_layers.3'):
            if (name not in layer_after_att) and (name not in muti_align_1) and (name not in muti_align_2):
                if convs[i+1].endswith('skip_connection'):
                    idx1 = idx_of_scores[i]
                    idx2 = idx_of_scores[i+1]
                    idx3 = idx_of_scores[i+2]

                    tmp1 = score_list[idx1 : idx2]
                    tmp2 = score_list[idx2 : idx3]
                    new = (np.array(tmp1) + np.array(tmp2))/2

                    score_list[idx1 : idx2] = list(new)
                    score_list[idx2 : idx3] = list(new)
                
                elif convs[i-2].endswith('skip_connection'):
                    idx1 = idx_of_scores[i]
                    idx2 = idx_of_scores[i+1]
                    idx3 = idx_of_scores[i-3]
                    idx4 = idx_of_scores[i-2]
                    idx5 = idx_of_scores[i-1]

                    tmp1 = score_list[idx1 : idx2]
                    tmp2 = score_list[idx4 : idx5]
                    tmp3 = score_list[idx3 : idx4]
                    new = (np.array(tmp1) + np.array(tmp2) + np.array(tmp3))/3

                    score_list[idx1 : idx2] = list(new)
                    score_list[idx4 : idx5] = list(new)
                    score_list[idx3 : idx4] = list(new)
                
                else:
                    idx1 = idx_of_scores[i]
                    idx2 = idx_of_scores[i+1]
                    idx3 = idx_of_scores[i-2]
                    idx4 = idx_of_scores[i-1]

                    tmp1 = score_list[idx1 : idx2]
                    tmp2 = score_list[idx3 : idx4]
                    new = (np.array(tmp1) + np.array(tmp2))/2

                    score_list[idx1 : idx2] = list(new)
                    score_list[idx3 : idx4] = list(new)

            elif name in layer_after_att:
                if convs[i+1].endswith('skip_connection'):
                    idx1 = idx_of_scores[i]
                    idx2 = idx_of_scores[i+1]
                    idx3 = idx_of_scores[i+2]

                    tmp1 = score_list[idx1 : idx2]
                    tmp2 = score_list[idx2 : idx3]
                    new = (np.array(tmp1) + np.array(tmp2))/2

                    score_list[idx1 : idx2] = list(new)
                    score_list[idx2 : idx3] = list(new)

    return score_list


def check_skip_alignment(model, pciel, layer_after_att, get_module_type):
    convs = []
    muti_align_1 = ['input_blocks.0.0']
    muti_align_2 = ['input_blocks.7.0.skip_connection']
    name_to_layer_id = {}
    idx = 0
    for k,p in model.named_parameters():
        type_p,_ = get_module_type(model,k)
        if type_p=='Conv2d' and k.endswith('weight'):
            convs.append(k[:-7])
            name_to_layer_id[k[:-7]] = idx
            idx += 1

    for i in range(6):
        muti_align_1.append('input_blocks.'+str(i+1)+'.0.out_layers.3')
    for i in range(4):
        muti_align_2.append('input_blocks.'+str(i+7)+'.0.out_layers.3')
    
    name_ = 'input_blocks.0.0'
    pciel_align_1 = pciel[name_to_layer_id[name_]]
    for name in muti_align_1:
        layer_id = name_to_layer_id[name]
        pciel[layer_id] = pciel_align_1
    name_ = 'input_blocks.7.0.skip_connection'
    pciel_align_2 = pciel[name_to_layer_id[name_]]
    for name in muti_align_2:
        layer_id = name_to_layer_id[name]
        pciel[layer_id] = pciel_align_2
    
    for i,name in enumerate(convs):
        if name.endswith('out_layers.3'):
            if (name not in layer_after_att) and (name not in muti_align_1) and (name not in muti_align_2):
                if convs[i+1].endswith('skip_connection'):
                    tmp1 = pciel[i]
                    tmp2 = pciel[i+1]

                    if get_len(tmp1) > get_len(tmp2):
                        pciel[i] = tmp2
                    else:
                        pciel[i+1] = tmp1
                
                elif convs[i-2].endswith('skip_connection'):
                    tmp1 = pciel[i]
                    tmp2 = pciel[i-2]
                    tmp3 = pciel[i-3]

                    if get_len(tmp1) > get_len(tmp2):
                        if get_len(tmp3) > get_len(tmp2):
                            pciel[i-3] = tmp2
                            pciel[i] = tmp2
                        else:
                            pciel[i] = tmp3
                            pciel[i-2] = tmp3
                    else:
                        if get_len(tmp3) > get_len(tmp1):
                            pciel[i-2] = tmp1
                            pciel[i-3] = tmp1
                        else:
                            pciel[i] = tmp3
                            pciel[i-2] = tmp3
                
                else:
                    tmp1 = pciel[i]
                    tmp2 = pciel[i-2]

                    if get_len(tmp1) > get_len(tmp2):
                        pciel[i] = tmp2
                    else:
                        pciel[i-2] = tmp1
            
            elif name in layer_after_att :
                if convs[i+1].endswith('skip_connection'):
                    tmp1 = pciel[i]
                    tmp2 = pciel[i+1]

                    if get_len(tmp1) > get_len(tmp2):
                        pciel[i] = tmp2
                    else:
                        pciel[i+1] = tmp1

    return pciel


def get_len(a):
    if isinstance(a,np.ndarray):
        return a.shape[0]
    else:
        return -1


def pruning_conv1d(model, get_module_type, ckpt, reverse_info, in_pciel, out_pciel, in_best_channels, out_best_channels):
    prune_type = 'Conv1d'
    new_ckpt = {}

    out_reserved_channels = []

    layer_id = -1
    for k,v in track(ckpt.items()):
        type_k, _ = get_module_type(model,k)

        if type_k == prune_type:
            if k.endswith('weight'):
                layer_id += 1
                name = k[:-7]
                assert name not in reverse_info
                reverse_info[name] = []

                if isinstance(out_pciel[layer_id], np.ndarray):
                    out_pruned_channels = list(out_pciel[layer_id])
                else:
                    out_pruned_channels = []

                if isinstance(in_pciel[layer_id], np.ndarray):
                    in_pruned_channels = list(in_pciel[layer_id])
                else:
                    in_pruned_channels = []

                if len(out_pruned_channels) == v.shape[0]:
                    out_best_channel = out_best_channels[layer_id]
                    out_pruned_channels.remove(out_best_channel)

                if len(in_pruned_channels) == v.shape[1]:
                    in_best_channel = in_best_channels[layer_id]
                    in_pruned_channels.remove(in_best_channel)

                def out_need_reserve(n):
                    return n not in out_pruned_channels
                out_reserved_channels = list(filter(out_need_reserve,range(v.shape[0])))

                def in_need_reserve(n):
                    return n not in in_pruned_channels
                in_reserved_channels = list(filter(in_need_reserve,range(v.shape[1])))

                new_ckpt[k] = v[:, in_reserved_channels].cpu().clone()
                new_ckpt[k] = new_ckpt[k][out_reserved_channels].cpu().clone()

                reverse_info[name].append(in_reserved_channels)
                reverse_info[name].append(out_reserved_channels)

            elif k.endswith('bias'):
                new_ckpt[k] = v[out_reserved_channels].cpu().clone()

        else:
            new_ckpt[k] = v.cpu().clone()
    
    return new_ckpt, reverse_info


def pruning_lin(model, get_module_type, ckpt, reverse_info, in_pciel, out_pciel, in_best_channels, out_best_channels):
    prune_type = 'Linear'
    new_ckpt = {}

    out_reserved_channels = []

    layer_id = -1
    for k,v in track(ckpt.items()):
        type_k, _ = get_module_type(model,k)
        if type_k == prune_type:
            if k.endswith('weight'):
                layer_id += 1
                name = k[:-7]
                assert name not in reverse_info
                reverse_info[name] = []

                if isinstance(out_pciel[layer_id], np.ndarray):
                    out_pruned_channels = list(out_pciel[layer_id])
                else:
                    out_pruned_channels = []


                if isinstance(in_pciel[layer_id], np.ndarray):
                    in_pruned_channels = list(in_pciel[layer_id])
                else:
                    in_pruned_channels = []

                if len(out_pruned_channels) == v.shape[0]:
                    out_best_channel = out_best_channels[layer_id]
                    out_pruned_channels.remove(out_best_channel)

                if len(in_pruned_channels) == v.shape[1]:
                    in_best_channel = in_best_channels[layer_id]
                    in_pruned_channels.remove(in_best_channel)

                def out_need_reserve(n):
                    return n not in out_pruned_channels
                out_reserved_channels = list(filter(out_need_reserve,range(v.shape[0])))

                def in_need_reserve(n):
                    return n not in in_pruned_channels
                in_reserved_channels = list(filter(in_need_reserve,range(v.shape[1])))

                new_ckpt[k] = v[:, in_reserved_channels].cpu().clone()
                new_ckpt[k] = new_ckpt[k][out_reserved_channels].cpu().clone()

                reverse_info[name].append(in_reserved_channels)
                reverse_info[name].append(out_reserved_channels)


            elif k.endswith('bias'):
                new_ckpt[k] = v[out_reserved_channels].cpu().clone()

        else:
            new_ckpt[k] = v.cpu().clone()

    return new_ckpt, reverse_info


def sort_conv1d_score(in_score_list, out_score_list, pruning_scale, in_idx_of_scores, in_c_of_para, out_idx_of_scores, out_c_of_para):
    in_scores = np.array(in_score_list)
    out_scores = np.array(out_score_list)

    loss_in = np.sum(in_scores)
    num_in = in_scores.shape[0]
    loss_out = np.sum(out_scores)
    num_out = out_scores.shape[0]

    in_num_to_prune = int(num_in * pruning_scale * (num_in + num_out) / (loss_in/loss_out*num_out + num_in))
    out_num_to_prune = int(num_out * pruning_scale * (num_in + num_out) / (loss_out/loss_in*num_in + num_out))

    in_ch_to_layer = np.zeros(in_scores.shape[0], dtype=np.int32)
    out_ch_to_layer = np.zeros(out_scores.shape[0], dtype=np.int32)

    for j in range(len(in_idx_of_scores)):
        start = in_idx_of_scores[j]
        end = in_idx_of_scores[j] + in_c_of_para[j]
        length = in_c_of_para[j]
        in_ch_to_layer[start:end] = j * np.ones(length, dtype=np.int32)
    for j in range(len(out_idx_of_scores)):
        start = out_idx_of_scores[j]
        end = out_idx_of_scores[j] + out_c_of_para[j]
        length = out_c_of_para[j]
        out_ch_to_layer[start:end] = j * np.ones(length, dtype=np.int32)

    in_sorted_idx_list = in_scores.argsort()
    out_sorted_idx_list = out_scores.argsort()

    total_layer_to_prune = in_ch_to_layer[in_sorted_idx_list]

    in_ch_prune = in_sorted_idx_list[:in_num_to_prune]
    out_ch_prune = out_sorted_idx_list[:out_num_to_prune]

    in_layer_prune = in_ch_to_layer[in_ch_prune]
    out_layer_prune = out_ch_to_layer[out_ch_prune]

    in_best_channels = get_best_channel(in_layer_prune, in_ch_prune, total_layer_to_prune.max()+1, in_idx_of_scores)
    out_best_channels = get_best_channel(out_layer_prune, out_ch_prune, total_layer_to_prune.max()+1, out_idx_of_scores)

    in_pciel = list(-1 * np.ones(len(in_idx_of_scores), dtype=np.int32))
    out_pciel = list(-1 * np.ones(len(out_idx_of_scores), dtype=np.int32))

    for i, layer in enumerate(in_layer_prune):
        if in_pciel[layer] == -1:
            in_pciel[layer] = [in_ch_prune[i]-in_idx_of_scores[layer]]
        else:
            in_pciel[layer].append(in_ch_prune[i]-in_idx_of_scores[layer])
    for i, layer in enumerate(out_layer_prune):
        if out_pciel[layer] == -1:
            out_pciel[layer] = [out_ch_prune[i]-out_idx_of_scores[layer]]
        else:
            out_pciel[layer].append(out_ch_prune[i]-out_idx_of_scores[layer])
    
    for i in range(len(in_pciel)):
        if in_pciel[i] != -1:
            tmp = np.array(in_pciel[i])
            in_pciel[i] = np.sort(tmp)
    for i in range(len(out_pciel)):
        if out_pciel[i] != -1:
            tmp = np.array(out_pciel[i])
            out_pciel[i] = np.sort(tmp)
    
    return in_pciel, out_pciel, in_best_channels, out_best_channels


def align_lin(in_score_list, out_score_list, in_c_of_para, out_c_of_para):
    in_first_two = np.array(in_score_list[:in_c_of_para[1]])
    out_first_two = np.array(out_score_list[:out_c_of_para[0]])
    in_others = np.array(in_score_list[in_c_of_para[1]:])
    out_others = np.array(out_score_list[out_c_of_para[0]:])
    assert in_first_two.shape[0] == out_first_two.shape[0]

    tmp = (in_first_two + in_first_two)/2
    in_first_two = list(tmp)
    out_first_two = list(tmp)

    in_0 = list(math.inf * np.ones(in_c_of_para[0]))
    out_1 = list(math.inf * np.ones(out_c_of_para[1]))

    in_first_two = in_0 + in_first_two
    out_first_two = out_first_two + out_1

    in_score_list = np.concatenate((np.array(in_first_two),in_others),axis=0)
    out_score_list = np.concatenate((np.array(out_first_two),out_others),axis=0)

    return in_score_list, out_score_list

    
def sort_lin_score(in_score_list, out_score_list, pruning_scale, in_idx_of_scores, in_c_of_para, out_idx_of_scores, out_c_of_para):
    in_scores = np.array(in_score_list)
    out_scores = np.array(out_score_list)

    inscores_other = in_scores[in_idx_of_scores[2]:]
    outscores_other = out_scores[out_idx_of_scores[2]:]
    loss_in = np.sum(inscores_other)
    num_in = in_scores.shape[0]
    loss_out = np.sum(outscores_other)
    num_out = out_scores.shape[0]

    in_scale = float(pruning_scale * (num_in + num_out) / (loss_in/loss_out*num_out + num_in))
    out_scale = float(pruning_scale * (num_in + num_out) / (loss_out/loss_in*num_in + num_out))

    in_ch_to_layer = np.zeros(in_scores.shape[0], dtype=np.int32)
    out_ch_to_layer = np.zeros(out_scores.shape[0], dtype=np.int32)

    for j in range(len(in_idx_of_scores)):
        start = in_idx_of_scores[j]
        end = in_idx_of_scores[j] + in_c_of_para[j]
        length = in_c_of_para[j]
        in_ch_to_layer[start:end] = j * np.ones(length, dtype=np.int32)
    for j in range(len(out_idx_of_scores)):
        start = out_idx_of_scores[j]
        end = out_idx_of_scores[j] + out_c_of_para[j]
        length = out_c_of_para[j]
        out_ch_to_layer[start:end] = j * np.ones(length, dtype=np.int32)
    
    in_sorted_idx_list = in_scores.argsort()
    total_layer_to_prune = in_ch_to_layer[in_sorted_idx_list]


    inscores_first_2 = in_scores[:in_idx_of_scores[2]]
    outscores_first_2 = out_scores[:out_idx_of_scores[2]]
    in_sorted_first_2 = inscores_first_2.argsort()
    out_sorted_first_2 = outscores_first_2.argsort()

    assert out_c_of_para[0] == in_c_of_para[1]
    first_2_scale = pruning_scale
    in_ch_prune_first_2 = in_sorted_first_2[:int(first_2_scale*in_c_of_para[1])]
    out_ch_prune_first_2 = out_sorted_first_2[:int(first_2_scale*out_c_of_para[0])]

    in_ch_prune_ = list(in_ch_prune_first_2)
    out_ch_prune_ = list(out_ch_prune_first_2)

    # Adding is necessary.
    in_sorted_idx_list = inscores_other.argsort() + in_idx_of_scores[2]
    out_sorted_idx_list = outscores_other.argsort() + out_idx_of_scores[2]


    in_ch_prune = in_sorted_idx_list[:int(in_scale*inscores_other.shape[0])]
    out_ch_prune = out_sorted_idx_list[:int(out_scale*outscores_other.shape[0])]

    in_ch_prune = in_ch_prune_ + list(in_ch_prune)
    out_ch_prune = out_ch_prune_ + list(out_ch_prune)

    in_layer_prune = in_ch_to_layer[in_ch_prune]
    out_layer_prune = out_ch_to_layer[out_ch_prune]

    in_best_channels = get_best_channel(in_layer_prune, np.array(in_ch_prune), total_layer_to_prune.max()+1, in_idx_of_scores)
    out_best_channels = get_best_channel(out_layer_prune, np.array(out_ch_prune), total_layer_to_prune.max()+1, out_idx_of_scores)

    in_pciel = list(-1 * np.ones(len(in_idx_of_scores), dtype=np.int32))
    out_pciel = list(-1 * np.ones(len(out_idx_of_scores), dtype=np.int32))

    for i, layer in enumerate(in_layer_prune):
        if in_pciel[layer] == -1:
            in_pciel[layer] = [in_ch_prune[i]-in_idx_of_scores[layer]]
        else:
            in_pciel[layer].append(in_ch_prune[i]-in_idx_of_scores[layer])
    for i, layer in enumerate(out_layer_prune):
        if out_pciel[layer] == -1:
            out_pciel[layer] = [out_ch_prune[i]-out_idx_of_scores[layer]]
            if layer >= 2:
                out_pciel[layer].append(out_ch_prune[i]-out_idx_of_scores[layer]+int(out_c_of_para[layer]))
        else:
            out_pciel[layer].append(out_ch_prune[i]-out_idx_of_scores[layer])
            if layer >= 2:
                out_pciel[layer].append(out_ch_prune[i]-out_idx_of_scores[layer]+int(out_c_of_para[layer]))
    
    for i in range(len(in_pciel)):
        if in_pciel[i] != -1:
            tmp = np.array(in_pciel[i])
            in_pciel[i] = np.sort(tmp)
    for i in range(len(out_pciel)):
        if out_pciel[i] != -1:
            tmp = np.array(out_pciel[i])
            out_pciel[i] = np.sort(tmp)
    
    return in_pciel, out_pciel, in_best_channels, out_best_channels


def get_rep_score(rep_path, idx_of_scores):
    rep_ckpt = torch.load(rep_path, map_location="cpu")

    out_ch = 0
    conv_name_ = ''
    scores = []
    for k,p in rep_ckpt.items():
        if 'pwc' in k:
            ks = k.split('.')
            if 'skip_com' in k:
                conv_name = k.replace('skip_com.pwc', 'skip_connection')
            else:
                ks[-3] = str(int(ks[-3]) - 1)
                ks.pop(-2)
                conv_name = '.'.join(ks)

            compactor_in_ch = int(p.shape[1])
            compactor_out_ch = int(p.shape[0])
            assert k.endswith('weight')
            assert conv_name == conv_name_
            assert compactor_in_ch == out_ch
            assert compactor_out_ch == out_ch

            kernel_weight = p.detach().cpu().numpy()
            scores_k = np.sqrt(np.sum(kernel_weight ** 2, axis=(1,2,3)))

            # Each Conv2d reserves at least 1 channel.
            best_idx = np.argmax(scores_k)
            scores_k[best_idx] = math.inf

            scores += list(scores_k)

        else:
            if len(p.shape) == 4:
                out_ch = int(p.shape[0])
                conv_name_ = k
    
    assert len(scores) == idx_of_scores[-1]
    del rep_ckpt

    return scores


def sort_rep_score(model, rep_model, model_path, rep_path, score_list, idx_of_scores, out_c_of_para, thresh, get_module_type):
    ori_ckpt = torch.load(model_path, map_location="cpu")
    rep_ckpt = torch.load(rep_path, map_location="cpu")


    '''--------------------------------------- Sort the rep score list --------------------------------------'''
    scores = list(score_list)
    for i in range(6):
        scores.append(math.inf)
    scores = np.array(scores)

    channel_to_layer = np.zeros(scores.shape[0],dtype=np.int32)
    for j in range(len(idx_of_scores)):
        start = idx_of_scores[j]
        end = idx_of_scores[j] + out_c_of_para[j]
        length = out_c_of_para[j]
        channel_to_layer[start:end] = j * np.ones(length, dtype=np.int32)

    sorted_idx_list = scores.argsort()
    layer_to_prune = channel_to_layer[sorted_idx_list]

    channel_prune = np.where(scores < thresh)[0]
    layer_prune = channel_to_layer[list(channel_prune)]
    best_channels = get_best_channel(layer_prune, channel_prune, layer_to_prune.max()+1, idx_of_scores)

    cpiel = list(-1 * np.ones(len(idx_of_scores), dtype=np.int32))
    for i, layer in enumerate(layer_prune):
        if cpiel[layer] == -1:
            cpiel[layer] = [channel_prune[i]-idx_of_scores[layer]]
        else:
            cpiel[layer].append(channel_prune[i]-idx_of_scores[layer])

    for i in range(len(cpiel)):
        if cpiel[i] != -1:
            tmp = np.array(cpiel[i])
            cpiel[i] = list(np.sort(tmp))


    '''--------------------------------------- Convert the rep Conv2d ---------------------------------------'''
    ori_conv_weight = []
    ori_conv_bias = []
    for k,p in ori_ckpt.items():
        type_p, _ = get_module_type(model, k)
        if type_p == 'Conv2d':
            pp = p.detach().cpu().numpy()
            if k.endswith('weight'):
                ori_conv_weight.append(pp)
            else:
                assert k.endswith('bias')
                ori_conv_bias.append(pp)

    rep_conv_weight = []
    rep_conv_bias = []
    idx = -1
    pruned_kernel = []
    pruned_bias = []
    for k,p in rep_ckpt.items():
        if 'pwc' in k:
            weight = rep_conv_weight[idx]
            bias = rep_conv_bias[idx]

            pwc = p.detach().cpu().numpy()
            if cpiel[idx] != -1:
                pwc = np.delete(pwc, cpiel[idx], axis=0)
            pruned_k = F.conv2d(
                torch.from_numpy(weight).permute(1, 0, 2, 3),
                torch.from_numpy(pwc),
                padding=(0, 0)
            ).permute(1, 0, 2, 3)
            a = weight.shape[0]*weight.shape[1]*weight.shape[2]*weight.shape[3]
            b = pruned_k.shape[0]*pruned_k.shape[1]*pruned_k.shape[2]*pruned_k.shape[3]
            assert a>=b

            pruned_b = np.zeros(pwc.shape[0])
            for i in range(pwc.shape[0]):
                pruned_b[i] = bias.dot(pwc[i,:,0,0])

            if type(pruned_b) is not np.ndarray:
                pruned_b = np.array([pruned_b])
            
            bb = bias.shape[0]
            cc = pruned_b[0]
            assert bb>=cc

            pruned_kernel.append(torch.tensor(pruned_k))
            pruned_bias.append(torch.tensor(pruned_b))

        else:
            type_p, _ = get_module_type(rep_model, k)
            if type_p == 'Conv2d':
                pp = p.detach().cpu().numpy()
                if k.endswith('weight'):
                    idx += 1
                    rep_conv_weight.append(pp)
                else:
                    assert k.endswith('bias')
                    rep_conv_bias.append(pp)

    del ori_ckpt
    del rep_ckpt

    return cpiel, best_channels, pruned_kernel, pruned_bias


def pruning_rep(
    model, prune_type, get_module_type, ckpt_path, pciel, layer_after_att, best_channels,
    pruned_kernel, pruned_bias, rep_path
):
    assert prune_type == 'Conv2d', "This function is designed for dealing with Conv2d layer."
    ckpt = torch.load(ckpt_path, map_location="cpu")
    rep_ckpt = torch.load(rep_path, map_location="cpu")
    for k,p in ckpt.items():
        ckpt[k] = rep_ckpt[k]

    new_ckpt = {}
    reversed_channels_last_layer = []
    reversed_channels_before_skip = []
    reverse_info = {}
    in_blocks_chs = []
    out_blocks_origin_ch = []
    long_skip_ch = None
    long_skip_origin_ch = -1

    layer_id = -1
    for k,v in track(ckpt.items()):
        type_k,_ = get_module_type(model,k)

        if type_k == 'Conv1d':
            ks = k.split('.')
            new_ckpt[k] = v.cpu().clone()

            if ks[0] == 'input_blocks' and ks[-1] == 'bias' and ks[-2]=='proj_out':
                in_blocks_chs.pop()
                in_blocks_chs.append(list(range(v.shape[0])))

            continue


        if type_k != prune_type:
            new_ckpt[k] = v.cpu().clone()
            continue

        if k.endswith('weight'):
            """
            The reverse_info[k] will contain two:
            reversed channels in last layer,
            reversed channels in this layer.
            """
            name = k[:-7]
            reverse_info[name] = []

            layer_id += 1

            '''if isinstance(pciel[layer_id],np.ndarray):
                pruned_channels = list(pciel[layer_id])
            else:
                pruned_channels = []'''
            pruned_channels = pciel[layer_id]
            if pciel[layer_id] == -1:
                pruned_channels = []
            assert isinstance(pruned_channels, list)

            if len(pruned_channels) == v.shape[0]:
                best_channel = best_channels[layer_id]
                pruned_channels.remove(best_channel)

            def need_reserve(n):
                return n not in pruned_channels
            reserved_channels = list(filter(need_reserve,range(v.shape[0])))

            """
            Get the channels reserved in the previous layer.
            Also, if the previous block is Attention block,
            input channel is no need to prune.
            Also, long skip connection must be considered.
            """
            ks = k.split('.')
            if ks[0] != 'out':
                ################################################################################
                # new_ckpt[k] = new_ckpt[k][reserved_channels].cpu().clone()
                # Use weights from rep train.
                new_ckpt[k] = pruned_kernel[layer_id].cpu().clone()
                ################################################################################
            else:
                new_ckpt[k] = v.cpu().clone()

            
            #################
            if ks[0] == 'middle_block' and ks[1] == '2' and ks[-3] == 'out_layers':
                out_blocks_origin_ch.append(v.shape[0])
            if ks[0] == 'output_blocks' and ks[-3] == 'out_layers':
                if ks[2] == '2':
                    out_blocks_origin_ch.pop()
                out_blocks_origin_ch.append(v.shape[0])


            if name in layer_after_att:
                if ks[-2] == 'skip_connection':
                    if ks[0] != 'output_blocks':
                        reversed_channels_before_skip = range(v.shape[0])
                else:
                    reversed_channels_last_layer = range(v.shape[0])
            if ks[-2] == 'skip_connection':
                reverse_info[name].append(reversed_channels_before_skip)
                if reversed_channels_before_skip != []:
                    # new_ckpt[k] = v[:,reversed_channels_before_skip,...].cpu().clone()
                    new_ckpt[k] = new_ckpt[k][:,reversed_channels_before_skip,...].cpu().clone()

            else:
                if ks[0]=='output_blocks' and ks[2]=='0' and ks[-2]=='2' and ks[-3]=='in_layers':
                    long_skip_ch = in_blocks_chs.pop()
                    long_skip_origin_ch = out_blocks_origin_ch.pop()
                    reversed_channels_last_layer = list(reversed_channels_last_layer) + list(np.array(long_skip_ch) + long_skip_origin_ch)


                if ks[-3] != 'out_layers':
                    reverse_info[name].append(reversed_channels_last_layer)
                else:
                    reverse_info[name].append(list(range(v.shape[0])))
                # new_ckpt[k] = v.cpu().clone()
                new_ckpt[k] = new_ckpt[k].cpu().clone()

                # @@@@@@@@@@@@@
                if k == 'out.2.weight':
                    name1 = 'output_blocks.17.0.out_layers.3'
                    name2 = 'output_blocks.17.0.skip_connection'
                    the_last_ch, out_inch = get_new_out_channel(reverse_info[name1][1],reverse_info[name2][1])
                    new_ckpt[k] = new_ckpt[k][:,out_inch].cpu().clone()
                if ks[-3]=='in_layers' and ks[-2]=='2':
                    chhh = reverse_info[name][0]
                    new_ckpt[k] = new_ckpt[k][:,chhh].cpu().clone()
                if ks[-3]=='out_layers' and ks[-2]=='3':
                    if ks[0] != 'middle_block':
                        name_ = ks[0]+'.'+ks[1]+'.'+ks[2]+'.in_layers.2'
                    else:
                        name_ = ks[0]+'.'+ks[1]+'.in_layers.2'
                    chhh_o = reverse_info[name_][1]
                    new_ckpt[k] = new_ckpt[k][:,chhh_o].cpu().clone()



            """
            Prune this layer's out channels.
            Long skip connection must be considered.
            """
            if ks[0] != 'out':
                # new_ckpt[k] = new_ckpt[k][reserved_channels].cpu().clone()
                reverse_info[name].append(reserved_channels)
            else:
                reverse_info[name].append(range(v.shape[0]))

            """
            Prepare reversed channels for skip conenction
            at the beginning of a resblock.
            """
            if ks[-2]=='2' and ks[-3] == 'in_layers':
                reversed_channels_before_skip = reversed_channels_last_layer

            """
            Input channels of the next resblock is jointly determined
            by out channels from both conv layers and skip connection.
            """
            if ks[-2] != 'skip_connection':
                if (ks[0] == 'input_blocks' and ks[-2]=='3' and ks[-3]=='out_layers') or (ks[0]=='input_blocks' and ks[1]=='0'):
                    if (ks[0]+'.'+ks[1]+'.'+ks[2]+'.'+'skip_connection.weight') in ckpt.keys():
                        in_blocks_chs.append(reserved_channels)
                    else:
                        _, reversed_channels_last_layer = get_new_out_channel(reversed_channels_before_skip,reserved_channels)
                        in_blocks_chs.append(reversed_channels_last_layer)

                if (ks[0] == 'input_blocks' and ks[-3]=='out_layers') or (ks[0] == 'middle_block' and ks[-3]=='out_layers') or (ks[0] == 'output_blocks' and ks[-3]=='out_layers'):
                    if (ks[0]+'.'+ks[1]+'.'+ks[2]+'.'+'skip_connection.weight') in ckpt.keys():
                        reversed_channels_last_layer = reserved_channels
                    else:
                        _, reversed_channels_last_layer = get_new_out_channel(reversed_channels_before_skip,reserved_channels)
                else:
                    reversed_channels_last_layer = reserved_channels


            else:
                _, reversed_channels_last_layer = get_new_out_channel(reversed_channels_last_layer,reserved_channels)
                if ks[0] == 'input_blocks':
                    in_blocks_chs.pop()
                    in_blocks_chs.append(reversed_channels_last_layer)
        elif k.endswith('bias'):
            if not k.startswith('out.'):
                ################################################################################
                # new_ckpt[k] = v[reserved_channels].cpu().clone()
                # Use weights from rep train.
                new_ckpt[k] = pruned_bias[layer_id].cpu().clone()
                ################################################################################
            else:
                new_ckpt[k] = v.cpu().clone()

    return new_ckpt, reverse_info