import torch
import wandb
import os
import transformers

assert transformers.__version__ == "4.41.2"

from torch import nn
from torch.nn.functional import linear, embedding
from transformers.models.qwen2.modeling_qwen2 import *
from transformers.modeling_outputs import ModelOutput
from tools.log import main_logger
from dataclasses import dataclass
from tools.global_state import hyper_params, data_cls_reversed_dict, ban_losses, ban_layers
from accelerate import Accelerator


accelerator = Accelerator()


class BigValueFirstLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mseloss = nn.MSELoss(reduction="none")

    def forward(self, output, target):
        return torch.mean(torch.abs(target + 1e-2) * self.mseloss(output, target))


class MSELossV2(nn.Module):
    def __init__(self):
        super().__init__()
        self.mseloss = nn.MSELoss(reduction="none")
    
    def forward(self, output, target):
        return self.mseloss(output, target).sum(dim=-1).mean()


class L1LossV2(nn.Module):
    def forward(self, output, target):
        return ((output - target).abs().sum(dim=-1)/50.0).mean()


LOSS_DICT = {
    "mseloss": nn.MSELoss,
    "mseloss_v2": MSELossV2,
    "l1loss": nn.L1Loss,
    "l1loss_v2": L1LossV2,
    # "big_value_first": BigValueFirstLoss
}


class CustomConfig(Qwen2Config):
    def set_custom_kwargs(self, **kwargs):
        # required
        self.target_hidden_size = kwargs["target_hidden_size"]
        self.use_attn_map = kwargs.get("use_attn_map", False)
        self.target_rms_norm_eps = kwargs.get("target_rms_norm_eps", self.rms_norm_eps)
        self.use_aux_loss = kwargs.get("use_aux_loss", True)
        self.use_std_like_attn = kwargs.get("use_std_like_attn", False)
        self.use_logits_loss = kwargs.get("use_logits_loss", True)
        self.use_ntp_loss = kwargs.get("use_ntp_loss", True)
        self.check_data_cls_loss = kwargs.get("check_data_cls_loss", False)
        self.kl_temperature = kwargs.get("kl_temperature", 10.0)
        self.aux_loss_type = kwargs.get("aux_loss_type", "mseloss")
        self.student_attn_from_scratch = kwargs.get("student_attn_from_scratch", False)
        self.tie_word_emb_proj = kwargs.get("tie_word_emb_proj", False)
        self.del_layers = kwargs.get("del_layers", [])
        self.use_in_out_mlp = kwargs.get("use_in_out_mlp", False)
        self.use_all_attn = kwargs.get("use_all_attn", False)


class AllAttn(Qwen2FlashAttention2):
    def __init__(self, config: CustomConfig, layer_idx = None):
        super().__init__(config, layer_idx)
        self.config = config
        self.zoom_q = nn.Linear(config.target_hidden_size, self.hidden_size, bias=False)
        self.zoom_k = nn.Linear(config.target_hidden_size, self.hidden_size, bias=False)
        self.zoom_v = nn.Linear(config.target_hidden_size, self.hidden_size, bias=False)
        self.zoom_down = nn.Linear(self.hidden_size, config.target_hidden_size, bias=False)
        self.mseloss = LOSS_DICT[config.aux_loss_type]()
        self.layer_idx = layer_idx

    def part_forward(self, query_states, key_states, value_states, bsz, q_len, position_ids,
                     past_key_value=None, attention_mask=None):
        output_attentions = False

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            if self.layer_idx is None:
                raise ValueError(
                    f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
                    "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                    "with a layer index."
                )
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

        # Because the input can be padded, the absolute sequence length depends on the max position id.
        rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
        cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)

        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

        use_sliding_windows = (
            getattr(self.config, "sliding_window", None) is not None
            and kv_seq_len > self.config.sliding_window
            and self.config.use_sliding_window
        )
        if use_sliding_windows:
            raise NotImplementedError

        if past_key_value is not None:
            raise NotImplementedError

        # repeat k/v heads if n_kv_heads < n_heads
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)
        dropout_rate = 0.0 if not self.training else self.attention_dropout

        # In PEFT, usually we cast the layer norms in float32 for training stability reasons
        # therefore the input hidden states gets silently casted in float32. Hence, we need
        # cast them back in float16 just to be sure everything works as expected.
        input_dtype = query_states.dtype
        if input_dtype == torch.float32:
            if torch.is_autocast_enabled():
                target_dtype = torch.get_autocast_gpu_dtype()
            # Handle the case where the model is quantized
            elif hasattr(self.config, "_pre_quantization_dtype"):
                target_dtype = self.config._pre_quantization_dtype
            else:
                target_dtype = self.q_proj.weight.dtype

            logger.warning_once(
                f"The input hidden states seems to be silently casted in float32, this might be related to"
                f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
                f" {target_dtype}."
            )

            query_states = query_states.to(target_dtype)
            key_states = key_states.to(target_dtype)
            value_states = value_states.to(target_dtype)

        # Reashape to the expected shape for Flash Attention
        query_states = query_states.transpose(1, 2)
        key_states = key_states.transpose(1, 2)
        value_states = value_states.transpose(1, 2)

        attn_output = self._flash_attention_forward(
            query_states,
            key_states,
            value_states,
            attention_mask,
            q_len,
            dropout=dropout_rate,
            use_sliding_windows=use_sliding_windows,
        )

        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value

    def forward(
        self,
        hidden_states,
        compressed_hidden_states,
        loss_dict,
        attention_mask=None,
        position_ids=None,
        past_key_value=None,
        output_attentions=False,
        use_cache=False,
        **kwargs,
    ):
        assert past_key_value is None
        output_attentions = False

        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        _up_hs_q = self.q_proj(self.zoom_q(compressed_hidden_states))
        _up_hs_k = self.k_proj(self.zoom_k(compressed_hidden_states))
        _up_hs_v = self.v_proj(self.zoom_v(compressed_hidden_states))

        raw_out, raw_attn_map, _ = self.part_forward(query_states, key_states, value_states, bsz, q_len, position_ids,
                                                     attention_mask=None)
        out, attn_map, _ = self.part_forward(_up_hs_q, _up_hs_k, _up_hs_v, bsz, q_len, position_ids,
                                             attention_mask=None)
        compressed_hidden_states = self.zoom_down(out)

        if "attn-q-sim-loss" not in ban_losses and self.layer_idx not in ban_layers:
            loss_dict["attn-q-sim-loss"] = self.mseloss(_up_hs_q, query_states)
        if "attn-k-sim-loss" not in ban_losses and self.layer_idx not in ban_layers:
            loss_dict["attn-k-sim-loss"] = self.mseloss(_up_hs_k, key_states)
        if "attn-v-sim-loss" not in ban_losses and self.layer_idx not in ban_layers:
            loss_dict["attn-v-sim-loss"] = self.mseloss(_up_hs_v, value_states)
        if "attn-k-sim-loss" not in ban_losses and self.layer_idx not in ban_layers:
            loss_dict["attn-out-sim-loss"] = self.mseloss(out, raw_out)

        return raw_out, compressed_hidden_states, raw_attn_map, _, loss_dict
    
    def merge_weight(self):
        self.q_proj.weight.data = (self.q_proj.weight.data @ self.zoom_q.weight.data).contiguous()
        self.k_proj.weight.data = (self.k_proj.weight.data @ self.zoom_k.weight.data).contiguous()
        self.v_proj.weight.data = (self.v_proj.weight.data @ self.zoom_v.weight.data).contiguous()
        self.o_proj.weight.data = (self.zoom_down.weight.data @ self.o_proj.weight.data).contiguous()


class Attn(Qwen2FlashAttention2):
    def __init__(self, config: CustomConfig, layer_idx = None):
        super().__init__(config, layer_idx)
        self.config = config
        self.zoom_up = nn.Linear(config.target_hidden_size, self.hidden_size, bias=False)
        self.zoom_down = nn.Linear(self.hidden_size, config.target_hidden_size, bias=False)
        self.mseloss = LOSS_DICT[config.aux_loss_type]()
        self.layer_idx = layer_idx

    def forward(
        self,
        hidden_states,
        compressed_hidden_states,
        loss_dict,
        attention_mask=None,
        position_ids=None,
        past_key_value=None,
        output_attentions=False,
        use_cache=False,
        cache_position=None,
        **kwargs,
    ):
        output_attentions = self.config.use_attn_map
        if output_attentions:
            raise NotImplementedError
        assert attention_mask is None
        assert past_key_value is None
        
        raw_out, raw_attn_map, _ = super().forward(
            hidden_states,
            attention_mask,
            position_ids,
            past_key_value,
            output_attentions,
            use_cache,
            # cache_position,
            # **kwargs,
        )
        # assert not torch.isnan(compressed_hidden_states).any(), f"NaN detected in model output in a"
        zoomed_hs = self.zoom_up(compressed_hidden_states)
        out, attn_map, _ = super().forward(
            zoomed_hs,
            attention_mask,
            position_ids,
            past_key_value,
            output_attentions,
            use_cache,
            # cache_position,
            # **kwargs,
        )
        compressed_hidden_states = self.zoom_down(out)

        # assert not torch.isnan(compressed_hidden_states).any(), f"NaN detected in model output in b"
        if "attn-in-sim-loss" not in ban_losses and self.layer_idx not in ban_layers:
            loss_dict["attn-in-sim-loss"] = self.mseloss(zoomed_hs, hidden_states)
        if "attn-out-sim-loss" not in ban_losses and self.layer_idx not in ban_layers:
            loss_dict["attn-out-sim-loss"] = self.mseloss(out, raw_out)
        return raw_out, compressed_hidden_states, raw_attn_map, _, loss_dict
    
    def merge_weight(self):
        self.q_proj.weight.data = (self.q_proj.weight.data @ self.zoom_up.weight.data).contiguous()
        self.k_proj.weight.data = (self.k_proj.weight.data @ self.zoom_up.weight.data).contiguous()
        self.v_proj.weight.data = (self.v_proj.weight.data @ self.zoom_up.weight.data).contiguous()
        self.o_proj.weight.data = (self.zoom_down.weight.data @ self.o_proj.weight.data).contiguous()


class MLP(Qwen2MLP):
    def __init__(self, config: CustomConfig, layer_idx=None):
        super().__init__(config)
        self.zoom_up = nn.Linear(self.hidden_size, config.target_hidden_size, bias=False)
        self.zoom_gate = nn.Linear(self.hidden_size, config.target_hidden_size, bias=False)
        self.zoom_down = nn.Linear(self.hidden_size, config.target_hidden_size, bias=False)
        self.mseloss = LOSS_DICT[config.aux_loss_type]()
        self.layer_idx = layer_idx

    def small_forward(self, compressed_x, raw_gate, raw_act_gate, raw_up, raw_x, raw_out, loss_dict: dict):
        Wup = self.zoom_up(self.up_proj.weight)
        Wgate = self.zoom_gate(self.gate_proj.weight)
        Wdown = self.zoom_down(self.down_proj.weight.T).T
        gate = linear(compressed_x, Wgate)
        act_gate = self.act_fn(gate)
        up = linear(compressed_x, Wup)
        down = linear(act_gate * up, Wdown)

        # calculate loss
        if "mlp-gate-loss" not in ban_losses and self.layer_idx not in ban_layers:
            loss_dict[f"mlp-gate-loss"] = self.mseloss(gate, raw_gate)
        # loss_dict[f"mlp-act-gate-loss"] = self.mseloss(act_gate, raw_act_gate)
        if "mlp-up-loss" not in ban_losses and self.layer_idx not in ban_layers:
            loss_dict[f"mlp-up-loss"] = self.mseloss(up, raw_up)
        # loss_dict[f"mlp-in-loss"] = self.mseloss(compressed_x, self.zoom(raw_x))
        if "mlp-out-loss" not in ban_losses and self.layer_idx not in ban_layers:
            loss_dict[f"mlp-out-loss"] = self.mseloss(down, self.zoom_down(raw_out))

        # print("debug 2", loss_dict)
        return down

    def forward(self, x, compressed_x, loss_dict: dict):
        gate = self.gate_proj(x)
        act_gate = self.act_fn(gate)
        up = self.up_proj(x)
        down = self.down_proj(act_gate * up)

        return down, self.small_forward(compressed_x, gate, act_gate, up, x, down, loss_dict), loss_dict
    
    def merge_weight(self):
        self.gate_proj.weight.data = self.zoom_gate(self.gate_proj.weight.data).contiguous()
        self.up_proj.weight.data = self.zoom_up(self.up_proj.weight.data).contiguous()
        self.down_proj.weight.data = self.zoom_down(self.down_proj.weight.data.T).T.contiguous()


class DebugLlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        DebugLlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps
        # self.first = True

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        mm = torch.rsqrt(variance + self.variance_epsilon)
        hidden_states = hidden_states * mm
        out = self.weight * hidden_states.to(input_dtype)
        # if self.first:
        #     assert torch.abs(torch.mean(self.weight) - 1) < 1e-3, f"{self.weight}"
        #     self.first = False
        return out
    

# def reinit_weight(module: nn.Module):
#     if type(module) == nn.Linear:
#         if module.weight.requires_grad:
#             module.weight.data.normal_(mean=0.0, std=0.02)
#             if module.bias is not None:
#                 module.bias.data.zero_()
#     if type(module) == DebugLlamaRMSNorm:
#         if module.weight.requires_grad:
#             module.weight.data.fill_(1.0)


