import os
import re
import random
import numpy as np
import torch
import torch.nn as nn
import datetime
import math
from transformers import (
    T5Tokenizer,
    T5Config,
    DataCollatorWithPadding,
    DataCollatorForSeq2Seq,
)
from transformers.models.t5.modeling_t5 import (
    T5PreTrainedModel,
    T5Block,
    T5Model,
    T5EncoderModel,
    T5ForConditionalGeneration,
    T5Attention,
    T5LayerSelfAttention,
    T5LayerCrossAttention,
    T5LayerFF
)
from models.modeling_t5 import (
    T5ForSequenceClassification,
    ReduT5ForSequenceClassification,
    ReduT5ForConditionalGeneration,
    ReduT5Attention,
    ReduT5LayerSelfAttention,
    ReduT5LayerCrossAttention,
    ReduT5LayerFF,
    ReduT5Block
)
from datasets import Dataset, DatasetDict, load_from_disk, load_metric
from trainer import T5GlueTrainer, T5SquadTrainer
from typing import Optional, Dict, List, Callable, Any, Union, Tuple
from collections import defaultdict
from loguru import logger
import pickle


class T5Utils:

    def __init__(self) -> None:
        self.file_count = defaultdict(lambda: 0)
        self.LIMIT: int = 250
        self.MAX_FILE_COUNT: int = 8

    def collect(self,
                data_args,
                model_args,
                model: T5ForConditionalGeneration,
                trainer: Union[T5GlueTrainer, T5SquadTrainer]
                ):
        config: T5Config = model.config
        handlers = []
        norm_dict = defaultdict(list)
        attn_dict = defaultdict(list)
        act_dict = defaultdict(list)
        wo_dict = defaultdict(list)

        def init_fn(name, collector, collect_input=True):
            if name not in collector:
                collector[name] = []

            def fn(module, inputs, outputs):
                if collect_input:
                    data: torch.Tensor = inputs[0].detach().cpu().squeeze(0)
                else:
                    data: torch.Tensor = outputs[0].detach().cpu().squeeze(0)
                size = min(data_args.per_sample_token_num, data.shape[0])
                indices: np.ndarray = np.random.choice(data.shape[0], size, replace=False).tolist()
                collector[name].append(data[indices, ...])
                if len(collector[name]) == self.LIMIT and data_args.use_disk:
                    count = self.file_count[name]
                    self.file_count[name] += 1
                    file_name = "tmp/{}_{}.pkl".format(name, count)
                    with open(file_name, "wb") as f:
                        pickle.dump(collector[name], f)
                    collector[name] = []

            return fn

        def init_wo_fn(name, collector):
            def fn(module: nn.Linear, inputs, outputs):
                data: torch.Tensor = inputs[0].detach().cpu().squeeze(0)  # [L, D]
                size = min(data_args.per_sample_token_num, data.shape[0])
                indices: np.ndarray = np.random.choice(data.shape[0], size, replace=False).tolist()
                collector[name].append(data[indices, ...])
            return fn

        def init_act_fn(name, collector, collect_input=True):
            if name not in collector:
                collector[name] = []

            def fn(module, inputs, outputs):
                if collect_input:
                    data: torch.Tensor = inputs[0].detach().cpu().squeeze(0)
                else:
                    data: torch.Tensor = outputs[0].detach().cpu().squeeze(0)
                collector[name].append((data.shape[0], data.abs().mean(dim=0), data.pow(2).mean(dim=0)))
                if len(collector[name]) == self.LIMIT and data_args.use_disk:
                    count = self.file_count[name]
                    self.file_count[name] += 1
                    file_name = "tmp/{}_{}.pkl".format(name, count)
                    with open(file_name, "wb") as f:
                        pickle.dump(collector[name], f)
                    collector[name] = []

            return fn

        def dump_value(collector: Dict):
            if data_args.use_disk:
                for name in collector.keys():
                    if len(collector[name]) > 0:
                        count = self.file_count[name]
                        self.file_count[name] += 1
                        file_name = "tmp/{}_{}.pkl".format(name, count)
                        with open(file_name, "wb") as f:
                            pickle.dump(collector[name], f)
                        collector[name] = []

        def load_value(name: str) -> List:
            if data_args.use_disk:
                value = []
                for count in range(self.file_count[name]):
                    file_name = "tmp/{}_{}.pkl".format(name, count)
                    if os.path.exists(file_name):
                        with open(file_name, "rb") as f:
                            value.extend(pickle.load(f))
                return value
            else:
                for collector in [norm_dict, attn_dict, wo_dict, act_dict]:
                    if name in collector:
                        return collector[name]
                raise ValueError

        encoder_block_pat = re.compile("encoder\.block\.(\d+)")
        decoder_block_pat = re.compile("decoder\.block\.(\d+)")
        for name, module in model.named_modules():
            if encoder_block_pat.fullmatch(name):
                assert isinstance(module, T5Block)
                self_attn: T5LayerSelfAttention = module.layer[0]
                ffn: T5LayerFF = module.layer[1]

                # RMSNorm
                self_attn_norm_name = name + ".layer.0.layer_norm"
                ffn_norm_name = name + ".layer.1.layer_norm"

                handlers.append(self_attn.layer_norm.register_forward_hook(init_fn(self_attn_norm_name, norm_dict)))
                handlers.append(ffn.layer_norm.register_forward_hook(init_fn(ffn_norm_name, norm_dict)))

                # SelfMHA
                q_name = name + ".layer.0.SelfAttention.q"
                k_name = name + ".layer.0.SelfAttention.k"
                o_name = name + ".layer.0.SelfAttention.o"

                handlers.append(
                    self_attn.SelfAttention.q.register_forward_hook(init_fn(q_name, attn_dict, collect_input=False)))
                handlers.append(
                    self_attn.SelfAttention.k.register_forward_hook(init_fn(k_name, attn_dict, collect_input=False)))
                handlers.append(self_attn.SelfAttention.o.register_forward_hook(
                    init_wo_fn(o_name, wo_dict)))

                # FFN
                wo_name = name + ".layer.1.DenseReluDense.wo"
                handlers.append(ffn.DenseReluDense.wo.register_forward_hook(init_act_fn(wo_name, act_dict)))

            elif decoder_block_pat.fullmatch(name):
                assert isinstance(module, T5Block)
                self_attn: T5LayerSelfAttention = module.layer[0]
                cross_attn: T5LayerCrossAttention = module.layer[1]
                ffn: T5LayerFF = module.layer[2]

                # RMSNorm
                self_attn_norm_name = name + ".layer.0.layer_norm"
                cross_attn_norm_name = name + ".layer.1.layer_norm"
                ffn_norm_name = name + ".layer.2.layer_norm"

                handlers.append(self_attn.layer_norm.register_forward_hook(init_fn(self_attn_norm_name, norm_dict)))
                handlers.append(cross_attn.layer_norm.register_forward_hook(init_fn(cross_attn_norm_name, norm_dict)))
                handlers.append(ffn.layer_norm.register_forward_hook(init_fn(ffn_norm_name, norm_dict)))

                # SelfMHA
                sq_name = name + ".layer.0.SelfAttention.q"
                sk_name = name + ".layer.0.SelfAttention.k"
                so_name = name + ".layer.0.SelfAttention.o"

                handlers.append(
                    self_attn.SelfAttention.q.register_forward_hook(init_fn(sq_name, attn_dict, collect_input=False)))
                handlers.append(
                    self_attn.SelfAttention.k.register_forward_hook(init_fn(sk_name, attn_dict, collect_input=False)))
                handlers.append(self_attn.SelfAttention.o.register_forward_hook(
                    init_wo_fn(so_name, wo_dict)))

                # CrossMHA
                cq_name = name + ".layer.1.EncDecAttention.q"
                ck_name = name + ".layer.1.EncDecAttention.k"
                co_name = name + ".layer.1.EncDecAttention.o"

                handlers.append(cross_attn.EncDecAttention.q.register_forward_hook(
                    init_fn(cq_name, attn_dict, collect_input=False)))
                handlers.append(cross_attn.EncDecAttention.k.register_forward_hook(
                    init_fn(ck_name, attn_dict, collect_input=False)))
                handlers.append(cross_attn.EncDecAttention.o.register_forward_hook(
                    init_wo_fn(co_name, wo_dict)))

                # FFN
                wo_name = name + ".layer.2.DenseReluDense.wo"
                handlers.append(ffn.DenseReluDense.wo.register_forward_hook(init_act_fn(wo_name, act_dict)))

        # Final Layer Norm
        encoder_final_norm_name = "encoder.final_layer_norm"
        decoder_final_norm_name = "decoder.final_layer_norm"
        handlers.append(
            model.encoder.final_layer_norm.register_forward_hook(init_fn(encoder_final_norm_name, norm_dict)))
        handlers.append(
            model.decoder.final_layer_norm.register_forward_hook(init_fn(decoder_final_norm_name, norm_dict)))

        if not data_args.use_tmp:
            trainer.collect()
        else:
            self.file_count = defaultdict(lambda: self.MAX_FILE_COUNT)

        for collector in [norm_dict, act_dict, attn_dict]:
            dump_value(collector)

        def dump_norm():
            token_sample_num = int(math.ceil(data_args.token_sample_num / config.num_layers))
            hiddens = []
            for key in norm_dict.keys():
                value = load_value(key)
                assert isinstance(key, str)
                if key.startswith('encoder'):
                    print("key={}".format(key))
                    hidden = torch.concat(value, dim=0).transpose(0, 1)
                    size = min(hidden.shape[1], token_sample_num)
                    indices = np.random.choice(hidden.shape[1], size=size, replace=False)
                    hiddens.append(hidden[:, indices])
            hidden = torch.cat(hiddens, dim=-1)
            encoder_proj, _, _ = torch.linalg.svd(hidden)
            encoder_proj = encoder_proj[:, :model_args.r_model]

            hiddens = []
            for key in norm_dict.keys():
                value = load_value(key)
                assert isinstance(key, str)
                if key.startswith('decoder'):
                    print("key={}".format(key))
                    hidden = torch.concat(value, dim=0).transpose(0, 1)
                    size = min(hidden.shape[1], token_sample_num)
                    indices = np.random.choice(hidden.shape[1], size=size, replace=False)
                    hiddens.append(hidden[:, indices])
            hidden = torch.cat(hiddens, dim=-1)
            decoder_proj, _, _ = torch.linalg.svd(hidden)
            decoder_proj = decoder_proj[:, :model_args.r_model]

            return encoder_proj, decoder_proj

        def shape(states):
            """projection"""
            return states.view(-1, config.num_heads, config.d_kv).permute(1, 2, 0)

        def dump_attn():
            attn_value = {}
            r_kv = model_args.r_kv
            keys = sorted(list(set(x.replace('.q', '').replace('.k', '')
                                   for x in attn_dict.keys())))
            for key in keys:
                print("key={}".format(key))
                hidden_query = load_value(key + '.q')
                hidden_key = load_value(key + '.k')

                hidden_query = torch.concat(hidden_query, dim=0)  # [L, H]
                hidden_key = torch.concat(hidden_key, dim=0)  # [L, H]
                size = min(hidden_query.shape[0], data_args.token_sample_num)
                indices = np.random.choice(
                    hidden_query.shape[0], size=size, replace=False)
                hidden_query = shape(hidden_query[indices, :])
                hidden_key = shape(hidden_key[indices, :])
                Uq, Sq, _ = torch.linalg.svd(hidden_query)
                Uk, Sk, _ = torch.linalg.svd(hidden_key)

                M = torch.diag_embed(Sq) @ Uq.transpose(1, 2) @ Uk @ torch.diag_embed(Sk)
                Um, Sm, VmT = torch.linalg.svd(M)
                UT = Uq @ torch.diag_embed(1.0 / Sq) @ Um[..., :r_kv] @ torch.diag_embed(torch.sqrt(Sm[:, :r_kv]))
                V = torch.diag_embed(torch.sqrt(Sm[:, :r_kv])) @ VmT[:, :r_kv, :] @ torch.diag_embed(
                    1.0 / Sk) @ Uk.transpose(1, 2)
                attn_value[key + '.q_proj'] = UT.transpose(1, 2)
                attn_value[key + '.k_proj'] = V
            return attn_value

        def dump_wo():
            wo_value = {}
            for key in wo_dict.keys():
                print("key={}".format(key))
                value: List[torch.Tensor] = load_value(key)
                hidden = torch.concat(value, dim=0)  # [L, H]
                size = min(hidden.shape[0], data_args.token_sample_num)
                indices = np.random.choice(
                    hidden.shape[0], size=size, replace=False)
                hidden = shape(hidden[indices, :])
                U, S, _ = torch.linalg.svd(hidden)
                wo_value[key] = U[..., :model_args.r_kv]
            return wo_value

        def dump_act():
            act_value = {}
            model_state_dict = model.state_dict()
            for key in act_dict.keys():
                print("key={}".format(key))
                value: List[Tuple] = load_value(key)
                ws, hs, h2s = zip(*value)
                ws = torch.tensor(ws).unsqueeze(-1)
                hs = torch.stack(hs)
                h2s = torch.stack(h2s)

                h_mean = (ws * hs).sum(dim=0) / ws.sum()
                h2_mean = (ws * h2s).sum(dim=0) / ws.sum()
                h_std = (h2_mean - h_mean * h_mean + 1e-5).sqrt()
                weight: torch.Tensor = model_state_dict[key + ".weight"].transpose(0, 1).detach().cpu()
                norm = weight.norm(dim=1)
                h_value: torch.Tensor = (h_mean + h_std) * norm
                act_value[key] = h_value
            return act_value

        encoder_proj, decoder_proj = dump_norm()
        if model_args.comp_mode == 0:
            attn_proj = None
            wo_proj = None
            act_value = None
        elif model_args.comp_mode == 1:
            attn_proj = dump_attn()
            wo_proj = dump_wo()
            act_value = dump_act()
        else:
            raise ValueError

        t5_comp_params = {
            "encoder_proj": encoder_proj,
            "decoder_proj": decoder_proj,
            "attn_proj": attn_proj,
            "wo_proj": wo_proj,
            "act_value": act_value,
        }
        return t5_comp_params

    @torch.no_grad()
    def load_model_params(self,
                          model: ReduT5ForConditionalGeneration,
                          state_dict: Dict[str, torch.Tensor],
                          compression_params: Any,
                          ):
        encoder_proj: torch.Tensor = compression_params["encoder_proj"]
        decoder_proj: torch.Tensor = compression_params["decoder_proj"]
        attn_proj: Optional[Dict[str, torch.Tensor]] = compression_params["attn_proj"]
        wo_proj: Optional[Dict[str, torch.Tensor]] = compression_params["wo_proj"]
        act_value: Optional[Dict[str, torch.Tensor]] = compression_params["act_value"]

        if attn_proj is None:
            attn_proj = defaultdict(lambda: None)
        if wo_proj is None:
            wo_proj = defaultdict(lambda: None)
        if act_value is None:
            act_value = defaultdict(lambda: None)

        config: T5Config = model.config
        encoder_block_pat = re.compile("encoder\.block\.(\d+)")
        decoder_block_pat = re.compile("decoder\.block\.(\d+)")

        def get_norm_params(prefix: str, block_id: int, layer_id: int):
            path = "{}.block.{}.layer.{}.layer_norm".format(prefix, block_id, layer_id)
            return state_dict[path + ".weight"]

        def get_final_norm_params(prefix: str):
            path = "{}.final_layer_norm".format(prefix)
            return state_dict[path + ".weight"]

        def load_qk_params(
                name: str,
                linear: torch.nn.Linear,
                norm_w: torch.Tensor,
                norm_proj: torch.Tensor,
                qk_proj: Optional[torch.Tensor] = None,
        ):
            num_heads = config.num_heads
            d_kv = config.d_kv
            r_kv = config.r_kv

            lin_w = state_dict[name + '.weight']
            n_lin_w = lin_w @ torch.diag(norm_w) @ norm_proj
            if qk_proj is not None:
                n_lin_w = n_lin_w.view(num_heads, d_kv, config.r_model)
                n_lin_w = (qk_proj @ n_lin_w).view(num_heads * r_kv, config.r_model)

            linear.weight.copy_(n_lin_w)

        def prune_ffn_weight(
                w: torch.Tensor,
                retain_indices: torch.Tensor,
                prune_dim: int,
        ):
            mask = torch.zeros(config.d_ff, dtype=torch.bool)
            mask[retain_indices] = True
            return w[mask, :] if prune_dim == 0 else w[:, mask]

        def load_afternorm_linear(
                name: str,
                linear: torch.nn.Linear,
                norm_w: torch.Tensor,
                norm_proj: torch.Tensor,
                retain_indices: Optional[torch.Tensor] = None,
                prune_dim: Optional[int] = None,
                v_proj: Optional[torch.Tensor] = None
        ):
            lin_w = state_dict[name + '.weight']
            n_lin_w = (lin_w @ torch.diag(norm_w) @ norm_proj) * math.sqrt(config.d_model / config.r_model)

            if retain_indices is not None:
                n_lin_w = prune_ffn_weight(n_lin_w, retain_indices, prune_dim)

            if v_proj is not None:
                n_lin_w = n_lin_w.view(config.num_heads, config.d_kv, config.r_model)
                n_lin_w = torch.matmul(v_proj.transpose(1, 2), n_lin_w)
                n_lin_w = n_lin_w.reshape(config.num_heads * config.r_kv, config.r_model)

            linear.weight.copy_(n_lin_w)

        def load_beforenorm_linear(
                name: str,
                linear: torch.nn.Linear,
                norm_proj: torch.Tensor,
                retain_indices: Optional[torch.Tensor] = None,
                prune_dim: Optional[int] = None,
                o_proj: Optional[torch.Tensor] = None
        ):
            lin_w = state_dict[name + '.weight']
            n_lin_w = norm_proj.T @ lin_w

            if retain_indices is not None:
                n_lin_w = prune_ffn_weight(n_lin_w, retain_indices, prune_dim)

            if o_proj is not None:
                n_lin_w = n_lin_w.view(config.r_model, config.num_heads, config.d_kv).permute(1, 0, 2)
                n_lin_w = torch.matmul(n_lin_w, o_proj)
                n_lin_w = n_lin_w.permute(1, 0, 2).reshape(config.r_model, config.num_heads * config.r_kv)

            linear.weight.copy_(n_lin_w)

        def load_act_indices(
                name: str,
        ):
            h_value = act_value[name]
            act_indices = h_value.sort()[1][-config.r_ff:] if h_value is not None else None
            return act_indices

        for n, p in model.named_parameters():
            if n not in state_dict:
                continue
            if p.shape == state_dict[n].shape:
                p.copy_(state_dict[n])

        encoder_final_norm_w = get_final_norm_params("encoder")
        decoder_final_norm_w = get_final_norm_params("decoder")
        for name, module in model.named_modules():
            if 'teacher' in name:
                continue
            enc_match_result = encoder_block_pat.fullmatch(name)
            dec_match_result = decoder_block_pat.fullmatch(name)
            if enc_match_result is not None:
                assert isinstance(module, ReduT5Block)
                self_attn: ReduT5LayerSelfAttention = module.layer[0]
                ffn: ReduT5LayerFF = module.layer[1]

                block_id = int(enc_match_result.group(1))
                att_norm_w = get_norm_params("encoder", block_id, 0)
                ffn_norm_w = get_norm_params("encoder", block_id, 1)

                # Self MHA
                load_qk_params(
                    name + ".layer.0.SelfAttention.q",
                    self_attn.SelfAttention.q,
                    att_norm_w,
                    encoder_proj,
                    attn_proj[name + ".layer.0.SelfAttention.q_proj"]
                )
                load_qk_params(
                    name + ".layer.0.SelfAttention.k",
                    self_attn.SelfAttention.k,
                    att_norm_w,
                    encoder_proj,
                    attn_proj[name + ".layer.0.SelfAttention.k_proj"]
                )
                load_afternorm_linear(
                    name + ".layer.0.SelfAttention.v",
                    self_attn.SelfAttention.v,
                    att_norm_w,
                    encoder_proj,
                    v_proj=wo_proj[name + ".layer.0.SelfAttention.o"]
                )
                load_beforenorm_linear(
                    name + ".layer.0.SelfAttention.o",
                    self_attn.SelfAttention.o,
                    encoder_proj,
                    o_proj=wo_proj[name + ".layer.0.SelfAttention.o"]
                )
                if self_attn.SelfAttention.has_relative_attention_bias:
                    emb_weight = state_dict[name + '.layer.0.SelfAttention.relative_attention_bias.weight']
                    self_attn.SelfAttention.relative_attention_bias.weight.copy_(emb_weight)

                indices = load_act_indices(name + ".layer.1.DenseReluDense.wo")
                # FFN
                load_afternorm_linear(
                    name + ".layer.1.DenseReluDense.wi",
                    ffn.DenseReluDense.wi,
                    ffn_norm_w,
                    encoder_proj,
                    retain_indices=indices,
                    prune_dim=0
                )
                load_beforenorm_linear(
                    name + ".layer.1.DenseReluDense.wo",
                    ffn.DenseReluDense.wo,
                    encoder_proj,
                    retain_indices=indices,
                    prune_dim=1
                )

                self_attn.layer_norm.weight.copy_(torch.ones_like(self_attn.layer_norm.weight))
                ffn.layer_norm.weight.copy_(torch.ones_like(ffn.layer_norm.weight))

            elif dec_match_result is not None:
                assert isinstance(module, ReduT5Block)
                self_attn: ReduT5LayerSelfAttention = module.layer[0]
                cross_attn: ReduT5LayerCrossAttention = module.layer[1]
                ffn: ReduT5LayerFF = module.layer[2]

                block_id = int(dec_match_result.group(1))
                self_att_norm_w = get_norm_params("decoder", block_id, 0)
                cross_att_norm_w = get_norm_params("decoder", block_id, 1)
                ffn_norm_w = get_norm_params("decoder", block_id, 2)

                # Self MHA
                load_qk_params(
                    name + ".layer.0.SelfAttention.q",
                    self_attn.SelfAttention.q,
                    self_att_norm_w,
                    decoder_proj,
                    attn_proj[name + ".layer.0.SelfAttention.q_proj"]
                )
                load_qk_params(
                    name + ".layer.0.SelfAttention.k",
                    self_attn.SelfAttention.k,
                    self_att_norm_w,
                    decoder_proj,
                    attn_proj[name + ".layer.0.SelfAttention.k_proj"]
                )
                load_afternorm_linear(
                    name + ".layer.0.SelfAttention.v",
                    self_attn.SelfAttention.v,
                    self_att_norm_w,
                    decoder_proj,
                    v_proj=wo_proj[name + ".layer.0.SelfAttention.o"]
                )
                load_beforenorm_linear(
                    name + ".layer.0.SelfAttention.o",
                    self_attn.SelfAttention.o,
                    decoder_proj,
                    o_proj=wo_proj[name + ".layer.0.SelfAttention.o"]
                )

                if self_attn.SelfAttention.has_relative_attention_bias:
                    emb_weight = state_dict[name + '.layer.0.SelfAttention.relative_attention_bias.weight']
                    self_attn.SelfAttention.relative_attention_bias.weight.copy_(emb_weight)

                # Cross MHA
                load_qk_params(
                    name + ".layer.1.EncDecAttention.q",
                    cross_attn.EncDecAttention.q,
                    cross_att_norm_w,
                    decoder_proj,
                    attn_proj[name + ".layer.1.EncDecAttention.q_proj"]
                )
                load_qk_params(
                    name + ".layer.1.EncDecAttention.k",
                    cross_attn.EncDecAttention.k,
                    encoder_final_norm_w,
                    encoder_proj,
                    attn_proj[name + ".layer.1.EncDecAttention.k_proj"]
                )
                load_afternorm_linear(
                    name + ".layer.1.EncDecAttention.v",
                    cross_attn.EncDecAttention.v,
                    encoder_final_norm_w,
                    encoder_proj,
                    v_proj=wo_proj[name + ".layer.1.EncDecAttention.o"]
                )
                load_beforenorm_linear(
                    name + ".layer.1.EncDecAttention.o",
                    cross_attn.EncDecAttention.o,
                    decoder_proj,
                    o_proj=wo_proj[name + ".layer.1.EncDecAttention.o"]
                )

                if cross_attn.EncDecAttention.has_relative_attention_bias:
                    emb_weight = state_dict[name + '.layer.1.EncDecAttention.relative_attention_bias.weight']
                    cross_attn.EncDecAttention.relative_attention_bias.weight.copy_(emb_weight)

                indices = load_act_indices(name + ".layer.2.DenseReluDense.wo")
                # FFN
                load_afternorm_linear(
                    name + ".layer.2.DenseReluDense.wi",
                    ffn.DenseReluDense.wi,
                    ffn_norm_w,
                    decoder_proj,
                    retain_indices=indices,
                    prune_dim=0
                )
                load_beforenorm_linear(
                    name + ".layer.2.DenseReluDense.wo",
                    ffn.DenseReluDense.wo,
                    decoder_proj,
                    retain_indices=indices,
                    prune_dim=1
                )

                self_attn.layer_norm.weight.copy_(torch.ones_like(self_attn.layer_norm.weight))
                cross_attn.layer_norm.weight.copy_(torch.ones_like(cross_attn.layer_norm.weight))
                ffn.layer_norm.weight.copy_(torch.ones_like(ffn.layer_norm.weight))

        model.encoder.final_layer_norm.weight.copy_(torch.ones_like(model.encoder.final_layer_norm.weight))
        model.decoder.final_layer_norm.weight.copy_(torch.ones_like(model.decoder.final_layer_norm.weight))

        model.encoder.down_linear.weight.copy_(encoder_proj.T)
        model.decoder.down_linear.weight.copy_(decoder_proj.T)
        model.up_linear.weight.copy_(torch.diag(decoder_final_norm_w) @ decoder_proj)