def reinit_weight(module: nn.Module):
    if not hasattr(module, 'weight') or not module.weight.requires_grad:
        return

    #important_dims = [0, 1, 2, 3, 8, 11, 12, 14, 16, 18, 19, 20, 22, 24, 26, 27, 28, 31, 33, 34, 35, 37, 38, 40, 42, 43, 46, 47, 49, 57, 58, 63, 64, 65, 66, 70, 74, 77, 78, 79, 80, 81, 84, 85, 87, 88, 89, 93, 99, 103, 105, 107, 108, 109, 110, 112, 119, 120, 121, 122, 123, 124, 125, 127, 128, 132, 133, 137, 138, 141, 142, 143, 145, 147, 148, 150, 151, 156, 158, 160, 161, 163, 165, 166, 167, 174, 176, 177, 180, 183, 184, 187, 191, 194, 195, 196, 197, 201, 203, 204, 205, 207, 212, 215, 216, 217, 218, 219, 221, 223, 224, 227, 230, 232, 233, 234, 235, 236, 237, 238, 239, 240, 242, 243, 245, 246, 247, 251, 254, 256, 260, 262, 263, 264, 265, 267, 272, 273, 274, 275, 277, 278, 279, 280, 281, 282, 287, 289, 291, 292, 294, 298, 300, 301, 303, 304, 307, 308, 310, 311, 312, 314, 319, 322, 323, 326, 328, 329, 331, 332, 333, 335, 338, 343, 348, 349, 350, 351, 353, 354, 355, 356, 357, 359, 360, 364, 366, 368, 370, 372, 373, 375, 376, 377, 379, 380, 384, 385, 388, 389, 393, 398, 399, 400, 403, 404, 409, 411, 412, 413, 415, 416, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 431, 433, 435, 436, 437, 438, 441, 443, 446, 447, 448, 453, 454, 456, 457, 458, 459, 460, 462, 466, 469, 472, 476, 478, 480, 481, 482, 484, 485, 486, 487, 488, 489, 492, 495, 496, 497, 500, 501, 507, 509, 511, 512, 513, 514, 517, 518, 519, 520, 522, 524, 525, 527, 528, 530, 531, 533, 535, 538, 539, 540, 542, 554, 555, 557, 559, 562, 563, 564, 567, 568, 570, 571, 573, 574, 575, 580, 581, 584, 585, 586, 588, 589, 591, 593, 595, 596, 600, 605, 606, 608, 611, 613, 615, 616, 617, 621, 624, 627, 628, 629, 630, 634, 642, 643, 647, 649, 650, 651, 653, 658, 661, 662, 663, 664, 666, 667, 668, 669, 671, 672, 674, 676, 677, 678, 685, 687, 688, 689, 690, 692, 696, 698, 700, 702, 703, 704, 705, 708, 710, 712, 713, 714, 717, 718, 719, 720, 723, 726, 727, 728, 730, 732, 733, 737, 738, 739, 742, 743, 744, 745, 746, 747, 749, 752, 753, 755, 756, 757, 758, 759, 760, 763, 764, 766, 767, 769, 770, 775, 776, 777, 778, 779, 780, 781, 785, 786, 787, 790, 793, 794, 795, 796, 798, 801, 803, 806, 808, 810, 811, 814, 816, 817, 820, 822, 823, 824, 825, 827, 830, 832, 834, 835, 836, 838, 839, 840, 842, 843, 844, 849, 851, 852, 853, 854, 855, 856, 859, 860, 864, 865, 866, 868, 869, 870, 873, 875, 876, 878, 881, 882, 883, 884, 887, 891, 893, 895, 896, 897, 898, 900, 901, 902, 906, 907, 909, 913, 914, 917, 919, 921, 922, 925, 927, 928, 929, 931, 932, 933, 934, 936, 938, 940, 945, 946, 952, 953, 954, 956, 957, 958, 960, 962, 963, 965, 967, 968, 972, 973, 976, 977, 981, 983, 984, 985, 987, 993, 994, 996, 997, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1016, 1018, 1023, 1024, 1026, 1027, 1031, 1032, 1034, 1035, 1037, 1041, 1043, 1044, 1045, 1047, 1048, 1049, 1050, 1052, 1053, 1055, 1056, 1057, 1058, 1060, 1069, 1072, 1076, 1077, 1079, 1082, 1083, 1084, 1085, 1086, 1087, 1088, 1090, 1095, 1097, 1099, 1102, 1103, 1104, 1105, 1106, 1109, 1110, 1111, 1113, 1114, 1115, 1119, 1122, 1123, 1129, 1131, 1133, 1136, 1139, 1144, 1145, 1146, 1147, 1151, 1154, 1155, 1156, 1158, 1161, 1162, 1166, 1167, 1170, 1171, 1172, 1175, 1176, 1177, 1179, 1181, 1184, 1185, 1188, 1189, 1194, 1195, 1197, 1200, 1203, 1205, 1208, 1209, 1210, 1212, 1215, 1216, 1217, 1218, 1220, 1221, 1223, 1224, 1225, 1226, 1227, 1228, 1229, 1230, 1234, 1235, 1236, 1237, 1238, 1239, 1241, 1243, 1244, 1249, 1251, 1252, 1253, 1254, 1255, 1259, 1264, 1266, 1268, 1270, 1271, 1272, 1277, 1278, 1279, 1280, 1281, 1283, 1285, 1286, 1288, 1290, 1291, 1292, 1293, 1295, 1299, 1302, 1304, 1305, 1306, 1307, 1308, 1310, 1313, 1314, 1315, 1316, 1319, 1320, 1326, 1332, 1333, 1335, 1336, 1342, 1345, 1346, 1347, 1348, 1349, 1354, 1355, 1356, 1358, 1361, 1362, 1363, 1364, 1367, 1368, 1369, 1372, 1374, 1375, 1376, 1378, 1379, 1380, 1382, 1389, 1391, 1392, 1394, 1395, 1398, 1399, 1401, 1402, 1403, 1404, 1407, 1408, 1409, 1410, 1416, 1417, 1422, 1424, 1425, 1428, 1431, 1433, 1435, 1437, 1439, 1440, 1443, 1444, 1447, 1448, 1453, 1454, 1456, 1457, 1458, 1459, 1465, 1466, 1467, 1470, 1471, 1474, 1476, 1477, 1479, 1480, 1482, 1484, 1485, 1487, 1488, 1493, 1499, 1500, 1503, 1505, 1507, 1508, 1510, 1515, 1519, 1522, 1523, 1524, 1525, 1526, 1527, 1528, 1529, 1531, 1533, 1534, 1535, 1536, 1537, 1539, 1540, 1541, 1542, 1543, 1544, 1545, 1546, 1552, 1553, 1554, 1557, 1558, 1559, 1560, 1562, 1563, 1564, 1566, 1568, 1571, 1572, 1573, 1576, 1578, 1579, 1582, 1584, 1585, 1586, 1589, 1591, 1595, 1597, 1602, 1603, 1604, 1605, 1608, 1609, 1610, 1614, 1616, 1617, 1619, 1622, 1624, 1627, 1629, 1630, 1632, 1634, 1638, 1639, 1640, 1641, 1643, 1644, 1646, 1647, 1648, 1649, 1650, 1651, 1653, 1654, 1655, 1657, 1659, 1670, 1673, 1674, 1675, 1677, 1678, 1680, 1681, 1683, 1685, 1686, 1687, 1688, 1690, 1692, 1693, 1694, 1696, 1698, 1705, 1709, 1710, 1711, 1712, 1713, 1714, 1715, 1716, 1717, 1719, 1722, 1724, 1726, 1727, 1728, 1729, 1730, 1731, 1739, 1742, 1743, 1744, 1745, 1746, 1747, 1755, 1756, 1760, 1761, 1762, 1765, 1774, 1777, 1778, 1781, 1784, 1786, 1788, 1789, 1791, 1792, 1793, 1794, 1795, 1796, 1797, 1798, 1800, 1803, 1804, 1805, 1806, 1807, 1808, 1809, 1810, 1812, 1814, 1815, 1817, 1820, 1823, 1825, 1826, 1827, 1830, 1831, 1837, 1840, 1842, 1844, 1845, 1847, 1848, 1851, 1852, 1853, 1857, 1859, 1861, 1862, 1866, 1867, 1874, 1878, 1879, 1882, 1886, 1887, 1888, 1889, 1894, 1895, 1897, 1899, 1901, 1902, 1903, 1904, 1905, 1906, 1909, 1910, 1911, 1914, 1915, 1916, 1917, 1924, 1925, 1926, 1927, 1928, 1930, 1932, 1934, 1935, 1940, 1941, 1943, 1947, 1948, 1950, 1951, 1955, 1956, 1958, 1964, 1965, 1967, 1968, 1973, 1974, 1976, 1978, 1980, 1981, 1982, 1983, 1986, 1988, 1991, 1992, 1994, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2007, 2008, 2010, 2011, 2015, 2016, 2018, 2019, 2021, 2022, 2024, 2026, 2027, 2029, 2030, 2033, 2035, 2036, 2041, 2042, 2044, 2045, 2046, 2047, 2048, 2049, 2052, 2053, 2054, 2056, 2059, 2065, 2066, 2067, 2071, 2072, 2073, 2076, 2079, 2080, 2083, 2086, 2087, 2090, 2092, 2094, 2095, 2096, 2098, 2099, 2100, 2101, 2102, 2104, 2107, 2109, 2112, 2113, 2115, 2117, 2119, 2121, 2122, 2124, 2125, 2126, 2127, 2130, 2132, 2136, 2139, 2141, 2143, 2145, 2147, 2151, 2152, 2154, 2156, 2161, 2165, 2166, 2168, 2171, 2172, 2176, 2177, 2178, 2181, 2183, 2184, 2185, 2188, 2189, 2192, 2193, 2195, 2196, 2197, 2198, 2200, 2202, 2203, 2204, 2205, 2209, 2210, 2212, 2214, 2216, 2217, 2218, 2219, 2223, 2224, 2225, 2227, 2228, 2231, 2233, 2234, 2236, 2237, 2239, 2240, 2243, 2244, 2245, 2247, 2250, 2254, 2255, 2257, 2258, 2262, 2263, 2264, 2265, 2267, 2268, 2270, 2271, 2273, 2275, 2279, 2280, 2287, 2288, 2289, 2292, 2294, 2297, 2299, 2303, 2305, 2307, 2308, 2312, 2317, 2320, 2322, 2326, 2327, 2328, 2329, 2330, 2331, 2336, 2338, 2339, 2344, 2345, 2346, 2348, 2352, 2354, 2355, 2356, 2357, 2358, 2359, 2360, 2362, 2363, 2364, 2366, 2369, 2370, 2371, 2372, 2373, 2374, 2378, 2379, 2380, 2381, 2383, 2386, 2388, 2392, 2393, 2395, 2398, 2400, 2403, 2407, 2411, 2412, 2420, 2423, 2425, 2428, 2430, 2432, 2433, 2434, 2435, 2437, 2444, 2445, 2446, 2448, 2452, 2453, 2454, 2458, 2461, 2462, 2465, 2469, 2472, 2473, 2474, 2476, 2478, 2480, 2482, 2484, 2486, 2487, 2489, 2493, 2497, 2498, 2502, 2507, 2509, 2510, 2513, 2515, 2517, 2528, 2530, 2531, 2535, 2536, 2537, 2540, 2541, 2544, 2551, 2552, 2558, 2560, 2563, 2574, 2576, 2577, 2579, 2581, 2582, 2584, 2585, 2587, 2588, 2589, 2590, 2593, 2599, 2602, 2603, 2607, 2609, 2610, 2612, 2613, 2614, 2615, 2618, 2619, 2621, 2626, 2629, 2630, 2632, 2634, 2637, 2638, 2647, 2648, 2655, 2656, 2658, 2661, 2663, 2668, 2669, 2670, 2671, 2673, 2676, 2678, 2679, 2680, 2681, 2682, 2683, 2685, 2686, 2688, 2689, 2690, 2692, 2696, 2697, 2698, 2701, 2702, 2703, 2704, 2707, 2709, 2712, 2713, 2714, 2715, 2717, 2720, 2721, 2722, 2723, 2724, 2726, 2727, 2734, 2739, 2741, 2748, 2749, 2750, 2752, 2758, 2759, 2760, 2761, 2764, 2766, 2771, 2773, 2774, 2775, 2776, 2777, 2778, 2779, 2781, 2783, 2784, 2786, 2787, 2789, 2791, 2792, 2795, 2797, 2799, 2800, 2801, 2802, 2803, 2805, 2807, 2808, 2809, 2814, 2817, 2822, 2823, 2824, 2825, 2834, 2837, 2839, 2841, 2844, 2848, 2849, 2852, 2854, 2856, 2858, 2859, 2860, 2864, 2868, 2873, 2874, 2875, 2876, 2878, 2879, 2881, 2883, 2884, 2885, 2888, 2892, 2895, 2896, 2898, 2902, 2905, 2912, 2914, 2917, 2920, 2921, 2922, 2923, 2924, 2926, 2927, 2930, 2931, 2932, 2934, 2936, 2937, 2938, 2939, 2940, 2941, 2943, 2945, 2946, 2949, 2952, 2953, 2954, 2962, 2963, 2965, 2966, 2970, 2971, 2973, 2986, 2988, 2990, 2992, 2993, 2994, 2995, 2998, 2999, 3000, 3004, 3005, 3006, 3008, 3009, 3010, 3018, 3020, 3026, 3027, 3029, 3031, 3033, 3034, 3036, 3037, 3039, 3046, 3049, 3050, 3052, 3053, 3057, 3061, 3062, 3064, 3065, 3066, 3067, 3069, 3070]
    important_dims = [3, 5, 8, 9, 11, 13, 14, 17, 19, 20, 23, 24, 26, 27, 28, 29, 31, 32, 33, 36, 37, 39, 46, 49, 51, 52, 53, 55, 59, 60, 62, 66, 67, 68, 69, 70, 71, 73, 76, 77, 78, 83, 84, 85, 89, 95, 96, 97, 99, 100, 101, 104, 107, 108, 109, 111, 112, 113, 114, 115, 116, 119, 120, 121, 124, 125, 127, 128, 129, 130, 131, 132, 134, 135, 136, 138, 140, 143, 144, 145, 152, 154, 155, 156, 157, 160, 161, 162, 163, 164, 165, 166, 167, 168, 171, 172, 173, 177, 179, 180, 183, 184, 186, 187, 188, 189, 190, 191, 192, 193, 198, 199, 200, 201, 203, 206, 207, 209, 211, 212, 216, 217, 219, 220, 221, 222, 226, 227, 228, 229, 233, 234, 236, 238, 239, 244, 245, 252, 253, 255, 257, 258, 259, 265, 266, 268, 270, 273, 274, 277, 278, 280, 282, 284, 286, 289, 291, 292, 293, 294, 295, 296, 297, 300, 301, 302, 306, 307, 308, 310, 311, 314, 316, 319, 321, 322, 323, 324, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 339, 340, 341, 342, 343, 344, 348, 351, 355, 358, 359, 362, 363, 364, 365, 366, 368, 369, 373, 374, 376, 377, 382, 385, 387, 388, 389, 392, 393, 394, 396, 397, 399, 400, 402, 404, 405, 409, 410, 418, 420, 421, 422, 423, 424, 425, 428, 429, 431, 432, 433, 435, 436, 438, 439, 440, 444, 445, 446, 451, 453, 456, 457, 458, 459, 461, 463, 464, 466, 467, 468, 469, 470, 471, 475, 476, 477, 481, 482, 484, 489, 490, 495, 497, 500, 501, 502, 503, 504, 505, 507, 509, 510, 515, 516, 517, 518, 520, 523, 525, 528, 529, 534, 536, 537, 538, 540, 541, 542, 543, 544, 547, 548, 550, 553, 554, 555, 556, 557, 559, 561, 562, 563, 564, 566, 569, 571, 572, 573, 574, 575, 576, 578, 581, 583, 585, 586, 587, 589, 591, 595, 596, 597, 598, 599, 601, 602, 604, 607, 608, 612, 613, 616, 617, 621, 622, 623, 624, 625, 628, 630, 631, 632, 634, 635, 637, 639, 640, 641, 645, 649, 650, 651, 653, 654, 656, 659, 660, 661, 662, 666, 667, 670, 671, 677, 680, 682, 684, 686, 687, 688, 689, 694, 695, 696, 698, 700, 703, 704, 706, 707, 710, 711, 712, 713, 714, 715, 716, 718, 720, 725, 727, 728, 732, 734, 736, 737, 738, 740, 742, 743, 747, 749, 750, 752, 755, 756, 759, 761, 762, 763, 765, 767, 768, 770, 772, 773, 775, 776, 779, 780, 781, 783, 784, 786, 787, 791, 792, 793, 795, 796, 799, 800, 801, 802, 805, 806, 807, 808, 809, 810, 811, 816, 817, 819, 820, 821, 822, 824, 825, 826, 827, 828, 829, 830, 834, 835, 836, 841, 842, 844, 845, 846, 847, 848, 850, 851, 852, 853, 854, 855, 856, 857, 859, 861, 862, 864, 865, 867, 868, 871, 876, 877, 879, 882, 883, 888, 889, 890, 892, 893, 895, 896, 898, 899, 901, 902, 903, 909, 912, 913, 915, 918, 928, 931, 933, 935, 936, 938, 939, 940, 941, 942, 943, 944, 945, 947, 948, 951, 952, 953, 954, 956, 957, 958, 959, 960, 963, 964, 965, 966, 967, 968, 972, 980, 981, 982, 983, 984, 985, 986, 987, 989, 991, 992, 994, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1007, 1009, 1012, 1013, 1014, 1017, 1018, 1019, 1020, 1021, 1023, 1024, 1026, 1028, 1029, 1033, 1034, 1035, 1038, 1039, 1040, 1041, 1043, 1045, 1047, 1050, 1051, 1053, 1056, 1058, 1059, 1060, 1061, 1062, 1064, 1065, 1066, 1067, 1068, 1069, 1071, 1074, 1076, 1078, 1079, 1081, 1082, 1085, 1086, 1087, 1088, 1089, 1090, 1091, 1092, 1093, 1095, 1099, 1101, 1102, 1103, 1104, 1106, 1107, 1110, 1111, 1114, 1116, 1117, 1118, 1121, 1122, 1123, 1124, 1126, 1128, 1129, 1130, 1131, 1133, 1134, 1136, 1138, 1139, 1140, 1142, 1144, 1145, 1146, 1147, 1149, 1150, 1152, 1154, 1157, 1158, 1160, 1162, 1163, 1166, 1167, 1169, 1174, 1175, 1177, 1180, 1182, 1188, 1189, 1190, 1191, 1195, 1196, 1197, 1199, 1203, 1205, 1207, 1208, 1209, 1210, 1211, 1212, 1213, 1219, 1220, 1221, 1222, 1224, 1226, 1230, 1234, 1238, 1239, 1248, 1249, 1257, 1259, 1260, 1261, 1265, 1266, 1268, 1272, 1273, 1275, 1277, 1278, 1279, 1280, 1281, 1282, 1283, 1285, 1287, 1288, 1290, 1291, 1292, 1294, 1295, 1297, 1298, 1299, 1300, 1301, 1302, 1303, 1304, 1306, 1309, 1311, 1313, 1314, 1317, 1319, 1321, 1323, 1324, 1328, 1329, 1331, 1332, 1333, 1339, 1342, 1343, 1345, 1347, 1348, 1351, 1352, 1353, 1354, 1355, 1358, 1359, 1361, 1368, 1369, 1370, 1374, 1377, 1379, 1380, 1381, 1383, 1384, 1385, 1386, 1387, 1388, 1392, 1393, 1394, 1395, 1396, 1397, 1402, 1404, 1405, 1407, 1410, 1411, 1412, 1413, 1414, 1418, 1419, 1420, 1421, 1422, 1423, 1424, 1427, 1428, 1429, 1431, 1432, 1433, 1438, 1439, 1443, 1445, 1451, 1452, 1455, 1456, 1457, 1458, 1459, 1461, 1462, 1464, 1465, 1466, 1467, 1468, 1470, 1473, 1474, 1475, 1478, 1479, 1481, 1482, 1484, 1487, 1488, 1489, 1491, 1494, 1502, 1503, 1507, 1511, 1512, 1514, 1515, 1520, 1526, 1527, 1528, 1529, 1530, 1531, 1533, 1534, 1542, 1545, 1547, 1549, 1550, 1551, 1552, 1553, 1554, 1555, 1556, 1557, 1558, 1559, 1560, 1562, 1563, 1568, 1571, 1572, 1573, 1574, 1576, 1577, 1579, 1581, 1584, 1585, 1586, 1588, 1589, 1590, 1592, 1593, 1594, 1595, 1597, 1598, 1599, 1601, 1604, 1605, 1606, 1607, 1609, 1610, 1611, 1614, 1615, 1617, 1619, 1620, 1621, 1623, 1625, 1626, 1627, 1628, 1631, 1632, 1633, 1634, 1635, 1638, 1639, 1640, 1642, 1644, 1645, 1647, 1648, 1650, 1653, 1655, 1656, 1657, 1660, 1661, 1663, 1666, 1667, 1670, 1672, 1674, 1675, 1676, 1677, 1679, 1680, 1682, 1683, 1684, 1685, 1686, 1688, 1690, 1691, 1692, 1693, 1694, 1697, 1699, 1700, 1704, 1705, 1706, 1707, 1708, 1710, 1711, 1712, 1716, 1717, 1718, 1719, 1720, 1721, 1722, 1723, 1724, 1725, 1730, 1731, 1733, 1735, 1737, 1738, 1740, 1741, 1742, 1743, 1744, 1745, 1746, 1748, 1750, 1751, 1752, 1753, 1755, 1758, 1759, 1760, 1761, 1762, 1763, 1767, 1769, 1776, 1777, 1778, 1780, 1781, 1783, 1785, 1786, 1788, 1790, 1793, 1794, 1795, 1799, 1801, 1803, 1805, 1806, 1811, 1815, 1816, 1817, 1820, 1821, 1823, 1825, 1827, 1828, 1829, 1830, 1831, 1832, 1833, 1835, 1836, 1839, 1840, 1841, 1843, 1844, 1845, 1847, 1848, 1849, 1850, 1851, 1852, 1853, 1854, 1855, 1856, 1857, 1860, 1864, 1865, 1866, 1868, 1869, 1871, 1872, 1874, 1876, 1878, 1881, 1882, 1883, 1885, 1888, 1890, 1891, 1892, 1893, 1894, 1895, 1898, 1900, 1902, 1903, 1904, 1907, 1908, 1909, 1911, 1912, 1915, 1916, 1917, 1918, 1919, 1922, 1923, 1925, 1926, 1927, 1928, 1929, 1930, 1932, 1934, 1936, 1940, 1941, 1942, 1943, 1945, 1949, 1951, 1952, 1955, 1957, 1960, 1962, 1963, 1967, 1968, 1969, 1971, 1976, 1977, 1980, 1981, 1982, 1984, 1985, 1987, 1989, 1990, 1991, 1993, 1995, 1996, 1997, 1999, 2004, 2006, 2007, 2010, 2012, 2014, 2016, 2019, 2022, 2024, 2025, 2026, 2027, 2028, 2029, 2030, 2031, 2032, 2033, 2035, 2036, 2037, 2039, 2040, 2041, 2042, 2044, 2045, 2047, 2048, 2050, 2051, 2053, 2054, 2056, 2059, 2060, 2061, 2063, 2066, 2069, 2070, 2071, 2072, 2073, 2076, 2077, 2078, 2080, 2082, 2084, 2085, 2086, 2088, 2089, 2090, 2091, 2094, 2095, 2096, 2097, 2099, 2100, 2102, 2103, 2104, 2106, 2107, 2109, 2111, 2116, 2117, 2118, 2123, 2126, 2127, 2128, 2130, 2132, 2133, 2136, 2139, 2141, 2143, 2145, 2147, 2148, 2149, 2150, 2152, 2153, 2154, 2157, 2158, 2159, 2161, 2163, 2164, 2167, 2170, 2171, 2172, 2174, 2175, 2177, 2178, 2179, 2180, 2182, 2189, 2190, 2191, 2192, 2193, 2194, 2195, 2196, 2198, 2199, 2200, 2202, 2204, 2205, 2207, 2208, 2209, 2210, 2211, 2212, 2213, 2216, 2218, 2219, 2223, 2225, 2226, 2228, 2230, 2231, 2232, 2234, 2235, 2236, 2240, 2242, 2244, 2245, 2246, 2247, 2249, 2251, 2254, 2255, 2257, 2258, 2259, 2263, 2264, 2266, 2268, 2269, 2271, 2272, 2273, 2275, 2276, 2277, 2278, 2279, 2280, 2281, 2282, 2283, 2284, 2285, 2286, 2287, 2288, 2289, 2290, 2291, 2294, 2299, 2300, 2301, 2302, 2303, 2304, 2305, 2307, 2309, 2311, 2313, 2316, 2318, 2320, 2322, 2323, 2326, 2327, 2328, 2329, 2330, 2331, 2332, 2333, 2334, 2335, 2340, 2341, 2343, 2345, 2346, 2348, 2353, 2354, 2355, 2358, 2359, 2360, 2364, 2365, 2366, 2367, 2368, 2369, 2370, 2371, 2372, 2373, 2374, 2376, 2383, 2384, 2385, 2386, 2387, 2391, 2392, 2394, 2396, 2401, 2402, 2403, 2404, 2405, 2406, 2407, 2408, 2410, 2411, 2413, 2415, 2416, 2418, 2419, 2425, 2426, 2427, 2429, 2431, 2433, 2434, 2435, 2438, 2439, 2440, 2441, 2442, 2444, 2445, 2448, 2449, 2450, 2451, 2452, 2455, 2461, 2464, 2465, 2467, 2470, 2474, 2475, 2476, 2477, 2478, 2479, 2480, 2481, 2483, 2484, 2486, 2488, 2489, 2491, 2494, 2495, 2496, 2498, 2501, 2503, 2510, 2511, 2516, 2518, 2519, 2520, 2523, 2524, 2525, 2527, 2529, 2531, 2533, 2534, 2537, 2539, 2541, 2542, 2543, 2544, 2545, 2547, 2549, 2550, 2551, 2552, 2557, 2558, 2559, 2560, 2561, 2563, 2565, 2566, 2567, 2568, 2569, 2570, 2571, 2572, 2577, 2581, 2582, 2583, 2589, 2590, 2591, 2596, 2598, 2599, 2600, 2601, 2603, 2604, 2605, 2606, 2610, 2612, 2613, 2615, 2617, 2619, 2620, 2622, 2624, 2626, 2629, 2630, 2631, 2632, 2633, 2634, 2637, 2638, 2640, 2641, 2642, 2643, 2644, 2645, 2648, 2650, 2652, 2653, 2655, 2656, 2657, 2658, 2659, 2660, 2662, 2664, 2666, 2667, 2669, 2672, 2673, 2674, 2675, 2677, 2678, 2679, 2680, 2682, 2683, 2684, 2685, 2690, 2693, 2695, 2696, 2697, 2699, 2700, 2701, 2702, 2703, 2704, 2705, 2706, 2708, 2709, 2710, 2711, 2713, 2714, 2715, 2718, 2720, 2721, 2722, 2724, 2725, 2727, 2730, 2731, 2733, 2735, 2737, 2738, 2741, 2742, 2745, 2746, 2748, 2750, 2751, 2753, 2755, 2758, 2759, 2761, 2762, 2763, 2764, 2768, 2769, 2770, 2771, 2773, 2776, 2779, 2780, 2781, 2787, 2788, 2789, 2790, 2792, 2793, 2794, 2795, 2796, 2797, 2800, 2801, 2806, 2809, 2810, 2811, 2812, 2813, 2814, 2817, 2818, 2819, 2827, 2828, 2829, 2830, 2831, 2832, 2835, 2836, 2838, 2840, 2841, 2843, 2849, 2850, 2851, 2854, 2858, 2861, 2863, 2865, 2870, 2872, 2873, 2874, 2877, 2878, 2880, 2882, 2883, 2884, 2886, 2887, 2889, 2890, 2891, 2894, 2896, 2897, 2898, 2900, 2901, 2905, 2906, 2907, 2911, 2914, 2915, 2918, 2919, 2920, 2921, 2922, 2923, 2926, 2927, 2928, 2932, 2934, 2940, 2943, 2944, 2945, 2946, 2947, 2948, 2950, 2953, 2954, 2955, 2956, 2957, 2958, 2959, 2960, 2965, 2967, 2968, 2969, 2970, 2971, 2972, 2973, 2975, 2977, 2978, 2979, 2980, 2981, 2982, 2985, 2987, 2988, 2989, 2991, 2992, 2995, 2996, 2997, 2998, 2999, 3003, 3005, 3008, 3009, 3010, 3013, 3016, 3017, 3020, 3021, 3025, 3026, 3028, 3030, 3031, 3032, 3036, 3037, 3043, 3045, 3046, 3047, 3050, 3051, 3053, 3058, 3059, 3063, 3064, 3066, 3067, 3068, 3069, 3070, 3071, 3073, 3074, 3077, 3078, 3080, 3082, 3083, 3084, 3085, 3090, 3092, 3095, 3101, 3102, 3103, 3104, 3108, 3110, 3111, 3113, 3117, 3119, 3120, 3121, 3123, 3124, 3125, 3126, 3127, 3128, 3129, 3131, 3132, 3133, 3134, 3135, 3137, 3138, 3142, 3143, 3148, 3152, 3154, 3155, 3157, 3159, 3160, 3162, 3163, 3167, 3168, 3169, 3170, 3171, 3172, 3173, 3174, 3177, 3178, 3180, 3181, 3182, 3184, 3185, 3186, 3188, 3189, 3191, 3192, 3195, 3197, 3199, 3202, 3203, 3205, 3206, 3207, 3209, 3210, 3213, 3215, 3216, 3217, 3218, 3219, 3222, 3224, 3226, 3228, 3229, 3230, 3231, 3232, 3233, 3234, 3235, 3238, 3239, 3240, 3241, 3244, 3245, 3246, 3248, 3250, 3251, 3254, 3255, 3256, 3257, 3260, 3261, 3262, 3263, 3264, 3265, 3270, 3271, 3272, 3273, 3275, 3276, 3278, 3279, 3281, 3283, 3286, 3288, 3289, 3290, 3293, 3296, 3297, 3298, 3300, 3301, 3302, 3303, 3304, 3305, 3308, 3311, 3312, 3314, 3316, 3318, 3319, 3320, 3323, 3325, 3327, 3329, 3330, 3333, 3338, 3341, 3344, 3345, 3349, 3356, 3357, 3358, 3359, 3360, 3361, 3363, 3364, 3366, 3369, 3376, 3378, 3379, 3380, 3382, 3387, 3388, 3389, 3391, 3392, 3394, 3398, 3402, 3403, 3404, 3405, 3407, 3409, 3410, 3412, 3413, 3414, 3415, 3417, 3419, 3420, 3423, 3424, 3430, 3431, 3432, 3433, 3435, 3436, 3437, 3438, 3439, 3441, 3443, 3444, 3446, 3447, 3452, 3453, 3456, 3457, 3461, 3462, 3463, 3464, 3465, 3466, 3471, 3473, 3475, 3477, 3478, 3479, 3481, 3482, 3486, 3487, 3488, 3489, 3490, 3492, 3493, 3495, 3496, 3498, 3501, 3502, 3503, 3505, 3507, 3508, 3509, 3510, 3512, 3513, 3514, 3515, 3516, 3517, 3522, 3524, 3525, 3529, 3531, 3533, 3535, 3536, 3538, 3540, 3541, 3545, 3546, 3548, 3549, 3550, 3551, 3552, 3555, 3557, 3558, 3560, 3561, 3562, 3563, 3564, 3565, 3566, 3567, 3570, 3571, 3572, 3573, 3574, 3575, 3576, 3577, 3578, 3579, 3580, 3581]
    is_zoom_down = module.weight.shape == (2048, 3584)
    is_zoom_up = module.weight.shape == (3584, 2048)

    if isinstance(module, nn.Linear):
        
        if is_zoom_down:
            new_weight = torch.zeros_like(module.weight)
            for i, dim_idx in enumerate(important_dims):
                new_weight[i, dim_idx] = 1.0
            module.weight.data = new_weight
        
        elif is_zoom_up:
            new_weight = torch.zeros_like(module.weight)
            for i, dim_idx in enumerate(important_dims):
                new_weight[dim_idx, i] = 1.0
            module.weight.data = new_weight
       
        else:
            exit()
            # import pdb
            # pdb.set_trace()
        
        if module.bias is not None and module.bias.requires_grad:
            module.bias.data.zero_()

    if type(module) == DebugLlamaRMSNorm:
        if module.weight.requires_grad:
            module.weight.data.fill_(1.0)


class CustomLayer(Qwen2DecoderLayer):
    def __init__(self, config: CustomConfig, layer_idx):
        super().__init__(config, layer_idx)
        self.config = config
        self.layer_idx = layer_idx

        if self.config.use_std_like_attn:
            raise ValueError("Low Performance")
        elif self.config.student_attn_from_scratch:
            raise NotImplementedError
        elif self.config.use_all_attn:
            # print(f"[arch] using all attn")
            self.self_attn = AllAttn(config, layer_idx)
        else:
            # print(f"[arch] using io attn")
            self.self_attn = Attn(config, layer_idx)
        if self.config.use_in_out_mlp:
            raise NotImplementedError
        else:
            # print(f"[arch] using all ffn")
            self.mlp = MLP(config, layer_idx)
        self.target_input_layernorm = DebugLlamaRMSNorm(config.target_hidden_size, eps=config.target_rms_norm_eps)
        self.target_post_attention_layernorm = DebugLlamaRMSNorm(config.target_hidden_size, eps=config.target_rms_norm_eps)

    def forward(
        self,
        hidden_states,
        compressed_hidden_states,
        attention_mask=None,
        position_ids=None,
        past_key_value=None,
        output_attentions=False,
        use_cache=False,
        cache_position=None,
    ):
        residual = hidden_states
        compressed_residual = compressed_hidden_states

        hidden_states = self.input_layernorm(hidden_states)
        compressed_hidden_states = self.target_input_layernorm(compressed_hidden_states)

        # loss_dict = {}

        # Self Attention
        hidden_states, compressed_hidden_states, self_attn_weights, present_key_value, loss_dict = self.self_attn(
            hidden_states=hidden_states,
            compressed_hidden_states=compressed_hidden_states,
            loss_dict={},
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
        )

        hidden_states = residual + hidden_states
        compressed_hidden_states = compressed_hidden_states + compressed_residual

        # Fully Connected (MLP)
        residual = hidden_states
        compressed_residual = compressed_hidden_states

        hidden_states = self.post_attention_layernorm(hidden_states)
        compressed_hidden_states = self.target_post_attention_layernorm(compressed_hidden_states)

        hidden_states, compressed_hidden_states, loss_dict = self.mlp(hidden_states, compressed_hidden_states, loss_dict)

        hidden_states = residual + hidden_states
        compressed_hidden_states = compressed_hidden_states + compressed_residual

        # MLP end
        # print("debug 3", loss_dict)
        outputs = (hidden_states, compressed_hidden_states, loss_dict)

        if output_attentions:
            outputs += (self_attn_weights,)
        if use_cache:
            outputs += (present_key_value,)

        return outputs
    
    def merge_weight(self):
        self.input_layernorm.weight.data = self.target_input_layernorm.weight.data
        self.post_attention_layernorm.weight.data = self.target_post_attention_layernorm.weight.data
    def init_norm(self):
       
        #important_indices = [0, 1, 2, 3, 8, 11, 12, 14, 16, 18, 19, 20, 22, 24, 26, 27, 28, 31, 33, 34, 35, 37, 38, 40, 42, 43, 46, 47, 49, 57, 58, 63, 64, 65, 66, 70, 74, 77, 78, 79, 80, 81, 84, 85, 87, 88, 89, 93, 99, 103, 105, 107, 108, 109, 110, 112, 119, 120, 121, 122, 123, 124, 125, 127, 128, 132, 133, 137, 138, 141, 142, 143, 145, 147, 148, 150, 151, 156, 158, 160, 161, 163, 165, 166, 167, 174, 176, 177, 180, 183, 184, 187, 191, 194, 195, 196, 197, 201, 203, 204, 205, 207, 212, 215, 216, 217, 218, 219, 221, 223, 224, 227, 230, 232, 233, 234, 235, 236, 237, 238, 239, 240, 242, 243, 245, 246, 247, 251, 254, 256, 260, 262, 263, 264, 265, 267, 272, 273, 274, 275, 277, 278, 279, 280, 281, 282, 287, 289, 291, 292, 294, 298, 300, 301, 303, 304, 307, 308, 310, 311, 312, 314, 319, 322, 323, 326, 328, 329, 331, 332, 333, 335, 338, 343, 348, 349, 350, 351, 353, 354, 355, 356, 357, 359, 360, 364, 366, 368, 370, 372, 373, 375, 376, 377, 379, 380, 384, 385, 388, 389, 393, 398, 399, 400, 403, 404, 409, 411, 412, 413, 415, 416, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 431, 433, 435, 436, 437, 438, 441, 443, 446, 447, 448, 453, 454, 456, 457, 458, 459, 460, 462, 466, 469, 472, 476, 478, 480, 481, 482, 484, 485, 486, 487, 488, 489, 492, 495, 496, 497, 500, 501, 507, 509, 511, 512, 513, 514, 517, 518, 519, 520, 522, 524, 525, 527, 528, 530, 531, 533, 535, 538, 539, 540, 542, 554, 555, 557, 559, 562, 563, 564, 567, 568, 570, 571, 573, 574, 575, 580, 581, 584, 585, 586, 588, 589, 591, 593, 595, 596, 600, 605, 606, 608, 611, 613, 615, 616, 617, 621, 624, 627, 628, 629, 630, 634, 642, 643, 647, 649, 650, 651, 653, 658, 661, 662, 663, 664, 666, 667, 668, 669, 671, 672, 674, 676, 677, 678, 685, 687, 688, 689, 690, 692, 696, 698, 700, 702, 703, 704, 705, 708, 710, 712, 713, 714, 717, 718, 719, 720, 723, 726, 727, 728, 730, 732, 733, 737, 738, 739, 742, 743, 744, 745, 746, 747, 749, 752, 753, 755, 756, 757, 758, 759, 760, 763, 764, 766, 767, 769, 770, 775, 776, 777, 778, 779, 780, 781, 785, 786, 787, 790, 793, 794, 795, 796, 798, 801, 803, 806, 808, 810, 811, 814, 816, 817, 820, 822, 823, 824, 825, 827, 830, 832, 834, 835, 836, 838, 839, 840, 842, 843, 844, 849, 851, 852, 853, 854, 855, 856, 859, 860, 864, 865, 866, 868, 869, 870, 873, 875, 876, 878, 881, 882, 883, 884, 887, 891, 893, 895, 896, 897, 898, 900, 901, 902, 906, 907, 909, 913, 914, 917, 919, 921, 922, 925, 927, 928, 929, 931, 932, 933, 934, 936, 938, 940, 945, 946, 952, 953, 954, 956, 957, 958, 960, 962, 963, 965, 967, 968, 972, 973, 976, 977, 981, 983, 984, 985, 987, 993, 994, 996, 997, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1016, 1018, 1023, 1024, 1026, 1027, 1031, 1032, 1034, 1035, 1037, 1041, 1043, 1044, 1045, 1047, 1048, 1049, 1050, 1052, 1053, 1055, 1056, 1057, 1058, 1060, 1069, 1072, 1076, 1077, 1079, 1082, 1083, 1084, 1085, 1086, 1087, 1088, 1090, 1095, 1097, 1099, 1102, 1103, 1104, 1105, 1106, 1109, 1110, 1111, 1113, 1114, 1115, 1119, 1122, 1123, 1129, 1131, 1133, 1136, 1139, 1144, 1145, 1146, 1147, 1151, 1154, 1155, 1156, 1158, 1161, 1162, 1166, 1167, 1170, 1171, 1172, 1175, 1176, 1177, 1179, 1181, 1184, 1185, 1188, 1189, 1194, 1195, 1197, 1200, 1203, 1205, 1208, 1209, 1210, 1212, 1215, 1216, 1217, 1218, 1220, 1221, 1223, 1224, 1225, 1226, 1227, 1228, 1229, 1230, 1234, 1235, 1236, 1237, 1238, 1239, 1241, 1243, 1244, 1249, 1251, 1252, 1253, 1254, 1255, 1259, 1264, 1266, 1268, 1270, 1271, 1272, 1277, 1278, 1279, 1280, 1281, 1283, 1285, 1286, 1288, 1290, 1291, 1292, 1293, 1295, 1299, 1302, 1304, 1305, 1306, 1307, 1308, 1310, 1313, 1314, 1315, 1316, 1319, 1320, 1326, 1332, 1333, 1335, 1336, 1342, 1345, 1346, 1347, 1348, 1349, 1354, 1355, 1356, 1358, 1361, 1362, 1363, 1364, 1367, 1368, 1369, 1372, 1374, 1375, 1376, 1378, 1379, 1380, 1382, 1389, 1391, 1392, 1394, 1395, 1398, 1399, 1401, 1402, 1403, 1404, 1407, 1408, 1409, 1410, 1416, 1417, 1422, 1424, 1425, 1428, 1431, 1433, 1435, 1437, 1439, 1440, 1443, 1444, 1447, 1448, 1453, 1454, 1456, 1457, 1458, 1459, 1465, 1466, 1467, 1470, 1471, 1474, 1476, 1477, 1479, 1480, 1482, 1484, 1485, 1487, 1488, 1493, 1499, 1500, 1503, 1505, 1507, 1508, 1510, 1515, 1519, 1522, 1523, 1524, 1525, 1526, 1527, 1528, 1529, 1531, 1533, 1534, 1535, 1536, 1537, 1539, 1540, 1541, 1542, 1543, 1544, 1545, 1546, 1552, 1553, 1554, 1557, 1558, 1559, 1560, 1562, 1563, 1564, 1566, 1568, 1571, 1572, 1573, 1576, 1578, 1579, 1582, 1584, 1585, 1586, 1589, 1591, 1595, 1597, 1602, 1603, 1604, 1605, 1608, 1609, 1610, 1614, 1616, 1617, 1619, 1622, 1624, 1627, 1629, 1630, 1632, 1634, 1638, 1639, 1640, 1641, 1643, 1644, 1646, 1647, 1648, 1649, 1650, 1651, 1653, 1654, 1655, 1657, 1659, 1670, 1673, 1674, 1675, 1677, 1678, 1680, 1681, 1683, 1685, 1686, 1687, 1688, 1690, 1692, 1693, 1694, 1696, 1698, 1705, 1709, 1710, 1711, 1712, 1713, 1714, 1715, 1716, 1717, 1719, 1722, 1724, 1726, 1727, 1728, 1729, 1730, 1731, 1739, 1742, 1743, 1744, 1745, 1746, 1747, 1755, 1756, 1760, 1761, 1762, 1765, 1774, 1777, 1778, 1781, 1784, 1786, 1788, 1789, 1791, 1792, 1793, 1794, 1795, 1796, 1797, 1798, 1800, 1803, 1804, 1805, 1806, 1807, 1808, 1809, 1810, 1812, 1814, 1815, 1817, 1820, 1823, 1825, 1826, 1827, 1830, 1831, 1837, 1840, 1842, 1844, 1845, 1847, 1848, 1851, 1852, 1853, 1857, 1859, 1861, 1862, 1866, 1867, 1874, 1878, 1879, 1882, 1886, 1887, 1888, 1889, 1894, 1895, 1897, 1899, 1901, 1902, 1903, 1904, 1905, 1906, 1909, 1910, 1911, 1914, 1915, 1916, 1917, 1924, 1925, 1926, 1927, 1928, 1930, 1932, 1934, 1935, 1940, 1941, 1943, 1947, 1948, 1950, 1951, 1955, 1956, 1958, 1964, 1965, 1967, 1968, 1973, 1974, 1976, 1978, 1980, 1981, 1982, 1983, 1986, 1988, 1991, 1992, 1994, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2007, 2008, 2010, 2011, 2015, 2016, 2018, 2019, 2021, 2022, 2024, 2026, 2027, 2029, 2030, 2033, 2035, 2036, 2041, 2042, 2044, 2045, 2046, 2047, 2048, 2049, 2052, 2053, 2054, 2056, 2059, 2065, 2066, 2067, 2071, 2072, 2073, 2076, 2079, 2080, 2083, 2086, 2087, 2090, 2092, 2094, 2095, 2096, 2098, 2099, 2100, 2101, 2102, 2104, 2107, 2109, 2112, 2113, 2115, 2117, 2119, 2121, 2122, 2124, 2125, 2126, 2127, 2130, 2132, 2136, 2139, 2141, 2143, 2145, 2147, 2151, 2152, 2154, 2156, 2161, 2165, 2166, 2168, 2171, 2172, 2176, 2177, 2178, 2181, 2183, 2184, 2185, 2188, 2189, 2192, 2193, 2195, 2196, 2197, 2198, 2200, 2202, 2203, 2204, 2205, 2209, 2210, 2212, 2214, 2216, 2217, 2218, 2219, 2223, 2224, 2225, 2227, 2228, 2231, 2233, 2234, 2236, 2237, 2239, 2240, 2243, 2244, 2245, 2247, 2250, 2254, 2255, 2257, 2258, 2262, 2263, 2264, 2265, 2267, 2268, 2270, 2271, 2273, 2275, 2279, 2280, 2287, 2288, 2289, 2292, 2294, 2297, 2299, 2303, 2305, 2307, 2308, 2312, 2317, 2320, 2322, 2326, 2327, 2328, 2329, 2330, 2331, 2336, 2338, 2339, 2344, 2345, 2346, 2348, 2352, 2354, 2355, 2356, 2357, 2358, 2359, 2360, 2362, 2363, 2364, 2366, 2369, 2370, 2371, 2372, 2373, 2374, 2378, 2379, 2380, 2381, 2383, 2386, 2388, 2392, 2393, 2395, 2398, 2400, 2403, 2407, 2411, 2412, 2420, 2423, 2425, 2428, 2430, 2432, 2433, 2434, 2435, 2437, 2444, 2445, 2446, 2448, 2452, 2453, 2454, 2458, 2461, 2462, 2465, 2469, 2472, 2473, 2474, 2476, 2478, 2480, 2482, 2484, 2486, 2487, 2489, 2493, 2497, 2498, 2502, 2507, 2509, 2510, 2513, 2515, 2517, 2528, 2530, 2531, 2535, 2536, 2537, 2540, 2541, 2544, 2551, 2552, 2558, 2560, 2563, 2574, 2576, 2577, 2579, 2581, 2582, 2584, 2585, 2587, 2588, 2589, 2590, 2593, 2599, 2602, 2603, 2607, 2609, 2610, 2612, 2613, 2614, 2615, 2618, 2619, 2621, 2626, 2629, 2630, 2632, 2634, 2637, 2638, 2647, 2648, 2655, 2656, 2658, 2661, 2663, 2668, 2669, 2670, 2671, 2673, 2676, 2678, 2679, 2680, 2681, 2682, 2683, 2685, 2686, 2688, 2689, 2690, 2692, 2696, 2697, 2698, 2701, 2702, 2703, 2704, 2707, 2709, 2712, 2713, 2714, 2715, 2717, 2720, 2721, 2722, 2723, 2724, 2726, 2727, 2734, 2739, 2741, 2748, 2749, 2750, 2752, 2758, 2759, 2760, 2761, 2764, 2766, 2771, 2773, 2774, 2775, 2776, 2777, 2778, 2779, 2781, 2783, 2784, 2786, 2787, 2789, 2791, 2792, 2795, 2797, 2799, 2800, 2801, 2802, 2803, 2805, 2807, 2808, 2809, 2814, 2817, 2822, 2823, 2824, 2825, 2834, 2837, 2839, 2841, 2844, 2848, 2849, 2852, 2854, 2856, 2858, 2859, 2860, 2864, 2868, 2873, 2874, 2875, 2876, 2878, 2879, 2881, 2883, 2884, 2885, 2888, 2892, 2895, 2896, 2898, 2902, 2905, 2912, 2914, 2917, 2920, 2921, 2922, 2923, 2924, 2926, 2927, 2930, 2931, 2932, 2934, 2936, 2937, 2938, 2939, 2940, 2941, 2943, 2945, 2946, 2949, 2952, 2953, 2954, 2962, 2963, 2965, 2966, 2970, 2971, 2973, 2986, 2988, 2990, 2992, 2993, 2994, 2995, 2998, 2999, 3000, 3004, 3005, 3006, 3008, 3009, 3010, 3018, 3020, 3026, 3027, 3029, 3031, 3033, 3034, 3036, 3037, 3039, 3046, 3049, 3050, 3052, 3053, 3057, 3061, 3062, 3064, 3065, 3066, 3067, 3069, 3070]
        important_indices = [3, 5, 8, 9, 11, 13, 14, 17, 19, 20, 23, 24, 26, 27, 28, 29, 31, 32, 33, 36, 37, 39, 46, 49, 51, 52, 53, 55, 59, 60, 62, 66, 67, 68, 69, 70, 71, 73, 76, 77, 78, 83, 84, 85, 89, 95, 96, 97, 99, 100, 101, 104, 107, 108, 109, 111, 112, 113, 114, 115, 116, 119, 120, 121, 124, 125, 127, 128, 129, 130, 131, 132, 134, 135, 136, 138, 140, 143, 144, 145, 152, 154, 155, 156, 157, 160, 161, 162, 163, 164, 165, 166, 167, 168, 171, 172, 173, 177, 179, 180, 183, 184, 186, 187, 188, 189, 190, 191, 192, 193, 198, 199, 200, 201, 203, 206, 207, 209, 211, 212, 216, 217, 219, 220, 221, 222, 226, 227, 228, 229, 233, 234, 236, 238, 239, 244, 245, 252, 253, 255, 257, 258, 259, 265, 266, 268, 270, 273, 274, 277, 278, 280, 282, 284, 286, 289, 291, 292, 293, 294, 295, 296, 297, 300, 301, 302, 306, 307, 308, 310, 311, 314, 316, 319, 321, 322, 323, 324, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 339, 340, 341, 342, 343, 344, 348, 351, 355, 358, 359, 362, 363, 364, 365, 366, 368, 369, 373, 374, 376, 377, 382, 385, 387, 388, 389, 392, 393, 394, 396, 397, 399, 400, 402, 404, 405, 409, 410, 418, 420, 421, 422, 423, 424, 425, 428, 429, 431, 432, 433, 435, 436, 438, 439, 440, 444, 445, 446, 451, 453, 456, 457, 458, 459, 461, 463, 464, 466, 467, 468, 469, 470, 471, 475, 476, 477, 481, 482, 484, 489, 490, 495, 497, 500, 501, 502, 503, 504, 505, 507, 509, 510, 515, 516, 517, 518, 520, 523, 525, 528, 529, 534, 536, 537, 538, 540, 541, 542, 543, 544, 547, 548, 550, 553, 554, 555, 556, 557, 559, 561, 562, 563, 564, 566, 569, 571, 572, 573, 574, 575, 576, 578, 581, 583, 585, 586, 587, 589, 591, 595, 596, 597, 598, 599, 601, 602, 604, 607, 608, 612, 613, 616, 617, 621, 622, 623, 624, 625, 628, 630, 631, 632, 634, 635, 637, 639, 640, 641, 645, 649, 650, 651, 653, 654, 656, 659, 660, 661, 662, 666, 667, 670, 671, 677, 680, 682, 684, 686, 687, 688, 689, 694, 695, 696, 698, 700, 703, 704, 706, 707, 710, 711, 712, 713, 714, 715, 716, 718, 720, 725, 727, 728, 732, 734, 736, 737, 738, 740, 742, 743, 747, 749, 750, 752, 755, 756, 759, 761, 762, 763, 765, 767, 768, 770, 772, 773, 775, 776, 779, 780, 781, 783, 784, 786, 787, 791, 792, 793, 795, 796, 799, 800, 801, 802, 805, 806, 807, 808, 809, 810, 811, 816, 817, 819, 820, 821, 822, 824, 825, 826, 827, 828, 829, 830, 834, 835, 836, 841, 842, 844, 845, 846, 847, 848, 850, 851, 852, 853, 854, 855, 856, 857, 859, 861, 862, 864, 865, 867, 868, 871, 876, 877, 879, 882, 883, 888, 889, 890, 892, 893, 895, 896, 898, 899, 901, 902, 903, 909, 912, 913, 915, 918, 928, 931, 933, 935, 936, 938, 939, 940, 941, 942, 943, 944, 945, 947, 948, 951, 952, 953, 954, 956, 957, 958, 959, 960, 963, 964, 965, 966, 967, 968, 972, 980, 981, 982, 983, 984, 985, 986, 987, 989, 991, 992, 994, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1007, 1009, 1012, 1013, 1014, 1017, 1018, 1019, 1020, 1021, 1023, 1024, 1026, 1028, 1029, 1033, 1034, 1035, 1038, 1039, 1040, 1041, 1043, 1045, 1047, 1050, 1051, 1053, 1056, 1058, 1059, 1060, 1061, 1062, 1064, 1065, 1066, 1067, 1068, 1069, 1071, 1074, 1076, 1078, 1079, 1081, 1082, 1085, 1086, 1087, 1088, 1089, 1090, 1091, 1092, 1093, 1095, 1099, 1101, 1102, 1103, 1104, 1106, 1107, 1110, 1111, 1114, 1116, 1117, 1118, 1121, 1122, 1123, 1124, 1126, 1128, 1129, 1130, 1131, 1133, 1134, 1136, 1138, 1139, 1140, 1142, 1144, 1145, 1146, 1147, 1149, 1150, 1152, 1154, 1157, 1158, 1160, 1162, 1163, 1166, 1167, 1169, 1174, 1175, 1177, 1180, 1182, 1188, 1189, 1190, 1191, 1195, 1196, 1197, 1199, 1203, 1205, 1207, 1208, 1209, 1210, 1211, 1212, 1213, 1219, 1220, 1221, 1222, 1224, 1226, 1230, 1234, 1238, 1239, 1248, 1249, 1257, 1259, 1260, 1261, 1265, 1266, 1268, 1272, 1273, 1275, 1277, 1278, 1279, 1280, 1281, 1282, 1283, 1285, 1287, 1288, 1290, 1291, 1292, 1294, 1295, 1297, 1298, 1299, 1300, 1301, 1302, 1303, 1304, 1306, 1309, 1311, 1313, 1314, 1317, 1319, 1321, 1323, 1324, 1328, 1329, 1331, 1332, 1333, 1339, 1342, 1343, 1345, 1347, 1348, 1351, 1352, 1353, 1354, 1355, 1358, 1359, 1361, 1368, 1369, 1370, 1374, 1377, 1379, 1380, 1381, 1383, 1384, 1385, 1386, 1387, 1388, 1392, 1393, 1394, 1395, 1396, 1397, 1402, 1404, 1405, 1407, 1410, 1411, 1412, 1413, 1414, 1418, 1419, 1420, 1421, 1422, 1423, 1424, 1427, 1428, 1429, 1431, 1432, 1433, 1438, 1439, 1443, 1445, 1451, 1452, 1455, 1456, 1457, 1458, 1459, 1461, 1462, 1464, 1465, 1466, 1467, 1468, 1470, 1473, 1474, 1475, 1478, 1479, 1481, 1482, 1484, 1487, 1488, 1489, 1491, 1494, 1502, 1503, 1507, 1511, 1512, 1514, 1515, 1520, 1526, 1527, 1528, 1529, 1530, 1531, 1533, 1534, 1542, 1545, 1547, 1549, 1550, 1551, 1552, 1553, 1554, 1555, 1556, 1557, 1558, 1559, 1560, 1562, 1563, 1568, 1571, 1572, 1573, 1574, 1576, 1577, 1579, 1581, 1584, 1585, 1586, 1588, 1589, 1590, 1592, 1593, 1594, 1595, 1597, 1598, 1599, 1601, 1604, 1605, 1606, 1607, 1609, 1610, 1611, 1614, 1615, 1617, 1619, 1620, 1621, 1623, 1625, 1626, 1627, 1628, 1631, 1632, 1633, 1634, 1635, 1638, 1639, 1640, 1642, 1644, 1645, 1647, 1648, 1650, 1653, 1655, 1656, 1657, 1660, 1661, 1663, 1666, 1667, 1670, 1672, 1674, 1675, 1676, 1677, 1679, 1680, 1682, 1683, 1684, 1685, 1686, 1688, 1690, 1691, 1692, 1693, 1694, 1697, 1699, 1700, 1704, 1705, 1706, 1707, 1708, 1710, 1711, 1712, 1716, 1717, 1718, 1719, 1720, 1721, 1722, 1723, 1724, 1725, 1730, 1731, 1733, 1735, 1737, 1738, 1740, 1741, 1742, 1743, 1744, 1745, 1746, 1748, 1750, 1751, 1752, 1753, 1755, 1758, 1759, 1760, 1761, 1762, 1763, 1767, 1769, 1776, 1777, 1778, 1780, 1781, 1783, 1785, 1786, 1788, 1790, 1793, 1794, 1795, 1799, 1801, 1803, 1805, 1806, 1811, 1815, 1816, 1817, 1820, 1821, 1823, 1825, 1827, 1828, 1829, 1830, 1831, 1832, 1833, 1835, 1836, 1839, 1840, 1841, 1843, 1844, 1845, 1847, 1848, 1849, 1850, 1851, 1852, 1853, 1854, 1855, 1856, 1857, 1860, 1864, 1865, 1866, 1868, 1869, 1871, 1872, 1874, 1876, 1878, 1881, 1882, 1883, 1885, 1888, 1890, 1891, 1892, 1893, 1894, 1895, 1898, 1900, 1902, 1903, 1904, 1907, 1908, 1909, 1911, 1912, 1915, 1916, 1917, 1918, 1919, 1922, 1923, 1925, 1926, 1927, 1928, 1929, 1930, 1932, 1934, 1936, 1940, 1941, 1942, 1943, 1945, 1949, 1951, 1952, 1955, 1957, 1960, 1962, 1963, 1967, 1968, 1969, 1971, 1976, 1977, 1980, 1981, 1982, 1984, 1985, 1987, 1989, 1990, 1991, 1993, 1995, 1996, 1997, 1999, 2004, 2006, 2007, 2010, 2012, 2014, 2016, 2019, 2022, 2024, 2025, 2026, 2027, 2028, 2029, 2030, 2031, 2032, 2033, 2035, 2036, 2037, 2039, 2040, 2041, 2042, 2044, 2045, 2047, 2048, 2050, 2051, 2053, 2054, 2056, 2059, 2060, 2061, 2063, 2066, 2069, 2070, 2071, 2072, 2073, 2076, 2077, 2078, 2080, 2082, 2084, 2085, 2086, 2088, 2089, 2090, 2091, 2094, 2095, 2096, 2097, 2099, 2100, 2102, 2103, 2104, 2106, 2107, 2109, 2111, 2116, 2117, 2118, 2123, 2126, 2127, 2128, 2130, 2132, 2133, 2136, 2139, 2141, 2143, 2145, 2147, 2148, 2149, 2150, 2152, 2153, 2154, 2157, 2158, 2159, 2161, 2163, 2164, 2167, 2170, 2171, 2172, 2174, 2175, 2177, 2178, 2179, 2180, 2182, 2189, 2190, 2191, 2192, 2193, 2194, 2195, 2196, 2198, 2199, 2200, 2202, 2204, 2205, 2207, 2208, 2209, 2210, 2211, 2212, 2213, 2216, 2218, 2219, 2223, 2225, 2226, 2228, 2230, 2231, 2232, 2234, 2235, 2236, 2240, 2242, 2244, 2245, 2246, 2247, 2249, 2251, 2254, 2255, 2257, 2258, 2259, 2263, 2264, 2266, 2268, 2269, 2271, 2272, 2273, 2275, 2276, 2277, 2278, 2279, 2280, 2281, 2282, 2283, 2284, 2285, 2286, 2287, 2288, 2289, 2290, 2291, 2294, 2299, 2300, 2301, 2302, 2303, 2304, 2305, 2307, 2309, 2311, 2313, 2316, 2318, 2320, 2322, 2323, 2326, 2327, 2328, 2329, 2330, 2331, 2332, 2333, 2334, 2335, 2340, 2341, 2343, 2345, 2346, 2348, 2353, 2354, 2355, 2358, 2359, 2360, 2364, 2365, 2366, 2367, 2368, 2369, 2370, 2371, 2372, 2373, 2374, 2376, 2383, 2384, 2385, 2386, 2387, 2391, 2392, 2394, 2396, 2401, 2402, 2403, 2404, 2405, 2406, 2407, 2408, 2410, 2411, 2413, 2415, 2416, 2418, 2419, 2425, 2426, 2427, 2429, 2431, 2433, 2434, 2435, 2438, 2439, 2440, 2441, 2442, 2444, 2445, 2448, 2449, 2450, 2451, 2452, 2455, 2461, 2464, 2465, 2467, 2470, 2474, 2475, 2476, 2477, 2478, 2479, 2480, 2481, 2483, 2484, 2486, 2488, 2489, 2491, 2494, 2495, 2496, 2498, 2501, 2503, 2510, 2511, 2516, 2518, 2519, 2520, 2523, 2524, 2525, 2527, 2529, 2531, 2533, 2534, 2537, 2539, 2541, 2542, 2543, 2544, 2545, 2547, 2549, 2550, 2551, 2552, 2557, 2558, 2559, 2560, 2561, 2563, 2565, 2566, 2567, 2568, 2569, 2570, 2571, 2572, 2577, 2581, 2582, 2583, 2589, 2590, 2591, 2596, 2598, 2599, 2600, 2601, 2603, 2604, 2605, 2606, 2610, 2612, 2613, 2615, 2617, 2619, 2620, 2622, 2624, 2626, 2629, 2630, 2631, 2632, 2633, 2634, 2637, 2638, 2640, 2641, 2642, 2643, 2644, 2645, 2648, 2650, 2652, 2653, 2655, 2656, 2657, 2658, 2659, 2660, 2662, 2664, 2666, 2667, 2669, 2672, 2673, 2674, 2675, 2677, 2678, 2679, 2680, 2682, 2683, 2684, 2685, 2690, 2693, 2695, 2696, 2697, 2699, 2700, 2701, 2702, 2703, 2704, 2705, 2706, 2708, 2709, 2710, 2711, 2713, 2714, 2715, 2718, 2720, 2721, 2722, 2724, 2725, 2727, 2730, 2731, 2733, 2735, 2737, 2738, 2741, 2742, 2745, 2746, 2748, 2750, 2751, 2753, 2755, 2758, 2759, 2761, 2762, 2763, 2764, 2768, 2769, 2770, 2771, 2773, 2776, 2779, 2780, 2781, 2787, 2788, 2789, 2790, 2792, 2793, 2794, 2795, 2796, 2797, 2800, 2801, 2806, 2809, 2810, 2811, 2812, 2813, 2814, 2817, 2818, 2819, 2827, 2828, 2829, 2830, 2831, 2832, 2835, 2836, 2838, 2840, 2841, 2843, 2849, 2850, 2851, 2854, 2858, 2861, 2863, 2865, 2870, 2872, 2873, 2874, 2877, 2878, 2880, 2882, 2883, 2884, 2886, 2887, 2889, 2890, 2891, 2894, 2896, 2897, 2898, 2900, 2901, 2905, 2906, 2907, 2911, 2914, 2915, 2918, 2919, 2920, 2921, 2922, 2923, 2926, 2927, 2928, 2932, 2934, 2940, 2943, 2944, 2945, 2946, 2947, 2948, 2950, 2953, 2954, 2955, 2956, 2957, 2958, 2959, 2960, 2965, 2967, 2968, 2969, 2970, 2971, 2972, 2973, 2975, 2977, 2978, 2979, 2980, 2981, 2982, 2985, 2987, 2988, 2989, 2991, 2992, 2995, 2996, 2997, 2998, 2999, 3003, 3005, 3008, 3009, 3010, 3013, 3016, 3017, 3020, 3021, 3025, 3026, 3028, 3030, 3031, 3032, 3036, 3037, 3043, 3045, 3046, 3047, 3050, 3051, 3053, 3058, 3059, 3063, 3064, 3066, 3067, 3068, 3069, 3070, 3071, 3073, 3074, 3077, 3078, 3080, 3082, 3083, 3084, 3085, 3090, 3092, 3095, 3101, 3102, 3103, 3104, 3108, 3110, 3111, 3113, 3117, 3119, 3120, 3121, 3123, 3124, 3125, 3126, 3127, 3128, 3129, 3131, 3132, 3133, 3134, 3135, 3137, 3138, 3142, 3143, 3148, 3152, 3154, 3155, 3157, 3159, 3160, 3162, 3163, 3167, 3168, 3169, 3170, 3171, 3172, 3173, 3174, 3177, 3178, 3180, 3181, 3182, 3184, 3185, 3186, 3188, 3189, 3191, 3192, 3195, 3197, 3199, 3202, 3203, 3205, 3206, 3207, 3209, 3210, 3213, 3215, 3216, 3217, 3218, 3219, 3222, 3224, 3226, 3228, 3229, 3230, 3231, 3232, 3233, 3234, 3235, 3238, 3239, 3240, 3241, 3244, 3245, 3246, 3248, 3250, 3251, 3254, 3255, 3256, 3257, 3260, 3261, 3262, 3263, 3264, 3265, 3270, 3271, 3272, 3273, 3275, 3276, 3278, 3279, 3281, 3283, 3286, 3288, 3289, 3290, 3293, 3296, 3297, 3298, 3300, 3301, 3302, 3303, 3304, 3305, 3308, 3311, 3312, 3314, 3316, 3318, 3319, 3320, 3323, 3325, 3327, 3329, 3330, 3333, 3338, 3341, 3344, 3345, 3349, 3356, 3357, 3358, 3359, 3360, 3361, 3363, 3364, 3366, 3369, 3376, 3378, 3379, 3380, 3382, 3387, 3388, 3389, 3391, 3392, 3394, 3398, 3402, 3403, 3404, 3405, 3407, 3409, 3410, 3412, 3413, 3414, 3415, 3417, 3419, 3420, 3423, 3424, 3430, 3431, 3432, 3433, 3435, 3436, 3437, 3438, 3439, 3441, 3443, 3444, 3446, 3447, 3452, 3453, 3456, 3457, 3461, 3462, 3463, 3464, 3465, 3466, 3471, 3473, 3475, 3477, 3478, 3479, 3481, 3482, 3486, 3487, 3488, 3489, 3490, 3492, 3493, 3495, 3496, 3498, 3501, 3502, 3503, 3505, 3507, 3508, 3509, 3510, 3512, 3513, 3514, 3515, 3516, 3517, 3522, 3524, 3525, 3529, 3531, 3533, 3535, 3536, 3538, 3540, 3541, 3545, 3546, 3548, 3549, 3550, 3551, 3552, 3555, 3557, 3558, 3560, 3561, 3562, 3563, 3564, 3565, 3566, 3567, 3570, 3571, 3572, 3573, 3574, 3575, 3576, 3577, 3578, 3579, 3580, 3581]
       
        if len(important_indices) != self.config.target_hidden_size:
            raise ValueError(
           
                f"dim not macth"
            )

      
        device = self.input_layernorm.weight.device
        indices_tensor = torch.tensor(important_indices, dtype=torch.long, device=device)

     
        with torch.no_grad():
            # 1. 初始化 target_input_layernorm
            selected_weights_input = self.input_layernorm.weight.data.index_select(0, indices_tensor)
            self.target_input_layernorm.weight.data.copy_(selected_weights_input)

            # 2. 初始化 target_post_attention_layernorm
            selected_weights_post = self.post_attention_layernorm.weight.data.index_select(0, indices_tensor)
            self.target_post_attention_layernorm.weight.data.copy_(selected_weights_post)


@dataclass
class IIModelOutput(ModelOutput):
    last_hidden_state: torch.FloatTensor = None
    compressed_hidden_states: torch.FloatTensor = None
    aux_loss: torch.FloatTensor = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None


class Model(Qwen2Model):
    _no_split_modules = ["CustomLayer"]

    def __init__(self, config: CustomConfig):
        super().__init__(config)

        self.zoom = nn.Linear(config.hidden_size, config.target_hidden_size, bias=False)
        self.layers = nn.ModuleList(
            [CustomLayer(config, layer_idx) if layer_idx not in config.del_layers else Qwen2DecoderLayer(config, layer_idx)
             for layer_idx in range(config.num_hidden_layers)]
        )
        self.target_norm = DebugLlamaRMSNorm(config.target_hidden_size, eps=config.target_rms_norm_eps)
        self.cur_step = 0

        self.post_init()

    def merge_weight(self):
        self.embed_tokens.weight.data = self.zoom(self.embed_tokens.weight.data).contiguous()
        self.norm.weight.data = self.target_norm.weight.data
    def init_norm(self):
      
        #important_indices = [0, 1, 2, 3, 8, 11, 12, 14, 16, 18, 19, 20, 22, 24, 26, 27, 28, 31, 33, 34, 35, 37, 38, 40, 42, 43, 46, 47, 49, 57, 58, 63, 64, 65, 66, 70, 74, 77, 78, 79, 80, 81, 84, 85, 87, 88, 89, 93, 99, 103, 105, 107, 108, 109, 110, 112, 119, 120, 121, 122, 123, 124, 125, 127, 128, 132, 133, 137, 138, 141, 142, 143, 145, 147, 148, 150, 151, 156, 158, 160, 161, 163, 165, 166, 167, 174, 176, 177, 180, 183, 184, 187, 191, 194, 195, 196, 197, 201, 203, 204, 205, 207, 212, 215, 216, 217, 218, 219, 221, 223, 224, 227, 230, 232, 233, 234, 235, 236, 237, 238, 239, 240, 242, 243, 245, 246, 247, 251, 254, 256, 260, 262, 263, 264, 265, 267, 272, 273, 274, 275, 277, 278, 279, 280, 281, 282, 287, 289, 291, 292, 294, 298, 300, 301, 303, 304, 307, 308, 310, 311, 312, 314, 319, 322, 323, 326, 328, 329, 331, 332, 333, 335, 338, 343, 348, 349, 350, 351, 353, 354, 355, 356, 357, 359, 360, 364, 366, 368, 370, 372, 373, 375, 376, 377, 379, 380, 384, 385, 388, 389, 393, 398, 399, 400, 403, 404, 409, 411, 412, 413, 415, 416, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 431, 433, 435, 436, 437, 438, 441, 443, 446, 447, 448, 453, 454, 456, 457, 458, 459, 460, 462, 466, 469, 472, 476, 478, 480, 481, 482, 484, 485, 486, 487, 488, 489, 492, 495, 496, 497, 500, 501, 507, 509, 511, 512, 513, 514, 517, 518, 519, 520, 522, 524, 525, 527, 528, 530, 531, 533, 535, 538, 539, 540, 542, 554, 555, 557, 559, 562, 563, 564, 567, 568, 570, 571, 573, 574, 575, 580, 581, 584, 585, 586, 588, 589, 591, 593, 595, 596, 600, 605, 606, 608, 611, 613, 615, 616, 617, 621, 624, 627, 628, 629, 630, 634, 642, 643, 647, 649, 650, 651, 653, 658, 661, 662, 663, 664, 666, 667, 668, 669, 671, 672, 674, 676, 677, 678, 685, 687, 688, 689, 690, 692, 696, 698, 700, 702, 703, 704, 705, 708, 710, 712, 713, 714, 717, 718, 719, 720, 723, 726, 727, 728, 730, 732, 733, 737, 738, 739, 742, 743, 744, 745, 746, 747, 749, 752, 753, 755, 756, 757, 758, 759, 760, 763, 764, 766, 767, 769, 770, 775, 776, 777, 778, 779, 780, 781, 785, 786, 787, 790, 793, 794, 795, 796, 798, 801, 803, 806, 808, 810, 811, 814, 816, 817, 820, 822, 823, 824, 825, 827, 830, 832, 834, 835, 836, 838, 839, 840, 842, 843, 844, 849, 851, 852, 853, 854, 855, 856, 859, 860, 864, 865, 866, 868, 869, 870, 873, 875, 876, 878, 881, 882, 883, 884, 887, 891, 893, 895, 896, 897, 898, 900, 901, 902, 906, 907, 909, 913, 914, 917, 919, 921, 922, 925, 927, 928, 929, 931, 932, 933, 934, 936, 938, 940, 945, 946, 952, 953, 954, 956, 957, 958, 960, 962, 963, 965, 967, 968, 972, 973, 976, 977, 981, 983, 984, 985, 987, 993, 994, 996, 997, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1016, 1018, 1023, 1024, 1026, 1027, 1031, 1032, 1034, 1035, 1037, 1041, 1043, 1044, 1045, 1047, 1048, 1049, 1050, 1052, 1053, 1055, 1056, 1057, 1058, 1060, 1069, 1072, 1076, 1077, 1079, 1082, 1083, 1084, 1085, 1086, 1087, 1088, 1090, 1095, 1097, 1099, 1102, 1103, 1104, 1105, 1106, 1109, 1110, 1111, 1113, 1114, 1115, 1119, 1122, 1123, 1129, 1131, 1133, 1136, 1139, 1144, 1145, 1146, 1147, 1151, 1154, 1155, 1156, 1158, 1161, 1162, 1166, 1167, 1170, 1171, 1172, 1175, 1176, 1177, 1179, 1181, 1184, 1185, 1188, 1189, 1194, 1195, 1197, 1200, 1203, 1205, 1208, 1209, 1210, 1212, 1215, 1216, 1217, 1218, 1220, 1221, 1223, 1224, 1225, 1226, 1227, 1228, 1229, 1230, 1234, 1235, 1236, 1237, 1238, 1239, 1241, 1243, 1244, 1249, 1251, 1252, 1253, 1254, 1255, 1259, 1264, 1266, 1268, 1270, 1271, 1272, 1277, 1278, 1279, 1280, 1281, 1283, 1285, 1286, 1288, 1290, 1291, 1292, 1293, 1295, 1299, 1302, 1304, 1305, 1306, 1307, 1308, 1310, 1313, 1314, 1315, 1316, 1319, 1320, 1326, 1332, 1333, 1335, 1336, 1342, 1345, 1346, 1347, 1348, 1349, 1354, 1355, 1356, 1358, 1361, 1362, 1363, 1364, 1367, 1368, 1369, 1372, 1374, 1375, 1376, 1378, 1379, 1380, 1382, 1389, 1391, 1392, 1394, 1395, 1398, 1399, 1401, 1402, 1403, 1404, 1407, 1408, 1409, 1410, 1416, 1417, 1422, 1424, 1425, 1428, 1431, 1433, 1435, 1437, 1439, 1440, 1443, 1444, 1447, 1448, 1453, 1454, 1456, 1457, 1458, 1459, 1465, 1466, 1467, 1470, 1471, 1474, 1476, 1477, 1479, 1480, 1482, 1484, 1485, 1487, 1488, 1493, 1499, 1500, 1503, 1505, 1507, 1508, 1510, 1515, 1519, 1522, 1523, 1524, 1525, 1526, 1527, 1528, 1529, 1531, 1533, 1534, 1535, 1536, 1537, 1539, 1540, 1541, 1542, 1543, 1544, 1545, 1546, 1552, 1553, 1554, 1557, 1558, 1559, 1560, 1562, 1563, 1564, 1566, 1568, 1571, 1572, 1573, 1576, 1578, 1579, 1582, 1584, 1585, 1586, 1589, 1591, 1595, 1597, 1602, 1603, 1604, 1605, 1608, 1609, 1610, 1614, 1616, 1617, 1619, 1622, 1624, 1627, 1629, 1630, 1632, 1634, 1638, 1639, 1640, 1641, 1643, 1644, 1646, 1647, 1648, 1649, 1650, 1651, 1653, 1654, 1655, 1657, 1659, 1670, 1673, 1674, 1675, 1677, 1678, 1680, 1681, 1683, 1685, 1686, 1687, 1688, 1690, 1692, 1693, 1694, 1696, 1698, 1705, 1709, 1710, 1711, 1712, 1713, 1714, 1715, 1716, 1717, 1719, 1722, 1724, 1726, 1727, 1728, 1729, 1730, 1731, 1739, 1742, 1743, 1744, 1745, 1746, 1747, 1755, 1756, 1760, 1761, 1762, 1765, 1774, 1777, 1778, 1781, 1784, 1786, 1788, 1789, 1791, 1792, 1793, 1794, 1795, 1796, 1797, 1798, 1800, 1803, 1804, 1805, 1806, 1807, 1808, 1809, 1810, 1812, 1814, 1815, 1817, 1820, 1823, 1825, 1826, 1827, 1830, 1831, 1837, 1840, 1842, 1844, 1845, 1847, 1848, 1851, 1852, 1853, 1857, 1859, 1861, 1862, 1866, 1867, 1874, 1878, 1879, 1882, 1886, 1887, 1888, 1889, 1894, 1895, 1897, 1899, 1901, 1902, 1903, 1904, 1905, 1906, 1909, 1910, 1911, 1914, 1915, 1916, 1917, 1924, 1925, 1926, 1927, 1928, 1930, 1932, 1934, 1935, 1940, 1941, 1943, 1947, 1948, 1950, 1951, 1955, 1956, 1958, 1964, 1965, 1967, 1968, 1973, 1974, 1976, 1978, 1980, 1981, 1982, 1983, 1986, 1988, 1991, 1992, 1994, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2007, 2008, 2010, 2011, 2015, 2016, 2018, 2019, 2021, 2022, 2024, 2026, 2027, 2029, 2030, 2033, 2035, 2036, 2041, 2042, 2044, 2045, 2046, 2047, 2048, 2049, 2052, 2053, 2054, 2056, 2059, 2065, 2066, 2067, 2071, 2072, 2073, 2076, 2079, 2080, 2083, 2086, 2087, 2090, 2092, 2094, 2095, 2096, 2098, 2099, 2100, 2101, 2102, 2104, 2107, 2109, 2112, 2113, 2115, 2117, 2119, 2121, 2122, 2124, 2125, 2126, 2127, 2130, 2132, 2136, 2139, 2141, 2143, 2145, 2147, 2151, 2152, 2154, 2156, 2161, 2165, 2166, 2168, 2171, 2172, 2176, 2177, 2178, 2181, 2183, 2184, 2185, 2188, 2189, 2192, 2193, 2195, 2196, 2197, 2198, 2200, 2202, 2203, 2204, 2205, 2209, 2210, 2212, 2214, 2216, 2217, 2218, 2219, 2223, 2224, 2225, 2227, 2228, 2231, 2233, 2234, 2236, 2237, 2239, 2240, 2243, 2244, 2245, 2247, 2250, 2254, 2255, 2257, 2258, 2262, 2263, 2264, 2265, 2267, 2268, 2270, 2271, 2273, 2275, 2279, 2280, 2287, 2288, 2289, 2292, 2294, 2297, 2299, 2303, 2305, 2307, 2308, 2312, 2317, 2320, 2322, 2326, 2327, 2328, 2329, 2330, 2331, 2336, 2338, 2339, 2344, 2345, 2346, 2348, 2352, 2354, 2355, 2356, 2357, 2358, 2359, 2360, 2362, 2363, 2364, 2366, 2369, 2370, 2371, 2372, 2373, 2374, 2378, 2379, 2380, 2381, 2383, 2386, 2388, 2392, 2393, 2395, 2398, 2400, 2403, 2407, 2411, 2412, 2420, 2423, 2425, 2428, 2430, 2432, 2433, 2434, 2435, 2437, 2444, 2445, 2446, 2448, 2452, 2453, 2454, 2458, 2461, 2462, 2465, 2469, 2472, 2473, 2474, 2476, 2478, 2480, 2482, 2484, 2486, 2487, 2489, 2493, 2497, 2498, 2502, 2507, 2509, 2510, 2513, 2515, 2517, 2528, 2530, 2531, 2535, 2536, 2537, 2540, 2541, 2544, 2551, 2552, 2558, 2560, 2563, 2574, 2576, 2577, 2579, 2581, 2582, 2584, 2585, 2587, 2588, 2589, 2590, 2593, 2599, 2602, 2603, 2607, 2609, 2610, 2612, 2613, 2614, 2615, 2618, 2619, 2621, 2626, 2629, 2630, 2632, 2634, 2637, 2638, 2647, 2648, 2655, 2656, 2658, 2661, 2663, 2668, 2669, 2670, 2671, 2673, 2676, 2678, 2679, 2680, 2681, 2682, 2683, 2685, 2686, 2688, 2689, 2690, 2692, 2696, 2697, 2698, 2701, 2702, 2703, 2704, 2707, 2709, 2712, 2713, 2714, 2715, 2717, 2720, 2721, 2722, 2723, 2724, 2726, 2727, 2734, 2739, 2741, 2748, 2749, 2750, 2752, 2758, 2759, 2760, 2761, 2764, 2766, 2771, 2773, 2774, 2775, 2776, 2777, 2778, 2779, 2781, 2783, 2784, 2786, 2787, 2789, 2791, 2792, 2795, 2797, 2799, 2800, 2801, 2802, 2803, 2805, 2807, 2808, 2809, 2814, 2817, 2822, 2823, 2824, 2825, 2834, 2837, 2839, 2841, 2844, 2848, 2849, 2852, 2854, 2856, 2858, 2859, 2860, 2864, 2868, 2873, 2874, 2875, 2876, 2878, 2879, 2881, 2883, 2884, 2885, 2888, 2892, 2895, 2896, 2898, 2902, 2905, 2912, 2914, 2917, 2920, 2921, 2922, 2923, 2924, 2926, 2927, 2930, 2931, 2932, 2934, 2936, 2937, 2938, 2939, 2940, 2941, 2943, 2945, 2946, 2949, 2952, 2953, 2954, 2962, 2963, 2965, 2966, 2970, 2971, 2973, 2986, 2988, 2990, 2992, 2993, 2994, 2995, 2998, 2999, 3000, 3004, 3005, 3006, 3008, 3009, 3010, 3018, 3020, 3026, 3027, 3029, 3031, 3033, 3034, 3036, 3037, 3039, 3046, 3049, 3050, 3052, 3053, 3057, 3061, 3062, 3064, 3065, 3066, 3067, 3069, 3070]
        important_indices = [3, 5, 8, 9, 11, 13, 14, 17, 19, 20, 23, 24, 26, 27, 28, 29, 31, 32, 33, 36, 37, 39, 46, 49, 51, 52, 53, 55, 59, 60, 62, 66, 67, 68, 69, 70, 71, 73, 76, 77, 78, 83, 84, 85, 89, 95, 96, 97, 99, 100, 101, 104, 107, 108, 109, 111, 112, 113, 114, 115, 116, 119, 120, 121, 124, 125, 127, 128, 129, 130, 131, 132, 134, 135, 136, 138, 140, 143, 144, 145, 152, 154, 155, 156, 157, 160, 161, 162, 163, 164, 165, 166, 167, 168, 171, 172, 173, 177, 179, 180, 183, 184, 186, 187, 188, 189, 190, 191, 192, 193, 198, 199, 200, 201, 203, 206, 207, 209, 211, 212, 216, 217, 219, 220, 221, 222, 226, 227, 228, 229, 233, 234, 236, 238, 239, 244, 245, 252, 253, 255, 257, 258, 259, 265, 266, 268, 270, 273, 274, 277, 278, 280, 282, 284, 286, 289, 291, 292, 293, 294, 295, 296, 297, 300, 301, 302, 306, 307, 308, 310, 311, 314, 316, 319, 321, 322, 323, 324, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 339, 340, 341, 342, 343, 344, 348, 351, 355, 358, 359, 362, 363, 364, 365, 366, 368, 369, 373, 374, 376, 377, 382, 385, 387, 388, 389, 392, 393, 394, 396, 397, 399, 400, 402, 404, 405, 409, 410, 418, 420, 421, 422, 423, 424, 425, 428, 429, 431, 432, 433, 435, 436, 438, 439, 440, 444, 445, 446, 451, 453, 456, 457, 458, 459, 461, 463, 464, 466, 467, 468, 469, 470, 471, 475, 476, 477, 481, 482, 484, 489, 490, 495, 497, 500, 501, 502, 503, 504, 505, 507, 509, 510, 515, 516, 517, 518, 520, 523, 525, 528, 529, 534, 536, 537, 538, 540, 541, 542, 543, 544, 547, 548, 550, 553, 554, 555, 556, 557, 559, 561, 562, 563, 564, 566, 569, 571, 572, 573, 574, 575, 576, 578, 581, 583, 585, 586, 587, 589, 591, 595, 596, 597, 598, 599, 601, 602, 604, 607, 608, 612, 613, 616, 617, 621, 622, 623, 624, 625, 628, 630, 631, 632, 634, 635, 637, 639, 640, 641, 645, 649, 650, 651, 653, 654, 656, 659, 660, 661, 662, 666, 667, 670, 671, 677, 680, 682, 684, 686, 687, 688, 689, 694, 695, 696, 698, 700, 703, 704, 706, 707, 710, 711, 712, 713, 714, 715, 716, 718, 720, 725, 727, 728, 732, 734, 736, 737, 738, 740, 742, 743, 747, 749, 750, 752, 755, 756, 759, 761, 762, 763, 765, 767, 768, 770, 772, 773, 775, 776, 779, 780, 781, 783, 784, 786, 787, 791, 792, 793, 795, 796, 799, 800, 801, 802, 805, 806, 807, 808, 809, 810, 811, 816, 817, 819, 820, 821, 822, 824, 825, 826, 827, 828, 829, 830, 834, 835, 836, 841, 842, 844, 845, 846, 847, 848, 850, 851, 852, 853, 854, 855, 856, 857, 859, 861, 862, 864, 865, 867, 868, 871, 876, 877, 879, 882, 883, 888, 889, 890, 892, 893, 895, 896, 898, 899, 901, 902, 903, 909, 912, 913, 915, 918, 928, 931, 933, 935, 936, 938, 939, 940, 941, 942, 943, 944, 945, 947, 948, 951, 952, 953, 954, 956, 957, 958, 959, 960, 963, 964, 965, 966, 967, 968, 972, 980, 981, 982, 983, 984, 985, 986, 987, 989, 991, 992, 994, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1007, 1009, 1012, 1013, 1014, 1017, 1018, 1019, 1020, 1021, 1023, 1024, 1026, 1028, 1029, 1033, 1034, 1035, 1038, 1039, 1040, 1041, 1043, 1045, 1047, 1050, 1051, 1053, 1056, 1058, 1059, 1060, 1061, 1062, 1064, 1065, 1066, 1067, 1068, 1069, 1071, 1074, 1076, 1078, 1079, 1081, 1082, 1085, 1086, 1087, 1088, 1089, 1090, 1091, 1092, 1093, 1095, 1099, 1101, 1102, 1103, 1104, 1106, 1107, 1110, 1111, 1114, 1116, 1117, 1118, 1121, 1122, 1123, 1124, 1126, 1128, 1129, 1130, 1131, 1133, 1134, 1136, 1138, 1139, 1140, 1142, 1144, 1145, 1146, 1147, 1149, 1150, 1152, 1154, 1157, 1158, 1160, 1162, 1163, 1166, 1167, 1169, 1174, 1175, 1177, 1180, 1182, 1188, 1189, 1190, 1191, 1195, 1196, 1197, 1199, 1203, 1205, 1207, 1208, 1209, 1210, 1211, 1212, 1213, 1219, 1220, 1221, 1222, 1224, 1226, 1230, 1234, 1238, 1239, 1248, 1249, 1257, 1259, 1260, 1261, 1265, 1266, 1268, 1272, 1273, 1275, 1277, 1278, 1279, 1280, 1281, 1282, 1283, 1285, 1287, 1288, 1290, 1291, 1292, 1294, 1295, 1297, 1298, 1299, 1300, 1301, 1302, 1303, 1304, 1306, 1309, 1311, 1313, 1314, 1317, 1319, 1321, 1323, 1324, 1328, 1329, 1331, 1332, 1333, 1339, 1342, 1343, 1345, 1347, 1348, 1351, 1352, 1353, 1354, 1355, 1358, 1359, 1361, 1368, 1369, 1370, 1374, 1377, 1379, 1380, 1381, 1383, 1384, 1385, 1386, 1387, 1388, 1392, 1393, 1394, 1395, 1396, 1397, 1402, 1404, 1405, 1407, 1410, 1411, 1412, 1413, 1414, 1418, 1419, 1420, 1421, 1422, 1423, 1424, 1427, 1428, 1429, 1431, 1432, 1433, 1438, 1439, 1443, 1445, 1451, 1452, 1455, 1456, 1457, 1458, 1459, 1461, 1462, 1464, 1465, 1466, 1467, 1468, 1470, 1473, 1474, 1475, 1478, 1479, 1481, 1482, 1484, 1487, 1488, 1489, 1491, 1494, 1502, 1503, 1507, 1511, 1512, 1514, 1515, 1520, 1526, 1527, 1528, 1529, 1530, 1531, 1533, 1534, 1542, 1545, 1547, 1549, 1550, 1551, 1552, 1553, 1554, 1555, 1556, 1557, 1558, 1559, 1560, 1562, 1563, 1568, 1571, 1572, 1573, 1574, 1576, 1577, 1579, 1581, 1584, 1585, 1586, 1588, 1589, 1590, 1592, 1593, 1594, 1595, 1597, 1598, 1599, 1601, 1604, 1605, 1606, 1607, 1609, 1610, 1611, 1614, 1615, 1617, 1619, 1620, 1621, 1623, 1625, 1626, 1627, 1628, 1631, 1632, 1633, 1634, 1635, 1638, 1639, 1640, 1642, 1644, 1645, 1647, 1648, 1650, 1653, 1655, 1656, 1657, 1660, 1661, 1663, 1666, 1667, 1670, 1672, 1674, 1675, 1676, 1677, 1679, 1680, 1682, 1683, 1684, 1685, 1686, 1688, 1690, 1691, 1692, 1693, 1694, 1697, 1699, 1700, 1704, 1705, 1706, 1707, 1708, 1710, 1711, 1712, 1716, 1717, 1718, 1719, 1720, 1721, 1722, 1723, 1724, 1725, 1730, 1731, 1733, 1735, 1737, 1738, 1740, 1741, 1742, 1743, 1744, 1745, 1746, 1748, 1750, 1751, 1752, 1753, 1755, 1758, 1759, 1760, 1761, 1762, 1763, 1767, 1769, 1776, 1777, 1778, 1780, 1781, 1783, 1785, 1786, 1788, 1790, 1793, 1794, 1795, 1799, 1801, 1803, 1805, 1806, 1811, 1815, 1816, 1817, 1820, 1821, 1823, 1825, 1827, 1828, 1829, 1830, 1831, 1832, 1833, 1835, 1836, 1839, 1840, 1841, 1843, 1844, 1845, 1847, 1848, 1849, 1850, 1851, 1852, 1853, 1854, 1855, 1856, 1857, 1860, 1864, 1865, 1866, 1868, 1869, 1871, 1872, 1874, 1876, 1878, 1881, 1882, 1883, 1885, 1888, 1890, 1891, 1892, 1893, 1894, 1895, 1898, 1900, 1902, 1903, 1904, 1907, 1908, 1909, 1911, 1912, 1915, 1916, 1917, 1918, 1919, 1922, 1923, 1925, 1926, 1927, 1928, 1929, 1930, 1932, 1934, 1936, 1940, 1941, 1942, 1943, 1945, 1949, 1951, 1952, 1955, 1957, 1960, 1962, 1963, 1967, 1968, 1969, 1971, 1976, 1977, 1980, 1981, 1982, 1984, 1985, 1987, 1989, 1990, 1991, 1993, 1995, 1996, 1997, 1999, 2004, 2006, 2007, 2010, 2012, 2014, 2016, 2019, 2022, 2024, 2025, 2026, 2027, 2028, 2029, 2030, 2031, 2032, 2033, 2035, 2036, 2037, 2039, 2040, 2041, 2042, 2044, 2045, 2047, 2048, 2050, 2051, 2053, 2054, 2056, 2059, 2060, 2061, 2063, 2066, 2069, 2070, 2071, 2072, 2073, 2076, 2077, 2078, 2080, 2082, 2084, 2085, 2086, 2088, 2089, 2090, 2091, 2094, 2095, 2096, 2097, 2099, 2100, 2102, 2103, 2104, 2106, 2107, 2109, 2111, 2116, 2117, 2118, 2123, 2126, 2127, 2128, 2130, 2132, 2133, 2136, 2139, 2141, 2143, 2145, 2147, 2148, 2149, 2150, 2152, 2153, 2154, 2157, 2158, 2159, 2161, 2163, 2164, 2167, 2170, 2171, 2172, 2174, 2175, 2177, 2178, 2179, 2180, 2182, 2189, 2190, 2191, 2192, 2193, 2194, 2195, 2196, 2198, 2199, 2200, 2202, 2204, 2205, 2207, 2208, 2209, 2210, 2211, 2212, 2213, 2216, 2218, 2219, 2223, 2225, 2226, 2228, 2230, 2231, 2232, 2234, 2235, 2236, 2240, 2242, 2244, 2245, 2246, 2247, 2249, 2251, 2254, 2255, 2257, 2258, 2259, 2263, 2264, 2266, 2268, 2269, 2271, 2272, 2273, 2275, 2276, 2277, 2278, 2279, 2280, 2281, 2282, 2283, 2284, 2285, 2286, 2287, 2288, 2289, 2290, 2291, 2294, 2299, 2300, 2301, 2302, 2303, 2304, 2305, 2307, 2309, 2311, 2313, 2316, 2318, 2320, 2322, 2323, 2326, 2327, 2328, 2329, 2330, 2331, 2332, 2333, 2334, 2335, 2340, 2341, 2343, 2345, 2346, 2348, 2353, 2354, 2355, 2358, 2359, 2360, 2364, 2365, 2366, 2367, 2368, 2369, 2370, 2371, 2372, 2373, 2374, 2376, 2383, 2384, 2385, 2386, 2387, 2391, 2392, 2394, 2396, 2401, 2402, 2403, 2404, 2405, 2406, 2407, 2408, 2410, 2411, 2413, 2415, 2416, 2418, 2419, 2425, 2426, 2427, 2429, 2431, 2433, 2434, 2435, 2438, 2439, 2440, 2441, 2442, 2444, 2445, 2448, 2449, 2450, 2451, 2452, 2455, 2461, 2464, 2465, 2467, 2470, 2474, 2475, 2476, 2477, 2478, 2479, 2480, 2481, 2483, 2484, 2486, 2488, 2489, 2491, 2494, 2495, 2496, 2498, 2501, 2503, 2510, 2511, 2516, 2518, 2519, 2520, 2523, 2524, 2525, 2527, 2529, 2531, 2533, 2534, 2537, 2539, 2541, 2542, 2543, 2544, 2545, 2547, 2549, 2550, 2551, 2552, 2557, 2558, 2559, 2560, 2561, 2563, 2565, 2566, 2567, 2568, 2569, 2570, 2571, 2572, 2577, 2581, 2582, 2583, 2589, 2590, 2591, 2596, 2598, 2599, 2600, 2601, 2603, 2604, 2605, 2606, 2610, 2612, 2613, 2615, 2617, 2619, 2620, 2622, 2624, 2626, 2629, 2630, 2631, 2632, 2633, 2634, 2637, 2638, 2640, 2641, 2642, 2643, 2644, 2645, 2648, 2650, 2652, 2653, 2655, 2656, 2657, 2658, 2659, 2660, 2662, 2664, 2666, 2667, 2669, 2672, 2673, 2674, 2675, 2677, 2678, 2679, 2680, 2682, 2683, 2684, 2685, 2690, 2693, 2695, 2696, 2697, 2699, 2700, 2701, 2702, 2703, 2704, 2705, 2706, 2708, 2709, 2710, 2711, 2713, 2714, 2715, 2718, 2720, 2721, 2722, 2724, 2725, 2727, 2730, 2731, 2733, 2735, 2737, 2738, 2741, 2742, 2745, 2746, 2748, 2750, 2751, 2753, 2755, 2758, 2759, 2761, 2762, 2763, 2764, 2768, 2769, 2770, 2771, 2773, 2776, 2779, 2780, 2781, 2787, 2788, 2789, 2790, 2792, 2793, 2794, 2795, 2796, 2797, 2800, 2801, 2806, 2809, 2810, 2811, 2812, 2813, 2814, 2817, 2818, 2819, 2827, 2828, 2829, 2830, 2831, 2832, 2835, 2836, 2838, 2840, 2841, 2843, 2849, 2850, 2851, 2854, 2858, 2861, 2863, 2865, 2870, 2872, 2873, 2874, 2877, 2878, 2880, 2882, 2883, 2884, 2886, 2887, 2889, 2890, 2891, 2894, 2896, 2897, 2898, 2900, 2901, 2905, 2906, 2907, 2911, 2914, 2915, 2918, 2919, 2920, 2921, 2922, 2923, 2926, 2927, 2928, 2932, 2934, 2940, 2943, 2944, 2945, 2946, 2947, 2948, 2950, 2953, 2954, 2955, 2956, 2957, 2958, 2959, 2960, 2965, 2967, 2968, 2969, 2970, 2971, 2972, 2973, 2975, 2977, 2978, 2979, 2980, 2981, 2982, 2985, 2987, 2988, 2989, 2991, 2992, 2995, 2996, 2997, 2998, 2999, 3003, 3005, 3008, 3009, 3010, 3013, 3016, 3017, 3020, 3021, 3025, 3026, 3028, 3030, 3031, 3032, 3036, 3037, 3043, 3045, 3046, 3047, 3050, 3051, 3053, 3058, 3059, 3063, 3064, 3066, 3067, 3068, 3069, 3070, 3071, 3073, 3074, 3077, 3078, 3080, 3082, 3083, 3084, 3085, 3090, 3092, 3095, 3101, 3102, 3103, 3104, 3108, 3110, 3111, 3113, 3117, 3119, 3120, 3121, 3123, 3124, 3125, 3126, 3127, 3128, 3129, 3131, 3132, 3133, 3134, 3135, 3137, 3138, 3142, 3143, 3148, 3152, 3154, 3155, 3157, 3159, 3160, 3162, 3163, 3167, 3168, 3169, 3170, 3171, 3172, 3173, 3174, 3177, 3178, 3180, 3181, 3182, 3184, 3185, 3186, 3188, 3189, 3191, 3192, 3195, 3197, 3199, 3202, 3203, 3205, 3206, 3207, 3209, 3210, 3213, 3215, 3216, 3217, 3218, 3219, 3222, 3224, 3226, 3228, 3229, 3230, 3231, 3232, 3233, 3234, 3235, 3238, 3239, 3240, 3241, 3244, 3245, 3246, 3248, 3250, 3251, 3254, 3255, 3256, 3257, 3260, 3261, 3262, 3263, 3264, 3265, 3270, 3271, 3272, 3273, 3275, 3276, 3278, 3279, 3281, 3283, 3286, 3288, 3289, 3290, 3293, 3296, 3297, 3298, 3300, 3301, 3302, 3303, 3304, 3305, 3308, 3311, 3312, 3314, 3316, 3318, 3319, 3320, 3323, 3325, 3327, 3329, 3330, 3333, 3338, 3341, 3344, 3345, 3349, 3356, 3357, 3358, 3359, 3360, 3361, 3363, 3364, 3366, 3369, 3376, 3378, 3379, 3380, 3382, 3387, 3388, 3389, 3391, 3392, 3394, 3398, 3402, 3403, 3404, 3405, 3407, 3409, 3410, 3412, 3413, 3414, 3415, 3417, 3419, 3420, 3423, 3424, 3430, 3431, 3432, 3433, 3435, 3436, 3437, 3438, 3439, 3441, 3443, 3444, 3446, 3447, 3452, 3453, 3456, 3457, 3461, 3462, 3463, 3464, 3465, 3466, 3471, 3473, 3475, 3477, 3478, 3479, 3481, 3482, 3486, 3487, 3488, 3489, 3490, 3492, 3493, 3495, 3496, 3498, 3501, 3502, 3503, 3505, 3507, 3508, 3509, 3510, 3512, 3513, 3514, 3515, 3516, 3517, 3522, 3524, 3525, 3529, 3531, 3533, 3535, 3536, 3538, 3540, 3541, 3545, 3546, 3548, 3549, 3550, 3551, 3552, 3555, 3557, 3558, 3560, 3561, 3562, 3563, 3564, 3565, 3566, 3567, 3570, 3571, 3572, 3573, 3574, 3575, 3576, 3577, 3578, 3579, 3580, 3581]
     
        if len(important_indices) != self.target_norm.weight.shape[0]:
            raise ValueError(
            
                f"dim not macth"
            )

       
        device = self.norm.weight.device

  
        indices_tensor = torch.tensor(important_indices, dtype=torch.long, device=device)

  
        selected_weights = self.norm.weight.data.index_select(0, indices_tensor)

        
        with torch.no_grad():
            self.target_norm.weight.data.copy_(selected_weights)
       

     
        for i, layer in enumerate(self.layers):
         
            if isinstance(layer, CustomLayer):
               
             
                layer.init_norm()
            else:
                pass
        

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        position_ids=None,
        past_key_values=None,
        inputs_embeds=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        cache_position=None,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
            )

        if self.gradient_checkpointing and self.training and use_cache:
            raise NotImplementedError

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        # causal_mask = self._update_causal_mask(
        #     attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        # )
        assert attention_mask is None

        # embed positions
        hidden_states = inputs_embeds
        Wemb = self.zoom(self.embed_tokens.weight).to(device=input_ids.device)
        if os.environ.get("DEBUG", False):
            print("emb token", Wemb[0, :6])
        compressed_hidden_states = embedding(input_ids, Wemb)
        # assert not torch.isnan(compressed_hidden_states).any(), f"NaN detected in model output in af emb"

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = None
        aux_loss = 0

        # set state for logging loss
        grad_acumulation_steps = hyper_params["gradient_accumulation_steps"]
        cur_train_step = None
        if (self.cur_step + 1) % (grad_acumulation_steps * 20) == 0:
            cur_train_step = (self.cur_step + 1) // grad_acumulation_steps
        self.cur_step += 1
        
        for layer_idx, decoder_layer in enumerate(self.layers):
            if self.gradient_checkpointing and self.training:
                raise NotImplementedError
            
            if layer_idx not in self.config.del_layers:
                layer_outputs = decoder_layer(
                    hidden_states,
                    compressed_hidden_states,
                    attention_mask=None,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                )

                compressed_hidden_states = layer_outputs[1]
                loss_dict = layer_outputs[2]
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=None,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                )
            
            hidden_states = layer_outputs[0]

            if layer_idx not in self.config.del_layers:
                _log_dict = {}
                for k, v in loss_dict.items():
                    if self.config.use_aux_loss:
                        if isinstance(aux_loss, torch.Tensor):
                            aux_loss = aux_loss.to(v.device)
                        aux_loss = aux_loss + v * hyper_params["aux_loss_scale_factor"]
                    main_logger.debug(f"L{decoder_layer.layer_idx}-{k}: {v.item()}")
                    
                    if cur_train_step:
                        _log_dict[f"L{decoder_layer.layer_idx}-{k}"] = v.item()
                
                if cur_train_step and (os.environ.get("LOCAL_RANK", 0) == 0 or accelerator.is_main_process) and len(_log_dict) > 0:
                    wandb.log(_log_dict, cur_train_step)

        hidden_states = self.norm(hidden_states)
        compressed_hidden_states = self.target_norm(compressed_hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        if not return_dict:
            return tuple(v for v in [hidden_states, None, all_hidden_states, all_self_attns] if v is not None)
        
        return IIModelOutput(
            last_hidden_state=hidden_states,
            compressed_hidden_states=compressed_hidden_states,
            aux_loss=aux_loss,
            past_key_values=None,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )
    


def calculate_language_loss(lgts, labels, vocab_size):
    loss = None
    if labels is not None:
        # Shift so that tokens < n predict n
        shift_logits = lgts[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        loss_fct = CrossEntropyLoss()
        shift_logits = shift_logits.view(-1, vocab_size)
        shift_labels = shift_labels.view(-1)
        # Enable model parallelism
        shift_labels = shift_labels.to(shift_logits.device)
        loss = loss_fct(shift_logits, shift_labels)
    return loss


class CoTrainLM(Qwen2ForCausalLM):
    _tied_weights_keys = ["lm_head.weight"]
    def __init__(self, config: CustomConfig):
        super().__init__(config)
        self.model = Model(config)
        # self.zoom_up = nn.Linear(config.target_hidden_size, config.hidden_size, bias=False)
        if not config.tie_word_emb_proj:
            self.zoom_down = nn.Linear(config.hidden_size, config.target_hidden_size, bias=False)
            self.zoom_down.weight.data.normal_(mean=0.0, std=0.01)  # no init weights
        self.mseloss = LOSS_DICT[config.aux_loss_type]()
        self.kl_temperature = self.config.kl_temperature
        self.cur_step = 0
        self.cur_loss_accumulation = 0
        self.cur_logit_loss_accumulation = 0
        self.check_data_cls_loss = config.check_data_cls_loss
        self.data_cls_losses = [0] * 8
        self.data_cls_cnt = [0] * 8
        self.post_init()

    def merge_weight(self):
        # print(self.lm_head.weight.data.shape)
        if not self.config.tie_word_emb_proj:
            self.lm_head.weight.data = self.zoom_down(self.lm_head.weight.data).contiguous()
        else:
            self.lm_head.weight.data = self.model.zoom(self.lm_head.weight.data).contiguous()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        data_cls=None,
        position_ids=None,
        past_key_values=None,
        inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        cache_position=None,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )

        hidden_states = outputs[0]
        compressed_hidden_states = outputs[1]
        aux_loss = outputs[2]

        logits = self.lm_head(hidden_states)
        if not self.config.tie_word_emb_proj:
            Whead = self.zoom_down(self.lm_head.weight)
        else:
            Whead = self.model.zoom(self.lm_head.weight)
        if os.environ.get("DEBUG", False):
            print("head weight", Whead[0, :6])
        target_logits = linear(compressed_hidden_states, Whead)

        if self.config.use_logits_loss:
            target_logp = F.log_softmax(target_logits / self.kl_temperature, dim=-1)
            raw_logp = F.log_softmax(logits / self.kl_temperature, dim=-1)
            logits_loss = F.kl_div(target_logp, raw_logp, log_target=True, reduction="batchmean")
            # logits_loss = self.mseloss(target_logits, logits)
            aux_loss = aux_loss + logits_loss
            main_logger.debug(f"logits_loss: {round(logits_loss.item(), 4)}")
        
        raw_loss = calculate_language_loss(logits.float(), labels, self.config.vocab_size)
        target_loss = calculate_language_loss(target_logits.float(), labels, self.config.vocab_size)
        main_logger.debug(f"raw_loss: {round(raw_loss.item(), 4)}, target_loss: {round(target_loss.item(), 4)}")

        # wandb log
        self.cur_loss_accumulation += target_loss.item()
        if self.config.use_logits_loss:
            self.cur_logit_loss_accumulation += logits_loss.item()
        loss_log_steps = hyper_params["gradient_accumulation_steps"] * 5
        if self.check_data_cls_loss:
            assert hidden_states.shape[0] == 1, "only appliable in bs = 1"
            spec_cls = data_cls[0].item()
            self.data_cls_cnt[spec_cls] += 1
            self.data_cls_losses[spec_cls] += target_loss.item()
        if (self.cur_step + 1) % loss_log_steps == 0:
            cur_train_step = (self.cur_step + 1) // hyper_params["gradient_accumulation_steps"]
            _log_dict = {"target_loss": self.cur_loss_accumulation / loss_log_steps}
            if self.config.use_logits_loss:
                _log_dict["logits_loss"] = self.cur_logit_loss_accumulation / loss_log_steps
            # self.kl_temperature = 0.9 * self.kl_temperature + 0.1 * _log_dict["target_loss"] * 1.5
            if self.check_data_cls_loss:
                _log_dict.update({
                    f"{data_cls_reversed_dict[i]}_loss": loss / self.data_cls_cnt[i] 
                    for i, loss in enumerate(self.data_cls_losses) if self.data_cls_cnt[i] > 0
                })
            if (os.environ.get("LOCAL_RANK", 0) == 0 or accelerator.is_main_process):
                wandb.log(_log_dict, step=cur_train_step)
            self.cur_loss_accumulation = 0
            self.cur_logit_loss_accumulation = 0
            self.data_cls_cnt = [0] * 8
            self.data_cls_losses = [0] * 8
        self.cur_step += 1

        if not return_dict:
            raise NotImplementedError

        return CausalLMOutputWithPast(
            loss=target_loss + aux_loss if self.config.use_ntp_loss else aux_loss,
            logits=target_logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def freeze_original_model(self):
        key_words = ["zoom", "target"]
        for n, p in self.named_parameters():
            flag = False

            for key in key_words:
                if key in n:
                    flag = True

            p.requires_grad = flag

    def tie_custom_weights(self, tie_n):
        raise ValueError("low perf")
        layers = self.model.layers
        for i in range(2, self.config.num_hidden_layers - 1, tie_n):
            share_layer: CustomLayer = layers[i]
            for j in range(i + 1, min(self.config.num_hidden_layers - 1, i + tie_n)):
                cur_layer: CustomLayer = layers[j]
                cur_layer.mlp.zoom.weight = share_layer.mlp.zoom.weight
                cur_layer.self_attn.zoom_up.weight = share_layer.self_attn.zoom_up.weight
                cur_layer.self_attn.zoom_down.weight = share_layer.self_attn.zoom_down.weight

    def tie_word_emb_proj(self):
        # self.model.zoom.weight = self.zoom_down.weight
        self.zoom_down.weight = self.model.zoom.weight

    def get_trained_params(self):
        state_dict = {}
        for n, p in self.named_parameters():
            if p.requires_grad:
                state_dict[n] = p
        return state_dict

    def save_pretrained(self, *args, **kwargs):
        if kwargs.get("only_save_trainable", True):
            state_dict = self.get_trained_params()
            kwargs["state_dict"] = state_dict
        return super().save_pretrained(*args, **kwargs)
