import torch
import wandb
import os
import transformers

assert transformers.__version__ == "4.41.2"

from torch import nn
from torch.nn.functional import linear, embedding
from transformers.models.qwen2.modeling_qwen2 import *
from transformers.modeling_outputs import ModelOutput
from tools.log import main_logger
from dataclasses import dataclass
from tools.global_state import hyper_params, data_cls_reversed_dict, ban_losses, ban_layers
from accelerate import Accelerator


accelerator = Accelerator()


class BigValueFirstLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mseloss = nn.MSELoss(reduction="none")

    def forward(self, output, target):
        return torch.mean(torch.abs(target + 1e-2) * self.mseloss(output, target))


class MSELossV2(nn.Module):
    def __init__(self):
        super().__init__()
        self.mseloss = nn.MSELoss(reduction="none")
    
    def forward(self, output, target):
        return self.mseloss(output, target).sum(dim=-1).mean()


class L1LossV2(nn.Module):
    def forward(self, output, target):
        return ((output - target).abs().sum(dim=-1)/50.0).mean()


LOSS_DICT = {
    "mseloss": nn.MSELoss,
    "mseloss_v2": MSELossV2,
    "l1loss": nn.L1Loss,
    "l1loss_v2": L1LossV2,
    # "big_value_first": BigValueFirstLoss
}


class CustomConfig(Qwen2Config):
    def set_custom_kwargs(self, **kwargs):
        # required
        self.target_hidden_size = kwargs["target_hidden_size"]
        self.use_attn_map = kwargs.get("use_attn_map", False)
        self.target_rms_norm_eps = kwargs.get("target_rms_norm_eps", self.rms_norm_eps)
        self.use_aux_loss = kwargs.get("use_aux_loss", True)
        self.use_std_like_attn = kwargs.get("use_std_like_attn", False)
        self.use_logits_loss = kwargs.get("use_logits_loss", True)
        self.use_ntp_loss = kwargs.get("use_ntp_loss", True)
        self.check_data_cls_loss = kwargs.get("check_data_cls_loss", False)
        self.kl_temperature = kwargs.get("kl_temperature", 10.0)
        self.aux_loss_type = kwargs.get("aux_loss_type", "mseloss")
        self.student_attn_from_scratch = kwargs.get("student_attn_from_scratch", False)
        self.tie_word_emb_proj = kwargs.get("tie_word_emb_proj", False)
        self.del_layers = kwargs.get("del_layers", [])
        self.use_in_out_mlp = kwargs.get("use_in_out_mlp", False)
        self.use_all_attn = kwargs.get("use_all_attn", False)


class AllAttn(Qwen2FlashAttention2):
    def __init__(self, config: CustomConfig, layer_idx = None):
        super().__init__(config, layer_idx)
        self.config = config
        self.zoom_q = nn.Linear(config.target_hidden_size, self.hidden_size, bias=False)
        self.zoom_k = nn.Linear(config.target_hidden_size, self.hidden_size, bias=False)
        self.zoom_v = nn.Linear(config.target_hidden_size, self.hidden_size, bias=False)
        self.zoom_down = nn.Linear(self.hidden_size, config.target_hidden_size, bias=False)
        self.mseloss = LOSS_DICT[config.aux_loss_type]()
        self.layer_idx = layer_idx

    def part_forward(self, query_states, key_states, value_states, bsz, q_len, position_ids,
                     past_key_value=None, attention_mask=None):
        output_attentions = False

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            if self.layer_idx is None:
                raise ValueError(
                    f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
                    "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                    "with a layer index."
                )
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

        # Because the input can be padded, the absolute sequence length depends on the max position id.
        rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
        cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)

        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

        use_sliding_windows = (
            getattr(self.config, "sliding_window", None) is not None
            and kv_seq_len > self.config.sliding_window
            and self.config.use_sliding_window
        )
        if use_sliding_windows:
            raise NotImplementedError

        if past_key_value is not None:
            raise NotImplementedError

        # repeat k/v heads if n_kv_heads < n_heads
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)
        dropout_rate = 0.0 if not self.training else self.attention_dropout

        # In PEFT, usually we cast the layer norms in float32 for training stability reasons
        # therefore the input hidden states gets silently casted in float32. Hence, we need
        # cast them back in float16 just to be sure everything works as expected.
        input_dtype = query_states.dtype
        if input_dtype == torch.float32:
            if torch.is_autocast_enabled():
                target_dtype = torch.get_autocast_gpu_dtype()
            # Handle the case where the model is quantized
            elif hasattr(self.config, "_pre_quantization_dtype"):
                target_dtype = self.config._pre_quantization_dtype
            else:
                target_dtype = self.q_proj.weight.dtype

            logger.warning_once(
                f"The input hidden states seems to be silently casted in float32, this might be related to"
                f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
                f" {target_dtype}."
            )

            query_states = query_states.to(target_dtype)
            key_states = key_states.to(target_dtype)
            value_states = value_states.to(target_dtype)

        # Reashape to the expected shape for Flash Attention
        query_states = query_states.transpose(1, 2)
        key_states = key_states.transpose(1, 2)
        value_states = value_states.transpose(1, 2)

        attn_output = self._flash_attention_forward(
            query_states,
            key_states,
            value_states,
            attention_mask,
            q_len,
            dropout=dropout_rate,
            use_sliding_windows=use_sliding_windows,
        )

        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value

    def forward(
        self,
        hidden_states,
        compressed_hidden_states,
        loss_dict,
        attention_mask=None,
        position_ids=None,
        past_key_value=None,
        output_attentions=False,
        use_cache=False,
        **kwargs,
    ):
        assert past_key_value is None
        output_attentions = False

        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        _up_hs_q = self.q_proj(self.zoom_q(compressed_hidden_states))
        _up_hs_k = self.k_proj(self.zoom_k(compressed_hidden_states))
        _up_hs_v = self.v_proj(self.zoom_v(compressed_hidden_states))

        raw_out, raw_attn_map, _ = self.part_forward(query_states, key_states, value_states, bsz, q_len, position_ids,
                                                     attention_mask=None)
        out, attn_map, _ = self.part_forward(_up_hs_q, _up_hs_k, _up_hs_v, bsz, q_len, position_ids,
                                             attention_mask=None)
        compressed_hidden_states = self.zoom_down(out)

        if "attn-q-sim-loss" not in ban_losses and self.layer_idx not in ban_layers:
            loss_dict["attn-q-sim-loss"] = self.mseloss(_up_hs_q, query_states)
        if "attn-k-sim-loss" not in ban_losses and self.layer_idx not in ban_layers:
            loss_dict["attn-k-sim-loss"] = self.mseloss(_up_hs_k, key_states)
        if "attn-v-sim-loss" not in ban_losses and self.layer_idx not in ban_layers:
            loss_dict["attn-v-sim-loss"] = self.mseloss(_up_hs_v, value_states)
        if "attn-k-sim-loss" not in ban_losses and self.layer_idx not in ban_layers:
            loss_dict["attn-out-sim-loss"] = self.mseloss(out, raw_out)

        return raw_out, compressed_hidden_states, raw_attn_map, _, loss_dict
    
    def merge_weight(self):
        self.q_proj.weight.data = (self.q_proj.weight.data @ self.zoom_q.weight.data).contiguous()
        self.k_proj.weight.data = (self.k_proj.weight.data @ self.zoom_k.weight.data).contiguous()
        self.v_proj.weight.data = (self.v_proj.weight.data @ self.zoom_v.weight.data).contiguous()
        self.o_proj.weight.data = (self.zoom_down.weight.data @ self.o_proj.weight.data).contiguous()


class Attn(Qwen2FlashAttention2):
    def __init__(self, config: CustomConfig, layer_idx = None):
        super().__init__(config, layer_idx)
        self.config = config
        self.zoom_up = nn.Linear(config.target_hidden_size, self.hidden_size, bias=False)
        self.zoom_down = nn.Linear(self.hidden_size, config.target_hidden_size, bias=False)
        self.mseloss = LOSS_DICT[config.aux_loss_type]()
        self.layer_idx = layer_idx

    def forward(
        self,
        hidden_states,
        compressed_hidden_states,
        loss_dict,
        attention_mask=None,
        position_ids=None,
        past_key_value=None,
        output_attentions=False,
        use_cache=False,
        cache_position=None,
        **kwargs,
    ):
        output_attentions = self.config.use_attn_map
        if output_attentions:
            raise NotImplementedError
        assert attention_mask is None
        assert past_key_value is None
        
        raw_out, raw_attn_map, _ = super().forward(
            hidden_states,
            attention_mask,
            position_ids,
            past_key_value,
            output_attentions,
            use_cache,
            # cache_position,
            # **kwargs,
        )
        # assert not torch.isnan(compressed_hidden_states).any(), f"NaN detected in model output in a"
        zoomed_hs = self.zoom_up(compressed_hidden_states)
        out, attn_map, _ = super().forward(
            zoomed_hs,
            attention_mask,
            position_ids,
            past_key_value,
            output_attentions,
            use_cache,
            # cache_position,
            # **kwargs,
        )
        compressed_hidden_states = self.zoom_down(out)

        # assert not torch.isnan(compressed_hidden_states).any(), f"NaN detected in model output in b"
        if "attn-in-sim-loss" not in ban_losses and self.layer_idx not in ban_layers:
            loss_dict["attn-in-sim-loss"] = self.mseloss(zoomed_hs, hidden_states)
        if "attn-out-sim-loss" not in ban_losses and self.layer_idx not in ban_layers:
            loss_dict["attn-out-sim-loss"] = self.mseloss(out, raw_out)
        return raw_out, compressed_hidden_states, raw_attn_map, _, loss_dict
    
    def merge_weight(self):
        self.q_proj.weight.data = (self.q_proj.weight.data @ self.zoom_up.weight.data).contiguous()
        self.k_proj.weight.data = (self.k_proj.weight.data @ self.zoom_up.weight.data).contiguous()
        self.v_proj.weight.data = (self.v_proj.weight.data @ self.zoom_up.weight.data).contiguous()
        self.o_proj.weight.data = (self.zoom_down.weight.data @ self.o_proj.weight.data).contiguous()


class MLP(Qwen2MLP):
    def __init__(self, config: CustomConfig, layer_idx=None):
        super().__init__(config)
        self.zoom_up = nn.Linear(self.hidden_size, config.target_hidden_size, bias=False)
        self.zoom_gate = nn.Linear(self.hidden_size, config.target_hidden_size, bias=False)
        self.zoom_down = nn.Linear(self.hidden_size, config.target_hidden_size, bias=False)
        self.mseloss = LOSS_DICT[config.aux_loss_type]()
        self.layer_idx = layer_idx

    def small_forward(self, compressed_x, raw_gate, raw_act_gate, raw_up, raw_x, raw_out, loss_dict: dict):
        Wup = self.zoom_up(self.up_proj.weight)
        Wgate = self.zoom_gate(self.gate_proj.weight)
        Wdown = self.zoom_down(self.down_proj.weight.T).T
        gate = linear(compressed_x, Wgate)
        act_gate = self.act_fn(gate)
        up = linear(compressed_x, Wup)
        down = linear(act_gate * up, Wdown)

        # calculate loss
        if "mlp-gate-loss" not in ban_losses and self.layer_idx not in ban_layers:
            loss_dict[f"mlp-gate-loss"] = self.mseloss(gate, raw_gate)
        # loss_dict[f"mlp-act-gate-loss"] = self.mseloss(act_gate, raw_act_gate)
        if "mlp-up-loss" not in ban_losses and self.layer_idx not in ban_layers:
            loss_dict[f"mlp-up-loss"] = self.mseloss(up, raw_up)
        # loss_dict[f"mlp-in-loss"] = self.mseloss(compressed_x, self.zoom(raw_x))
        if "mlp-out-loss" not in ban_losses and self.layer_idx not in ban_layers:
            loss_dict[f"mlp-out-loss"] = self.mseloss(down, self.zoom_down(raw_out))

        # print("debug 2", loss_dict)
        return down

    def forward(self, x, compressed_x, loss_dict: dict):
        gate = self.gate_proj(x)
        act_gate = self.act_fn(gate)
        up = self.up_proj(x)
        down = self.down_proj(act_gate * up)

        return down, self.small_forward(compressed_x, gate, act_gate, up, x, down, loss_dict), loss_dict
    
    def merge_weight(self):
        self.gate_proj.weight.data = self.zoom_gate(self.gate_proj.weight.data).contiguous()
        self.up_proj.weight.data = self.zoom_up(self.up_proj.weight.data).contiguous()
        self.down_proj.weight.data = self.zoom_down(self.down_proj.weight.data.T).T.contiguous()


class DebugLlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        DebugLlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps
        # self.first = True

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        mm = torch.rsqrt(variance + self.variance_epsilon)
        hidden_states = hidden_states * mm
        out = self.weight * hidden_states.to(input_dtype)
        # if self.first:
        #     assert torch.abs(torch.mean(self.weight) - 1) < 1e-3, f"{self.weight}"
        #     self.first = False
        return out
    

# def reinit_weight(module: nn.Module):
#     if type(module) == nn.Linear:
#         if module.weight.requires_grad:
#             module.weight.data.normal_(mean=0.0, std=0.02)
#             if module.bias is not None:
#                 module.bias.data.zero_()
#     if type(module) == DebugLlamaRMSNorm:
#         if module.weight.requires_grad:
#             module.weight.data.fill_(1.0)


def reinit_weight(module: nn.Module):
   
    if not hasattr(module, 'weight') or not module.weight.requires_grad:
        return

    #important_dims = [0, 1, 2, 3, 8, 11, 12, 14, 16, 18, 19, 20, 22, 24, 26, 27, 28, 31, 33, 34, 35, 37, 38, 40, 42, 43, 46, 47, 49, 57, 58, 63, 64, 65, 66, 70, 74, 77, 78, 79, 80, 81, 84, 85, 87, 88, 89, 93, 99, 103, 105, 107, 108, 109, 110, 112, 119, 120, 121, 122, 123, 124, 125, 127, 128, 132, 133, 137, 138, 141, 142, 143, 145, 147, 148, 150, 151, 156, 158, 160, 161, 163, 165, 166, 167, 174, 176, 177, 180, 183, 184, 187, 191, 194, 195, 196, 197, 201, 203, 204, 205, 207, 212, 215, 216, 217, 218, 219, 221, 223, 224, 227, 230, 232, 233, 234, 235, 236, 237, 238, 239, 240, 242, 243, 245, 246, 247, 251, 254, 256, 260, 262, 263, 264, 265, 267, 272, 273, 274, 275, 277, 278, 279, 280, 281, 282, 287, 289, 291, 292, 294, 298, 300, 301, 303, 304, 307, 308, 310, 311, 312, 314, 319, 322, 323, 326, 328, 329, 331, 332, 333, 335, 338, 343, 348, 349, 350, 351, 353, 354, 355, 356, 357, 359, 360, 364, 366, 368, 370, 372, 373, 375, 376, 377, 379, 380, 384, 385, 388, 389, 393, 398, 399, 400, 403, 404, 409, 411, 412, 413, 415, 416, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 431, 433, 435, 436, 437, 438, 441, 443, 446, 447, 448, 453, 454, 456, 457, 458, 459, 460, 462, 466, 469, 472, 476, 478, 480, 481, 482, 484, 485, 486, 487, 488, 489, 492, 495, 496, 497, 500, 501, 507, 509, 511, 512, 513, 514, 517, 518, 519, 520, 522, 524, 525, 527, 528, 530, 531, 533, 535, 538, 539, 540, 542, 554, 555, 557, 559, 562, 563, 564, 567, 568, 570, 571, 573, 574, 575, 580, 581, 584, 585, 586, 588, 589, 591, 593, 595, 596, 600, 605, 606, 608, 611, 613, 615, 616, 617, 621, 624, 627, 628, 629, 630, 634, 642, 643, 647, 649, 650, 651, 653, 658, 661, 662, 663, 664, 666, 667, 668, 669, 671, 672, 674, 676, 677, 678, 685, 687, 688, 689, 690, 692, 696, 698, 700, 702, 703, 704, 705, 708, 710, 712, 713, 714, 717, 718, 719, 720, 723, 726, 727, 728, 730, 732, 733, 737, 738, 739, 742, 743, 744, 745, 746, 747, 749, 752, 753, 755, 756, 757, 758, 759, 760, 763, 764, 766, 767, 769, 770, 775, 776, 777, 778, 779, 780, 781, 785, 786, 787, 790, 793, 794, 795, 796, 798, 801, 803, 806, 808, 810, 811, 814, 816, 817, 820, 822, 823, 824, 825, 827, 830, 832, 834, 835, 836, 838, 839, 840, 842, 843, 844, 849, 851, 852, 853, 854, 855, 856, 859, 860, 864, 865, 866, 868, 869, 870, 873, 875, 876, 878, 881, 882, 883, 884, 887, 891, 893, 895, 896, 897, 898, 900, 901, 902, 906, 907, 909, 913, 914, 917, 919, 921, 922, 925, 927, 928, 929, 931, 932, 933, 934, 936, 938, 940, 945, 946, 952, 953, 954, 956, 957, 958, 960, 962, 963, 965, 967, 968, 972, 973, 976, 977, 981, 983, 984, 985, 987, 993, 994, 996, 997, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1016, 1018, 1023, 1024, 1026, 1027, 1031, 1032, 1034, 1035, 1037, 1041, 1043, 1044, 1045, 1047, 1048, 1049, 1050, 1052, 1053, 1055, 1056, 1057, 1058, 1060, 1069, 1072, 1076, 1077, 1079, 1082, 1083, 1084, 1085, 1086, 1087, 1088, 1090, 1095, 1097, 1099, 1102, 1103, 1104, 1105, 1106, 1109, 1110, 1111, 1113, 1114, 1115, 1119, 1122, 1123, 1129, 1131, 1133, 1136, 1139, 1144, 1145, 1146, 1147, 1151, 1154, 1155, 1156, 1158, 1161, 1162, 1166, 1167, 1170, 1171, 1172, 1175, 1176, 1177, 1179, 1181, 1184, 1185, 1188, 1189, 1194, 1195, 1197, 1200, 1203, 1205, 1208, 1209, 1210, 1212, 1215, 1216, 1217, 1218, 1220, 1221, 1223, 1224, 1225, 1226, 1227, 1228, 1229, 1230, 1234, 1235, 1236, 1237, 1238, 1239, 1241, 1243, 1244, 1249, 1251, 1252, 1253, 1254, 1255, 1259, 1264, 1266, 1268, 1270, 1271, 1272, 1277, 1278, 1279, 1280, 1281, 1283, 1285, 1286, 1288, 1290, 1291, 1292, 1293, 1295, 1299, 1302, 1304, 1305, 1306, 1307, 1308, 1310, 1313, 1314, 1315, 1316, 1319, 1320, 1326, 1332, 1333, 1335, 1336, 1342, 1345, 1346, 1347, 1348, 1349, 1354, 1355, 1356, 1358, 1361, 1362, 1363, 1364, 1367, 1368, 1369, 1372, 1374, 1375, 1376, 1378, 1379, 1380, 1382, 1389, 1391, 1392, 1394, 1395, 1398, 1399, 1401, 1402, 1403, 1404, 1407, 1408, 1409, 1410, 1416, 1417, 1422, 1424, 1425, 1428, 1431, 1433, 1435, 1437, 1439, 1440, 1443, 1444, 1447, 1448, 1453, 1454, 1456, 1457, 1458, 1459, 1465, 1466, 1467, 1470, 1471, 1474, 1476, 1477, 1479, 1480, 1482, 1484, 1485, 1487, 1488, 1493, 1499, 1500, 1503, 1505, 1507, 1508, 1510, 1515, 1519, 1522, 1523, 1524, 1525, 1526, 1527, 1528, 1529, 1531, 1533, 1534, 1535, 1536, 1537, 1539, 1540, 1541, 1542, 1543, 1544, 1545, 1546, 1552, 1553, 1554, 1557, 1558, 1559, 1560, 1562, 1563, 1564, 1566, 1568, 1571, 1572, 1573, 1576, 1578, 1579, 1582, 1584, 1585, 1586, 1589, 1591, 1595, 1597, 1602, 1603, 1604, 1605, 1608, 1609, 1610, 1614, 1616, 1617, 1619, 1622, 1624, 1627, 1629, 1630, 1632, 1634, 1638, 1639, 1640, 1641, 1643, 1644, 1646, 1647, 1648, 1649, 1650, 1651, 1653, 1654, 1655, 1657, 1659, 1670, 1673, 1674, 1675, 1677, 1678, 1680, 1681, 1683, 1685, 1686, 1687, 1688, 1690, 1692, 1693, 1694, 1696, 1698, 1705, 1709, 1710, 1711, 1712, 1713, 1714, 1715, 1716, 1717, 1719, 1722, 1724, 1726, 1727, 1728, 1729, 1730, 1731, 1739, 1742, 1743, 1744, 1745, 1746, 1747, 1755, 1756, 1760, 1761, 1762, 1765, 1774, 1777, 1778, 1781, 1784, 1786, 1788, 1789, 1791, 1792, 1793, 1794, 1795, 1796, 1797, 1798, 1800, 1803, 1804, 1805, 1806, 1807, 1808, 1809, 1810, 1812, 1814, 1815, 1817, 1820, 1823, 1825, 1826, 1827, 1830, 1831, 1837, 1840, 1842, 1844, 1845, 1847, 1848, 1851, 1852, 1853, 1857, 1859, 1861, 1862, 1866, 1867, 1874, 1878, 1879, 1882, 1886, 1887, 1888, 1889, 1894, 1895, 1897, 1899, 1901, 1902, 1903, 1904, 1905, 1906, 1909, 1910, 1911, 1914, 1915, 1916, 1917, 1924, 1925, 1926, 1927, 1928, 1930, 1932, 1934, 1935, 1940, 1941, 1943, 1947, 1948, 1950, 1951, 1955, 1956, 1958, 1964, 1965, 1967, 1968, 1973, 1974, 1976, 1978, 1980, 1981, 1982, 1983, 1986, 1988, 1991, 1992, 1994, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2007, 2008, 2010, 2011, 2015, 2016, 2018, 2019, 2021, 2022, 2024, 2026, 2027, 2029, 2030, 2033, 2035, 2036, 2041, 2042, 2044, 2045, 2046, 2047, 2048, 2049, 2052, 2053, 2054, 2056, 2059, 2065, 2066, 2067, 2071, 2072, 2073, 2076, 2079, 2080, 2083, 2086, 2087, 2090, 2092, 2094, 2095, 2096, 2098, 2099, 2100, 2101, 2102, 2104, 2107, 2109, 2112, 2113, 2115, 2117, 2119, 2121, 2122, 2124, 2125, 2126, 2127, 2130, 2132, 2136, 2139, 2141, 2143, 2145, 2147, 2151, 2152, 2154, 2156, 2161, 2165, 2166, 2168, 2171, 2172, 2176, 2177, 2178, 2181, 2183, 2184, 2185, 2188, 2189, 2192, 2193, 2195, 2196, 2197, 2198, 2200, 2202, 2203, 2204, 2205, 2209, 2210, 2212, 2214, 2216, 2217, 2218, 2219, 2223, 2224, 2225, 2227, 2228, 2231, 2233, 2234, 2236, 2237, 2239, 2240, 2243, 2244, 2245, 2247, 2250, 2254, 2255, 2257, 2258, 2262, 2263, 2264, 2265, 2267, 2268, 2270, 2271, 2273, 2275, 2279, 2280, 2287, 2288, 2289, 2292, 2294, 2297, 2299, 2303, 2305, 2307, 2308, 2312, 2317, 2320, 2322, 2326, 2327, 2328, 2329, 2330, 2331, 2336, 2338, 2339, 2344, 2345, 2346, 2348, 2352, 2354, 2355, 2356, 2357, 2358, 2359, 2360, 2362, 2363, 2364, 2366, 2369, 2370, 2371, 2372, 2373, 2374, 2378, 2379, 2380, 2381, 2383, 2386, 2388, 2392, 2393, 2395, 2398, 2400, 2403, 2407, 2411, 2412, 2420, 2423, 2425, 2428, 2430, 2432, 2433, 2434, 2435, 2437, 2444, 2445, 2446, 2448, 2452, 2453, 2454, 2458, 2461, 2462, 2465, 2469, 2472, 2473, 2474, 2476, 2478, 2480, 2482, 2484, 2486, 2487, 2489, 2493, 2497, 2498, 2502, 2507, 2509, 2510, 2513, 2515, 2517, 2528, 2530, 2531, 2535, 2536, 2537, 2540, 2541, 2544, 2551, 2552, 2558, 2560, 2563, 2574, 2576, 2577, 2579, 2581, 2582, 2584, 2585, 2587, 2588, 2589, 2590, 2593, 2599, 2602, 2603, 2607, 2609, 2610, 2612, 2613, 2614, 2615, 2618, 2619, 2621, 2626, 2629, 2630, 2632, 2634, 2637, 2638, 2647, 2648, 2655, 2656, 2658, 2661, 2663, 2668, 2669, 2670, 2671, 2673, 2676, 2678, 2679, 2680, 2681, 2682, 2683, 2685, 2686, 2688, 2689, 2690, 2692, 2696, 2697, 2698, 2701, 2702, 2703, 2704, 2707, 2709, 2712, 2713, 2714, 2715, 2717, 2720, 2721, 2722, 2723, 2724, 2726, 2727, 2734, 2739, 2741, 2748, 2749, 2750, 2752, 2758, 2759, 2760, 2761, 2764, 2766, 2771, 2773, 2774, 2775, 2776, 2777, 2778, 2779, 2781, 2783, 2784, 2786, 2787, 2789, 2791, 2792, 2795, 2797, 2799, 2800, 2801, 2802, 2803, 2805, 2807, 2808, 2809, 2814, 2817, 2822, 2823, 2824, 2825, 2834, 2837, 2839, 2841, 2844, 2848, 2849, 2852, 2854, 2856, 2858, 2859, 2860, 2864, 2868, 2873, 2874, 2875, 2876, 2878, 2879, 2881, 2883, 2884, 2885, 2888, 2892, 2895, 2896, 2898, 2902, 2905, 2912, 2914, 2917, 2920, 2921, 2922, 2923, 2924, 2926, 2927, 2930, 2931, 2932, 2934, 2936, 2937, 2938, 2939, 2940, 2941, 2943, 2945, 2946, 2949, 2952, 2953, 2954, 2962, 2963, 2965, 2966, 2970, 2971, 2973, 2986, 2988, 2990, 2992, 2993, 2994, 2995, 2998, 2999, 3000, 3004, 3005, 3006, 3008, 3009, 3010, 3018, 3020, 3026, 3027, 3029, 3031, 3033, 3034, 3036, 3037, 3039, 3046, 3049, 3050, 3052, 3053, 3057, 3061, 3062, 3064, 3065, 3066, 3067, 3069, 3070]
    important_dims = [0, 1, 2, 3, 6, 7, 9, 10, 14, 16, 17, 18, 20, 21, 24, 25, 26, 27, 28, 30, 31, 32, 34, 37, 38, 40, 41, 43, 45, 46, 47, 48, 49, 50, 52, 53, 54, 55, 56, 58, 59, 61, 62, 65, 71, 72, 73, 74, 75, 79, 80, 88, 90, 91, 95, 97, 99, 102, 103, 105, 110, 111, 115, 116, 117, 120, 124, 128, 129, 130, 132, 134, 136, 138, 139, 140, 141, 143, 145, 148, 149, 150, 152, 153, 158, 159, 163, 164, 165, 166, 167, 169, 171, 173, 174, 175, 176, 177, 182, 184, 185, 188, 190, 193, 194, 195, 196, 198, 204, 207, 209, 213, 218, 220, 221, 222, 223, 225, 226, 227, 229, 230, 232, 233, 234, 236, 237, 238, 240, 241, 242, 243, 245, 246, 247, 249, 250, 251, 253, 254, 255, 256, 257, 259, 260, 263, 264, 265, 268, 270, 272, 273, 275, 276, 277, 279, 280, 283, 284, 285, 286, 287, 289, 291, 292, 293, 294, 295, 296, 297, 298, 301, 305, 306, 307, 308, 310, 311, 312, 313, 314, 315, 317, 318, 319, 321, 322, 323, 324, 326, 328, 332, 333, 335, 336, 337, 338, 340, 341, 342, 345, 346, 347, 348, 350, 351, 357, 358, 360, 362, 364, 366, 368, 373, 374, 376, 378, 379, 381, 383, 384, 385, 386, 390, 391, 392, 393, 394, 395, 396, 398, 399, 401, 402, 403, 409, 411, 412, 414, 416, 419, 420, 421, 423, 424, 426, 427, 428, 430, 432, 437, 439, 440, 441, 442, 443, 444, 446, 449, 450, 451, 453, 454, 456, 457, 458, 459, 460, 463, 464, 465, 467, 468, 469, 471, 472, 473, 474, 477, 480, 482, 483, 484, 485, 486, 488, 489, 491, 493, 495, 496, 497, 499, 501, 502, 503, 504, 505, 506, 507, 510, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 525, 526, 527, 528, 529, 530, 531, 532, 534, 535, 536, 537, 545, 550, 552, 553, 554, 556, 558, 559, 560, 561, 562, 564, 567, 568, 569, 571, 572, 573, 575, 578, 581, 582, 583, 585, 588, 589, 590, 591, 594, 595, 598, 601, 605, 607, 608, 612, 613, 614, 615, 616, 619, 620, 621, 622, 627, 629, 634, 635, 636, 641, 642, 643, 645, 647, 649, 651, 653, 657, 658, 660, 663, 664, 665, 666, 667, 668, 669, 670, 671, 673, 676, 677, 678, 680, 681, 682, 684, 685, 687, 688, 689, 690, 693, 694, 695, 697, 700, 701, 702, 704, 705, 706, 708, 709, 710, 711, 713, 714, 715, 716, 717, 721, 724, 725, 726, 728, 730, 733, 734, 735, 736, 740, 741, 742, 743, 744, 747, 749, 750, 751, 753, 758, 759, 760, 762, 763, 764, 765, 766, 767, 771, 772, 773, 774, 777, 778, 780, 781, 782, 784, 786, 787, 788, 789, 791, 793, 794, 797, 798, 800, 802, 803, 807, 808, 809, 810, 812, 813, 817, 819, 820, 821, 822, 823, 825, 826, 827, 829, 830, 831, 832, 833, 834, 835, 837, 839, 840, 842, 844, 845, 846, 849, 851, 852, 853, 855, 859, 860, 861, 863, 864, 865, 866, 867, 868, 870, 872, 874, 878, 880, 884, 890, 892, 894, 895, 902, 903, 904, 906, 911, 914, 916, 917, 919, 920, 921, 922, 925, 926, 927, 930, 936, 937, 938, 939, 941, 943, 945, 946, 947, 950, 952, 954, 957, 958, 968, 970, 973, 976, 978, 979, 981, 982, 985, 987, 992, 995, 997, 998, 999, 1003, 1004, 1005, 1007, 1008, 1010, 1011, 1020, 1021, 1024, 1025, 1027, 1029, 1030, 1031, 1033, 1034, 1035, 1036, 1038, 1039, 1041, 1045, 1046, 1048, 1050, 1052, 1053, 1054, 1056, 1059, 1062, 1065, 1066, 1069, 1070, 1071, 1073, 1075, 1077, 1078, 1080, 1081, 1083, 1085, 1087, 1088, 1091, 1094, 1097, 1098, 1099, 1100, 1101, 1104, 1105, 1106, 1109, 1110, 1111, 1112, 1114, 1115, 1117, 1118, 1122, 1126, 1127, 1128, 1129, 1130, 1131, 1132, 1135, 1136, 1137, 1138, 1139, 1140, 1141, 1142, 1143, 1149, 1152, 1153, 1154, 1157, 1158, 1159, 1160, 1161, 1162, 1164, 1165, 1166, 1167, 1170, 1171, 1172, 1176, 1177, 1178, 1179, 1180, 1181, 1183, 1184, 1188, 1189, 1191, 1192, 1194, 1195, 1196, 1198, 1199, 1201, 1202, 1205, 1207, 1209, 1210, 1212, 1213, 1214, 1215, 1216, 1217, 1218, 1219, 1221, 1223, 1224, 1226, 1228, 1229, 1230, 1232, 1234, 1238, 1239, 1240, 1241, 1245, 1247, 1248, 1249, 1251, 1252, 1254, 1255, 1257, 1258, 1259, 1260, 1261, 1263, 1266, 1270, 1271, 1272, 1274, 1275, 1279, 1280, 1281, 1282, 1283, 1284, 1288, 1290, 1291, 1292, 1294, 1295, 1296, 1297, 1299, 1300, 1304, 1308, 1309, 1310, 1311, 1312, 1313, 1315, 1316, 1317, 1320, 1321, 1324, 1325, 1326, 1332, 1333, 1335, 1336, 1337, 1338, 1343, 1345, 1346, 1351, 1352, 1354, 1356, 1358, 1359, 1361, 1362, 1363, 1365, 1366, 1367, 1368, 1369, 1370, 1371, 1372, 1374, 1375, 1376, 1377, 1380, 1381, 1382, 1383, 1384, 1388, 1391, 1392, 1394, 1395, 1396, 1397, 1399, 1400, 1401, 1408, 1410, 1412, 1413, 1414, 1415, 1417, 1418, 1419, 1421, 1422, 1423, 1424, 1425, 1430, 1431, 1432, 1433, 1438, 1441, 1443, 1445, 1446, 1447, 1449, 1451, 1453, 1455, 1456, 1458, 1459, 1464, 1465, 1467, 1469, 1470, 1474, 1475, 1477, 1478, 1479, 1480, 1481, 1482, 1483, 1484, 1488, 1489, 1491, 1492, 1494, 1495, 1497, 1498, 1500, 1501, 1504, 1506, 1507, 1509, 1510, 1511, 1513, 1514, 1515, 1516, 1517, 1518, 1519, 1520, 1521, 1525, 1526, 1527, 1528, 1529, 1530, 1533, 1535, 1540, 1543, 1544, 1545, 1546, 1547, 1548, 1551, 1552, 1554, 1555, 1556, 1560, 1561, 1562, 1563, 1564, 1565, 1566, 1567, 1568, 1569, 1572, 1574, 1575, 1578, 1579, 1580, 1586, 1587, 1589, 1593, 1594, 1596, 1597, 1598, 1599, 1601, 1602, 1603, 1606, 1607, 1608, 1609, 1610, 1611, 1614, 1617, 1618, 1619, 1620, 1622, 1623, 1625, 1626, 1627, 1628, 1629, 1631, 1632, 1633, 1634, 1637, 1638, 1639, 1642, 1643, 1644, 1645, 1646, 1647, 1651, 1653, 1654, 1655, 1656, 1657, 1659, 1661, 1663, 1664, 1665, 1667, 1668, 1669, 1670, 1672, 1675, 1676, 1677, 1679, 1680, 1681, 1684, 1689, 1691, 1692, 1697, 1700, 1701, 1703, 1704, 1707, 1709, 1711, 1713, 1714, 1715, 1716, 1717, 1718, 1721, 1722, 1724, 1725, 1727, 1729, 1730, 1731, 1732, 1734, 1737, 1742, 1743, 1744, 1746, 1747, 1750, 1751, 1752, 1753, 1754, 1756, 1758, 1761, 1764, 1766, 1767, 1769, 1771, 1772, 1774, 1775, 1779, 1781, 1782, 1783, 1785, 1786, 1789, 1791, 1794, 1795, 1796, 1797, 1798, 1799, 1800, 1803, 1804, 1805, 1807, 1809, 1810, 1811, 1813, 1814, 1815, 1816, 1817, 1818, 1819, 1823, 1824, 1826, 1827, 1829, 1831, 1833, 1834, 1835, 1837, 1838, 1841, 1843, 1844, 1845, 1851, 1852, 1854, 1856, 1861, 1865, 1866, 1868, 1870, 1871, 1872, 1873, 1874, 1875, 1877, 1878, 1881, 1883, 1884, 1886, 1889, 1890, 1891, 1894, 1897, 1899, 1900, 1901, 1902, 1904, 1905, 1906, 1909, 1910, 1911, 1912, 1915, 1916, 1917, 1923, 1924, 1925, 1926, 1927, 1929, 1930, 1931, 1933, 1934, 1936, 1937, 1938, 1941, 1944, 1947, 1949, 1951, 1954, 1955, 1956, 1957, 1958, 1960, 1961, 1962, 1965, 1968, 1970, 1971, 1973, 1974, 1976, 1978, 1979, 1980, 1981, 1982, 1984, 1985, 1986, 1987, 1988, 1989, 1991, 1992, 1993, 1994, 1995, 1996, 1997, 1998, 1999, 2001, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2016, 2017, 2018, 2020, 2021, 2023, 2024, 2025, 2027, 2030, 2031, 2034, 2035, 2036, 2037, 2039, 2041, 2042, 2044, 2045, 2046, 2047]
    is_zoom_down = module.weight.shape == (1200, 2048)
    is_zoom_up = module.weight.shape == (2048, 1200)

    if isinstance(module, nn.Linear):
       
        if is_zoom_down:
            new_weight = torch.zeros_like(module.weight)
            for i, dim_idx in enumerate(important_dims):
                new_weight[i, dim_idx] = 1.0
            module.weight.data = new_weight
    
        elif is_zoom_up:
            new_weight = torch.zeros_like(module.weight)
            for i, dim_idx in enumerate(important_dims):
                new_weight[dim_idx, i] = 1.0
            module.weight.data = new_weight
       
        else:
            exit()
            # import pdb
            # pdb.set_trace()
        
        if module.bias is not None and module.bias.requires_grad:
            module.bias.data.zero_()

    if type(module) == DebugLlamaRMSNorm:
        if module.weight.requires_grad:
            module.weight.data.fill_(1.0)


class CustomLayer(Qwen2DecoderLayer):
    def __init__(self, config: CustomConfig, layer_idx):
        super().__init__(config, layer_idx)
        self.config = config
        self.layer_idx = layer_idx

        if self.config.use_std_like_attn:
            raise ValueError("Low Performance")
        elif self.config.student_attn_from_scratch:
            raise NotImplementedError
        elif self.config.use_all_attn:
            # print(f"[arch] using all attn")
            self.self_attn = AllAttn(config, layer_idx)
        else:
            # print(f"[arch] using io attn")
            self.self_attn = Attn(config, layer_idx)
        if self.config.use_in_out_mlp:
            raise NotImplementedError
        else:
            # print(f"[arch] using all ffn")
            self.mlp = MLP(config, layer_idx)
        self.target_input_layernorm = DebugLlamaRMSNorm(config.target_hidden_size, eps=config.target_rms_norm_eps)
        self.target_post_attention_layernorm = DebugLlamaRMSNorm(config.target_hidden_size, eps=config.target_rms_norm_eps)

    def forward(
        self,
        hidden_states,
        compressed_hidden_states,
        attention_mask=None,
        position_ids=None,
        past_key_value=None,
        output_attentions=False,
        use_cache=False,
        cache_position=None,
    ):
        residual = hidden_states
        compressed_residual = compressed_hidden_states

        hidden_states = self.input_layernorm(hidden_states)
        compressed_hidden_states = self.target_input_layernorm(compressed_hidden_states)

        # loss_dict = {}

        # Self Attention
        hidden_states, compressed_hidden_states, self_attn_weights, present_key_value, loss_dict = self.self_attn(
            hidden_states=hidden_states,
            compressed_hidden_states=compressed_hidden_states,
            loss_dict={},
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
        )

        hidden_states = residual + hidden_states
        compressed_hidden_states = compressed_hidden_states + compressed_residual

        # Fully Connected (MLP)
        residual = hidden_states
        compressed_residual = compressed_hidden_states

        hidden_states = self.post_attention_layernorm(hidden_states)
        compressed_hidden_states = self.target_post_attention_layernorm(compressed_hidden_states)

        hidden_states, compressed_hidden_states, loss_dict = self.mlp(hidden_states, compressed_hidden_states, loss_dict)

        hidden_states = residual + hidden_states
        compressed_hidden_states = compressed_hidden_states + compressed_residual

        # MLP end
        # print("debug 3", loss_dict)
        outputs = (hidden_states, compressed_hidden_states, loss_dict)

        if output_attentions:
            outputs += (self_attn_weights,)
        if use_cache:
            outputs += (present_key_value,)

        return outputs
    
    def merge_weight(self):
        self.input_layernorm.weight.data = self.target_input_layernorm.weight.data
        self.post_attention_layernorm.weight.data = self.target_post_attention_layernorm.weight.data
    def init_norm(self):
       
        #important_indices = [0, 1, 2, 3, 8, 11, 12, 14, 16, 18, 19, 20, 22, 24, 26, 27, 28, 31, 33, 34, 35, 37, 38, 40, 42, 43, 46, 47, 49, 57, 58, 63, 64, 65, 66, 70, 74, 77, 78, 79, 80, 81, 84, 85, 87, 88, 89, 93, 99, 103, 105, 107, 108, 109, 110, 112, 119, 120, 121, 122, 123, 124, 125, 127, 128, 132, 133, 137, 138, 141, 142, 143, 145, 147, 148, 150, 151, 156, 158, 160, 161, 163, 165, 166, 167, 174, 176, 177, 180, 183, 184, 187, 191, 194, 195, 196, 197, 201, 203, 204, 205, 207, 212, 215, 216, 217, 218, 219, 221, 223, 224, 227, 230, 232, 233, 234, 235, 236, 237, 238, 239, 240, 242, 243, 245, 246, 247, 251, 254, 256, 260, 262, 263, 264, 265, 267, 272, 273, 274, 275, 277, 278, 279, 280, 281, 282, 287, 289, 291, 292, 294, 298, 300, 301, 303, 304, 307, 308, 310, 311, 312, 314, 319, 322, 323, 326, 328, 329, 331, 332, 333, 335, 338, 343, 348, 349, 350, 351, 353, 354, 355, 356, 357, 359, 360, 364, 366, 368, 370, 372, 373, 375, 376, 377, 379, 380, 384, 385, 388, 389, 393, 398, 399, 400, 403, 404, 409, 411, 412, 413, 415, 416, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 431, 433, 435, 436, 437, 438, 441, 443, 446, 447, 448, 453, 454, 456, 457, 458, 459, 460, 462, 466, 469, 472, 476, 478, 480, 481, 482, 484, 485, 486, 487, 488, 489, 492, 495, 496, 497, 500, 501, 507, 509, 511, 512, 513, 514, 517, 518, 519, 520, 522, 524, 525, 527, 528, 530, 531, 533, 535, 538, 539, 540, 542, 554, 555, 557, 559, 562, 563, 564, 567, 568, 570, 571, 573, 574, 575, 580, 581, 584, 585, 586, 588, 589, 591, 593, 595, 596, 600, 605, 606, 608, 611, 613, 615, 616, 617, 621, 624, 627, 628, 629, 630, 634, 642, 643, 647, 649, 650, 651, 653, 658, 661, 662, 663, 664, 666, 667, 668, 669, 671, 672, 674, 676, 677, 678, 685, 687, 688, 689, 690, 692, 696, 698, 700, 702, 703, 704, 705, 708, 710, 712, 713, 714, 717, 718, 719, 720, 723, 726, 727, 728, 730, 732, 733, 737, 738, 739, 742, 743, 744, 745, 746, 747, 749, 752, 753, 755, 756, 757, 758, 759, 760, 763, 764, 766, 767, 769, 770, 775, 776, 777, 778, 779, 780, 781, 785, 786, 787, 790, 793, 794, 795, 796, 798, 801, 803, 806, 808, 810, 811, 814, 816, 817, 820, 822, 823, 824, 825, 827, 830, 832, 834, 835, 836, 838, 839, 840, 842, 843, 844, 849, 851, 852, 853, 854, 855, 856, 859, 860, 864, 865, 866, 868, 869, 870, 873, 875, 876, 878, 881, 882, 883, 884, 887, 891, 893, 895, 896, 897, 898, 900, 901, 902, 906, 907, 909, 913, 914, 917, 919, 921, 922, 925, 927, 928, 929, 931, 932, 933, 934, 936, 938, 940, 945, 946, 952, 953, 954, 956, 957, 958, 960, 962, 963, 965, 967, 968, 972, 973, 976, 977, 981, 983, 984, 985, 987, 993, 994, 996, 997, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1016, 1018, 1023, 1024, 1026, 1027, 1031, 1032, 1034, 1035, 1037, 1041, 1043, 1044, 1045, 1047, 1048, 1049, 1050, 1052, 1053, 1055, 1056, 1057, 1058, 1060, 1069, 1072, 1076, 1077, 1079, 1082, 1083, 1084, 1085, 1086, 1087, 1088, 1090, 1095, 1097, 1099, 1102, 1103, 1104, 1105, 1106, 1109, 1110, 1111, 1113, 1114, 1115, 1119, 1122, 1123, 1129, 1131, 1133, 1136, 1139, 1144, 1145, 1146, 1147, 1151, 1154, 1155, 1156, 1158, 1161, 1162, 1166, 1167, 1170, 1171, 1172, 1175, 1176, 1177, 1179, 1181, 1184, 1185, 1188, 1189, 1194, 1195, 1197, 1200, 1203, 1205, 1208, 1209, 1210, 1212, 1215, 1216, 1217, 1218, 1220, 1221, 1223, 1224, 1225, 1226, 1227, 1228, 1229, 1230, 1234, 1235, 1236, 1237, 1238, 1239, 1241, 1243, 1244, 1249, 1251, 1252, 1253, 1254, 1255, 1259, 1264, 1266, 1268, 1270, 1271, 1272, 1277, 1278, 1279, 1280, 1281, 1283, 1285, 1286, 1288, 1290, 1291, 1292, 1293, 1295, 1299, 1302, 1304, 1305, 1306, 1307, 1308, 1310, 1313, 1314, 1315, 1316, 1319, 1320, 1326, 1332, 1333, 1335, 1336, 1342, 1345, 1346, 1347, 1348, 1349, 1354, 1355, 1356, 1358, 1361, 1362, 1363, 1364, 1367, 1368, 1369, 1372, 1374, 1375, 1376, 1378, 1379, 1380, 1382, 1389, 1391, 1392, 1394, 1395, 1398, 1399, 1401, 1402, 1403, 1404, 1407, 1408, 1409, 1410, 1416, 1417, 1422, 1424, 1425, 1428, 1431, 1433, 1435, 1437, 1439, 1440, 1443, 1444, 1447, 1448, 1453, 1454, 1456, 1457, 1458, 1459, 1465, 1466, 1467, 1470, 1471, 1474, 1476, 1477, 1479, 1480, 1482, 1484, 1485, 1487, 1488, 1493, 1499, 1500, 1503, 1505, 1507, 1508, 1510, 1515, 1519, 1522, 1523, 1524, 1525, 1526, 1527, 1528, 1529, 1531, 1533, 1534, 1535, 1536, 1537, 1539, 1540, 1541, 1542, 1543, 1544, 1545, 1546, 1552, 1553, 1554, 1557, 1558, 1559, 1560, 1562, 1563, 1564, 1566, 1568, 1571, 1572, 1573, 1576, 1578, 1579, 1582, 1584, 1585, 1586, 1589, 1591, 1595, 1597, 1602, 1603, 1604, 1605, 1608, 1609, 1610, 1614, 1616, 1617, 1619, 1622, 1624, 1627, 1629, 1630, 1632, 1634, 1638, 1639, 1640, 1641, 1643, 1644, 1646, 1647, 1648, 1649, 1650, 1651, 1653, 1654, 1655, 1657, 1659, 1670, 1673, 1674, 1675, 1677, 1678, 1680, 1681, 1683, 1685, 1686, 1687, 1688, 1690, 1692, 1693, 1694, 1696, 1698, 1705, 1709, 1710, 1711, 1712, 1713, 1714, 1715, 1716, 1717, 1719, 1722, 1724, 1726, 1727, 1728, 1729, 1730, 1731, 1739, 1742, 1743, 1744, 1745, 1746, 1747, 1755, 1756, 1760, 1761, 1762, 1765, 1774, 1777, 1778, 1781, 1784, 1786, 1788, 1789, 1791, 1792, 1793, 1794, 1795, 1796, 1797, 1798, 1800, 1803, 1804, 1805, 1806, 1807, 1808, 1809, 1810, 1812, 1814, 1815, 1817, 1820, 1823, 1825, 1826, 1827, 1830, 1831, 1837, 1840, 1842, 1844, 1845, 1847, 1848, 1851, 1852, 1853, 1857, 1859, 1861, 1862, 1866, 1867, 1874, 1878, 1879, 1882, 1886, 1887, 1888, 1889, 1894, 1895, 1897, 1899, 1901, 1902, 1903, 1904, 1905, 1906, 1909, 1910, 1911, 1914, 1915, 1916, 1917, 1924, 1925, 1926, 1927, 1928, 1930, 1932, 1934, 1935, 1940, 1941, 1943, 1947, 1948, 1950, 1951, 1955, 1956, 1958, 1964, 1965, 1967, 1968, 1973, 1974, 1976, 1978, 1980, 1981, 1982, 1983, 1986, 1988, 1991, 1992, 1994, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2007, 2008, 2010, 2011, 2015, 2016, 2018, 2019, 2021, 2022, 2024, 2026, 2027, 2029, 2030, 2033, 2035, 2036, 2041, 2042, 2044, 2045, 2046, 2047, 2048, 2049, 2052, 2053, 2054, 2056, 2059, 2065, 2066, 2067, 2071, 2072, 2073, 2076, 2079, 2080, 2083, 2086, 2087, 2090, 2092, 2094, 2095, 2096, 2098, 2099, 2100, 2101, 2102, 2104, 2107, 2109, 2112, 2113, 2115, 2117, 2119, 2121, 2122, 2124, 2125, 2126, 2127, 2130, 2132, 2136, 2139, 2141, 2143, 2145, 2147, 2151, 2152, 2154, 2156, 2161, 2165, 2166, 2168, 2171, 2172, 2176, 2177, 2178, 2181, 2183, 2184, 2185, 2188, 2189, 2192, 2193, 2195, 2196, 2197, 2198, 2200, 2202, 2203, 2204, 2205, 2209, 2210, 2212, 2214, 2216, 2217, 2218, 2219, 2223, 2224, 2225, 2227, 2228, 2231, 2233, 2234, 2236, 2237, 2239, 2240, 2243, 2244, 2245, 2247, 2250, 2254, 2255, 2257, 2258, 2262, 2263, 2264, 2265, 2267, 2268, 2270, 2271, 2273, 2275, 2279, 2280, 2287, 2288, 2289, 2292, 2294, 2297, 2299, 2303, 2305, 2307, 2308, 2312, 2317, 2320, 2322, 2326, 2327, 2328, 2329, 2330, 2331, 2336, 2338, 2339, 2344, 2345, 2346, 2348, 2352, 2354, 2355, 2356, 2357, 2358, 2359, 2360, 2362, 2363, 2364, 2366, 2369, 2370, 2371, 2372, 2373, 2374, 2378, 2379, 2380, 2381, 2383, 2386, 2388, 2392, 2393, 2395, 2398, 2400, 2403, 2407, 2411, 2412, 2420, 2423, 2425, 2428, 2430, 2432, 2433, 2434, 2435, 2437, 2444, 2445, 2446, 2448, 2452, 2453, 2454, 2458, 2461, 2462, 2465, 2469, 2472, 2473, 2474, 2476, 2478, 2480, 2482, 2484, 2486, 2487, 2489, 2493, 2497, 2498, 2502, 2507, 2509, 2510, 2513, 2515, 2517, 2528, 2530, 2531, 2535, 2536, 2537, 2540, 2541, 2544, 2551, 2552, 2558, 2560, 2563, 2574, 2576, 2577, 2579, 2581, 2582, 2584, 2585, 2587, 2588, 2589, 2590, 2593, 2599, 2602, 2603, 2607, 2609, 2610, 2612, 2613, 2614, 2615, 2618, 2619, 2621, 2626, 2629, 2630, 2632, 2634, 2637, 2638, 2647, 2648, 2655, 2656, 2658, 2661, 2663, 2668, 2669, 2670, 2671, 2673, 2676, 2678, 2679, 2680, 2681, 2682, 2683, 2685, 2686, 2688, 2689, 2690, 2692, 2696, 2697, 2698, 2701, 2702, 2703, 2704, 2707, 2709, 2712, 2713, 2714, 2715, 2717, 2720, 2721, 2722, 2723, 2724, 2726, 2727, 2734, 2739, 2741, 2748, 2749, 2750, 2752, 2758, 2759, 2760, 2761, 2764, 2766, 2771, 2773, 2774, 2775, 2776, 2777, 2778, 2779, 2781, 2783, 2784, 2786, 2787, 2789, 2791, 2792, 2795, 2797, 2799, 2800, 2801, 2802, 2803, 2805, 2807, 2808, 2809, 2814, 2817, 2822, 2823, 2824, 2825, 2834, 2837, 2839, 2841, 2844, 2848, 2849, 2852, 2854, 2856, 2858, 2859, 2860, 2864, 2868, 2873, 2874, 2875, 2876, 2878, 2879, 2881, 2883, 2884, 2885, 2888, 2892, 2895, 2896, 2898, 2902, 2905, 2912, 2914, 2917, 2920, 2921, 2922, 2923, 2924, 2926, 2927, 2930, 2931, 2932, 2934, 2936, 2937, 2938, 2939, 2940, 2941, 2943, 2945, 2946, 2949, 2952, 2953, 2954, 2962, 2963, 2965, 2966, 2970, 2971, 2973, 2986, 2988, 2990, 2992, 2993, 2994, 2995, 2998, 2999, 3000, 3004, 3005, 3006, 3008, 3009, 3010, 3018, 3020, 3026, 3027, 3029, 3031, 3033, 3034, 3036, 3037, 3039, 3046, 3049, 3050, 3052, 3053, 3057, 3061, 3062, 3064, 3065, 3066, 3067, 3069, 3070]
        important_indices = [0, 1, 2, 3, 6, 7, 9, 10, 14, 16, 17, 18, 20, 21, 24, 25, 26, 27, 28, 30, 31, 32, 34, 37, 38, 40, 41, 43, 45, 46, 47, 48, 49, 50, 52, 53, 54, 55, 56, 58, 59, 61, 62, 65, 71, 72, 73, 74, 75, 79, 80, 88, 90, 91, 95, 97, 99, 102, 103, 105, 110, 111, 115, 116, 117, 120, 124, 128, 129, 130, 132, 134, 136, 138, 139, 140, 141, 143, 145, 148, 149, 150, 152, 153, 158, 159, 163, 164, 165, 166, 167, 169, 171, 173, 174, 175, 176, 177, 182, 184, 185, 188, 190, 193, 194, 195, 196, 198, 204, 207, 209, 213, 218, 220, 221, 222, 223, 225, 226, 227, 229, 230, 232, 233, 234, 236, 237, 238, 240, 241, 242, 243, 245, 246, 247, 249, 250, 251, 253, 254, 255, 256, 257, 259, 260, 263, 264, 265, 268, 270, 272, 273, 275, 276, 277, 279, 280, 283, 284, 285, 286, 287, 289, 291, 292, 293, 294, 295, 296, 297, 298, 301, 305, 306, 307, 308, 310, 311, 312, 313, 314, 315, 317, 318, 319, 321, 322, 323, 324, 326, 328, 332, 333, 335, 336, 337, 338, 340, 341, 342, 345, 346, 347, 348, 350, 351, 357, 358, 360, 362, 364, 366, 368, 373, 374, 376, 378, 379, 381, 383, 384, 385, 386, 390, 391, 392, 393, 394, 395, 396, 398, 399, 401, 402, 403, 409, 411, 412, 414, 416, 419, 420, 421, 423, 424, 426, 427, 428, 430, 432, 437, 439, 440, 441, 442, 443, 444, 446, 449, 450, 451, 453, 454, 456, 457, 458, 459, 460, 463, 464, 465, 467, 468, 469, 471, 472, 473, 474, 477, 480, 482, 483, 484, 485, 486, 488, 489, 491, 493, 495, 496, 497, 499, 501, 502, 503, 504, 505, 506, 507, 510, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 525, 526, 527, 528, 529, 530, 531, 532, 534, 535, 536, 537, 545, 550, 552, 553, 554, 556, 558, 559, 560, 561, 562, 564, 567, 568, 569, 571, 572, 573, 575, 578, 581, 582, 583, 585, 588, 589, 590, 591, 594, 595, 598, 601, 605, 607, 608, 612, 613, 614, 615, 616, 619, 620, 621, 622, 627, 629, 634, 635, 636, 641, 642, 643, 645, 647, 649, 651, 653, 657, 658, 660, 663, 664, 665, 666, 667, 668, 669, 670, 671, 673, 676, 677, 678, 680, 681, 682, 684, 685, 687, 688, 689, 690, 693, 694, 695, 697, 700, 701, 702, 704, 705, 706, 708, 709, 710, 711, 713, 714, 715, 716, 717, 721, 724, 725, 726, 728, 730, 733, 734, 735, 736, 740, 741, 742, 743, 744, 747, 749, 750, 751, 753, 758, 759, 760, 762, 763, 764, 765, 766, 767, 771, 772, 773, 774, 777, 778, 780, 781, 782, 784, 786, 787, 788, 789, 791, 793, 794, 797, 798, 800, 802, 803, 807, 808, 809, 810, 812, 813, 817, 819, 820, 821, 822, 823, 825, 826, 827, 829, 830, 831, 832, 833, 834, 835, 837, 839, 840, 842, 844, 845, 846, 849, 851, 852, 853, 855, 859, 860, 861, 863, 864, 865, 866, 867, 868, 870, 872, 874, 878, 880, 884, 890, 892, 894, 895, 902, 903, 904, 906, 911, 914, 916, 917, 919, 920, 921, 922, 925, 926, 927, 930, 936, 937, 938, 939, 941, 943, 945, 946, 947, 950, 952, 954, 957, 958, 968, 970, 973, 976, 978, 979, 981, 982, 985, 987, 992, 995, 997, 998, 999, 1003, 1004, 1005, 1007, 1008, 1010, 1011, 1020, 1021, 1024, 1025, 1027, 1029, 1030, 1031, 1033, 1034, 1035, 1036, 1038, 1039, 1041, 1045, 1046, 1048, 1050, 1052, 1053, 1054, 1056, 1059, 1062, 1065, 1066, 1069, 1070, 1071, 1073, 1075, 1077, 1078, 1080, 1081, 1083, 1085, 1087, 1088, 1091, 1094, 1097, 1098, 1099, 1100, 1101, 1104, 1105, 1106, 1109, 1110, 1111, 1112, 1114, 1115, 1117, 1118, 1122, 1126, 1127, 1128, 1129, 1130, 1131, 1132, 1135, 1136, 1137, 1138, 1139, 1140, 1141, 1142, 1143, 1149, 1152, 1153, 1154, 1157, 1158, 1159, 1160, 1161, 1162, 1164, 1165, 1166, 1167, 1170, 1171, 1172, 1176, 1177, 1178, 1179, 1180, 1181, 1183, 1184, 1188, 1189, 1191, 1192, 1194, 1195, 1196, 1198, 1199, 1201, 1202, 1205, 1207, 1209, 1210, 1212, 1213, 1214, 1215, 1216, 1217, 1218, 1219, 1221, 1223, 1224, 1226, 1228, 1229, 1230, 1232, 1234, 1238, 1239, 1240, 1241, 1245, 1247, 1248, 1249, 1251, 1252, 1254, 1255, 1257, 1258, 1259, 1260, 1261, 1263, 1266, 1270, 1271, 1272, 1274, 1275, 1279, 1280, 1281, 1282, 1283, 1284, 1288, 1290, 1291, 1292, 1294, 1295, 1296, 1297, 1299, 1300, 1304, 1308, 1309, 1310, 1311, 1312, 1313, 1315, 1316, 1317, 1320, 1321, 1324, 1325, 1326, 1332, 1333, 1335, 1336, 1337, 1338, 1343, 1345, 1346, 1351, 1352, 1354, 1356, 1358, 1359, 1361, 1362, 1363, 1365, 1366, 1367, 1368, 1369, 1370, 1371, 1372, 1374, 1375, 1376, 1377, 1380, 1381, 1382, 1383, 1384, 1388, 1391, 1392, 1394, 1395, 1396, 1397, 1399, 1400, 1401, 1408, 1410, 1412, 1413, 1414, 1415, 1417, 1418, 1419, 1421, 1422, 1423, 1424, 1425, 1430, 1431, 1432, 1433, 1438, 1441, 1443, 1445, 1446, 1447, 1449, 1451, 1453, 1455, 1456, 1458, 1459, 1464, 1465, 1467, 1469, 1470, 1474, 1475, 1477, 1478, 1479, 1480, 1481, 1482, 1483, 1484, 1488, 1489, 1491, 1492, 1494, 1495, 1497, 1498, 1500, 1501, 1504, 1506, 1507, 1509, 1510, 1511, 1513, 1514, 1515, 1516, 1517, 1518, 1519, 1520, 1521, 1525, 1526, 1527, 1528, 1529, 1530, 1533, 1535, 1540, 1543, 1544, 1545, 1546, 1547, 1548, 1551, 1552, 1554, 1555, 1556, 1560, 1561, 1562, 1563, 1564, 1565, 1566, 1567, 1568, 1569, 1572, 1574, 1575, 1578, 1579, 1580, 1586, 1587, 1589, 1593, 1594, 1596, 1597, 1598, 1599, 1601, 1602, 1603, 1606, 1607, 1608, 1609, 1610, 1611, 1614, 1617, 1618, 1619, 1620, 1622, 1623, 1625, 1626, 1627, 1628, 1629, 1631, 1632, 1633, 1634, 1637, 1638, 1639, 1642, 1643, 1644, 1645, 1646, 1647, 1651, 1653, 1654, 1655, 1656, 1657, 1659, 1661, 1663, 1664, 1665, 1667, 1668, 1669, 1670, 1672, 1675, 1676, 1677, 1679, 1680, 1681, 1684, 1689, 1691, 1692, 1697, 1700, 1701, 1703, 1704, 1707, 1709, 1711, 1713, 1714, 1715, 1716, 1717, 1718, 1721, 1722, 1724, 1725, 1727, 1729, 1730, 1731, 1732, 1734, 1737, 1742, 1743, 1744, 1746, 1747, 1750, 1751, 1752, 1753, 1754, 1756, 1758, 1761, 1764, 1766, 1767, 1769, 1771, 1772, 1774, 1775, 1779, 1781, 1782, 1783, 1785, 1786, 1789, 1791, 1794, 1795, 1796, 1797, 1798, 1799, 1800, 1803, 1804, 1805, 1807, 1809, 1810, 1811, 1813, 1814, 1815, 1816, 1817, 1818, 1819, 1823, 1824, 1826, 1827, 1829, 1831, 1833, 1834, 1835, 1837, 1838, 1841, 1843, 1844, 1845, 1851, 1852, 1854, 1856, 1861, 1865, 1866, 1868, 1870, 1871, 1872, 1873, 1874, 1875, 1877, 1878, 1881, 1883, 1884, 1886, 1889, 1890, 1891, 1894, 1897, 1899, 1900, 1901, 1902, 1904, 1905, 1906, 1909, 1910, 1911, 1912, 1915, 1916, 1917, 1923, 1924, 1925, 1926, 1927, 1929, 1930, 1931, 1933, 1934, 1936, 1937, 1938, 1941, 1944, 1947, 1949, 1951, 1954, 1955, 1956, 1957, 1958, 1960, 1961, 1962, 1965, 1968, 1970, 1971, 1973, 1974, 1976, 1978, 1979, 1980, 1981, 1982, 1984, 1985, 1986, 1987, 1988, 1989, 1991, 1992, 1993, 1994, 1995, 1996, 1997, 1998, 1999, 2001, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2016, 2017, 2018, 2020, 2021, 2023, 2024, 2025, 2027, 2030, 2031, 2034, 2035, 2036, 2037, 2039, 2041, 2042, 2044, 2045, 2046, 2047]
       
        if len(important_indices) != self.config.target_hidden_size:
            raise ValueError(
                f"target_hidden_size ({self.config.target_hidden_size})。"
            )

      
        device = self.input_layernorm.weight.device
        indices_tensor = torch.tensor(important_indices, dtype=torch.long, device=device)

    
        with torch.no_grad():
         
            selected_weights_input = self.input_layernorm.weight.data.index_select(0, indices_tensor)
            self.target_input_layernorm.weight.data.copy_(selected_weights_input)

            selected_weights_post = self.post_attention_layernorm.weight.data.index_select(0, indices_tensor)
            self.target_post_attention_layernorm.weight.data.copy_(selected_weights_post)


@dataclass
class IIModelOutput(ModelOutput):
    last_hidden_state: torch.FloatTensor = None
    compressed_hidden_states: torch.FloatTensor = None
    aux_loss: torch.FloatTensor = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None


class Model(Qwen2Model):
    _no_split_modules = ["CustomLayer"]

    def __init__(self, config: CustomConfig):
        super().__init__(config)

        self.zoom = nn.Linear(config.hidden_size, config.target_hidden_size, bias=False)
        self.layers = nn.ModuleList(
            [CustomLayer(config, layer_idx) if layer_idx not in config.del_layers else Qwen2DecoderLayer(config, layer_idx)
             for layer_idx in range(config.num_hidden_layers)]
        )
        self.target_norm = DebugLlamaRMSNorm(config.target_hidden_size, eps=config.target_rms_norm_eps)
        self.cur_step = 0

        self.post_init()

    def merge_weight(self):
        self.embed_tokens.weight.data = self.zoom(self.embed_tokens.weight.data).contiguous()
        self.norm.weight.data = self.target_norm.weight.data
    def init_norm(self):
      
        #important_indices = [0, 1, 2, 3, 8, 11, 12, 14, 16, 18, 19, 20, 22, 24, 26, 27, 28, 31, 33, 34, 35, 37, 38, 40, 42, 43, 46, 47, 49, 57, 58, 63, 64, 65, 66, 70, 74, 77, 78, 79, 80, 81, 84, 85, 87, 88, 89, 93, 99, 103, 105, 107, 108, 109, 110, 112, 119, 120, 121, 122, 123, 124, 125, 127, 128, 132, 133, 137, 138, 141, 142, 143, 145, 147, 148, 150, 151, 156, 158, 160, 161, 163, 165, 166, 167, 174, 176, 177, 180, 183, 184, 187, 191, 194, 195, 196, 197, 201, 203, 204, 205, 207, 212, 215, 216, 217, 218, 219, 221, 223, 224, 227, 230, 232, 233, 234, 235, 236, 237, 238, 239, 240, 242, 243, 245, 246, 247, 251, 254, 256, 260, 262, 263, 264, 265, 267, 272, 273, 274, 275, 277, 278, 279, 280, 281, 282, 287, 289, 291, 292, 294, 298, 300, 301, 303, 304, 307, 308, 310, 311, 312, 314, 319, 322, 323, 326, 328, 329, 331, 332, 333, 335, 338, 343, 348, 349, 350, 351, 353, 354, 355, 356, 357, 359, 360, 364, 366, 368, 370, 372, 373, 375, 376, 377, 379, 380, 384, 385, 388, 389, 393, 398, 399, 400, 403, 404, 409, 411, 412, 413, 415, 416, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 431, 433, 435, 436, 437, 438, 441, 443, 446, 447, 448, 453, 454, 456, 457, 458, 459, 460, 462, 466, 469, 472, 476, 478, 480, 481, 482, 484, 485, 486, 487, 488, 489, 492, 495, 496, 497, 500, 501, 507, 509, 511, 512, 513, 514, 517, 518, 519, 520, 522, 524, 525, 527, 528, 530, 531, 533, 535, 538, 539, 540, 542, 554, 555, 557, 559, 562, 563, 564, 567, 568, 570, 571, 573, 574, 575, 580, 581, 584, 585, 586, 588, 589, 591, 593, 595, 596, 600, 605, 606, 608, 611, 613, 615, 616, 617, 621, 624, 627, 628, 629, 630, 634, 642, 643, 647, 649, 650, 651, 653, 658, 661, 662, 663, 664, 666, 667, 668, 669, 671, 672, 674, 676, 677, 678, 685, 687, 688, 689, 690, 692, 696, 698, 700, 702, 703, 704, 705, 708, 710, 712, 713, 714, 717, 718, 719, 720, 723, 726, 727, 728, 730, 732, 733, 737, 738, 739, 742, 743, 744, 745, 746, 747, 749, 752, 753, 755, 756, 757, 758, 759, 760, 763, 764, 766, 767, 769, 770, 775, 776, 777, 778, 779, 780, 781, 785, 786, 787, 790, 793, 794, 795, 796, 798, 801, 803, 806, 808, 810, 811, 814, 816, 817, 820, 822, 823, 824, 825, 827, 830, 832, 834, 835, 836, 838, 839, 840, 842, 843, 844, 849, 851, 852, 853, 854, 855, 856, 859, 860, 864, 865, 866, 868, 869, 870, 873, 875, 876, 878, 881, 882, 883, 884, 887, 891, 893, 895, 896, 897, 898, 900, 901, 902, 906, 907, 909, 913, 914, 917, 919, 921, 922, 925, 927, 928, 929, 931, 932, 933, 934, 936, 938, 940, 945, 946, 952, 953, 954, 956, 957, 958, 960, 962, 963, 965, 967, 968, 972, 973, 976, 977, 981, 983, 984, 985, 987, 993, 994, 996, 997, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1016, 1018, 1023, 1024, 1026, 1027, 1031, 1032, 1034, 1035, 1037, 1041, 1043, 1044, 1045, 1047, 1048, 1049, 1050, 1052, 1053, 1055, 1056, 1057, 1058, 1060, 1069, 1072, 1076, 1077, 1079, 1082, 1083, 1084, 1085, 1086, 1087, 1088, 1090, 1095, 1097, 1099, 1102, 1103, 1104, 1105, 1106, 1109, 1110, 1111, 1113, 1114, 1115, 1119, 1122, 1123, 1129, 1131, 1133, 1136, 1139, 1144, 1145, 1146, 1147, 1151, 1154, 1155, 1156, 1158, 1161, 1162, 1166, 1167, 1170, 1171, 1172, 1175, 1176, 1177, 1179, 1181, 1184, 1185, 1188, 1189, 1194, 1195, 1197, 1200, 1203, 1205, 1208, 1209, 1210, 1212, 1215, 1216, 1217, 1218, 1220, 1221, 1223, 1224, 1225, 1226, 1227, 1228, 1229, 1230, 1234, 1235, 1236, 1237, 1238, 1239, 1241, 1243, 1244, 1249, 1251, 1252, 1253, 1254, 1255, 1259, 1264, 1266, 1268, 1270, 1271, 1272, 1277, 1278, 1279, 1280, 1281, 1283, 1285, 1286, 1288, 1290, 1291, 1292, 1293, 1295, 1299, 1302, 1304, 1305, 1306, 1307, 1308, 1310, 1313, 1314, 1315, 1316, 1319, 1320, 1326, 1332, 1333, 1335, 1336, 1342, 1345, 1346, 1347, 1348, 1349, 1354, 1355, 1356, 1358, 1361, 1362, 1363, 1364, 1367, 1368, 1369, 1372, 1374, 1375, 1376, 1378, 1379, 1380, 1382, 1389, 1391, 1392, 1394, 1395, 1398, 1399, 1401, 1402, 1403, 1404, 1407, 1408, 1409, 1410, 1416, 1417, 1422, 1424, 1425, 1428, 1431, 1433, 1435, 1437, 1439, 1440, 1443, 1444, 1447, 1448, 1453, 1454, 1456, 1457, 1458, 1459, 1465, 1466, 1467, 1470, 1471, 1474, 1476, 1477, 1479, 1480, 1482, 1484, 1485, 1487, 1488, 1493, 1499, 1500, 1503, 1505, 1507, 1508, 1510, 1515, 1519, 1522, 1523, 1524, 1525, 1526, 1527, 1528, 1529, 1531, 1533, 1534, 1535, 1536, 1537, 1539, 1540, 1541, 1542, 1543, 1544, 1545, 1546, 1552, 1553, 1554, 1557, 1558, 1559, 1560, 1562, 1563, 1564, 1566, 1568, 1571, 1572, 1573, 1576, 1578, 1579, 1582, 1584, 1585, 1586, 1589, 1591, 1595, 1597, 1602, 1603, 1604, 1605, 1608, 1609, 1610, 1614, 1616, 1617, 1619, 1622, 1624, 1627, 1629, 1630, 1632, 1634, 1638, 1639, 1640, 1641, 1643, 1644, 1646, 1647, 1648, 1649, 1650, 1651, 1653, 1654, 1655, 1657, 1659, 1670, 1673, 1674, 1675, 1677, 1678, 1680, 1681, 1683, 1685, 1686, 1687, 1688, 1690, 1692, 1693, 1694, 1696, 1698, 1705, 1709, 1710, 1711, 1712, 1713, 1714, 1715, 1716, 1717, 1719, 1722, 1724, 1726, 1727, 1728, 1729, 1730, 1731, 1739, 1742, 1743, 1744, 1745, 1746, 1747, 1755, 1756, 1760, 1761, 1762, 1765, 1774, 1777, 1778, 1781, 1784, 1786, 1788, 1789, 1791, 1792, 1793, 1794, 1795, 1796, 1797, 1798, 1800, 1803, 1804, 1805, 1806, 1807, 1808, 1809, 1810, 1812, 1814, 1815, 1817, 1820, 1823, 1825, 1826, 1827, 1830, 1831, 1837, 1840, 1842, 1844, 1845, 1847, 1848, 1851, 1852, 1853, 1857, 1859, 1861, 1862, 1866, 1867, 1874, 1878, 1879, 1882, 1886, 1887, 1888, 1889, 1894, 1895, 1897, 1899, 1901, 1902, 1903, 1904, 1905, 1906, 1909, 1910, 1911, 1914, 1915, 1916, 1917, 1924, 1925, 1926, 1927, 1928, 1930, 1932, 1934, 1935, 1940, 1941, 1943, 1947, 1948, 1950, 1951, 1955, 1956, 1958, 1964, 1965, 1967, 1968, 1973, 1974, 1976, 1978, 1980, 1981, 1982, 1983, 1986, 1988, 1991, 1992, 1994, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2007, 2008, 2010, 2011, 2015, 2016, 2018, 2019, 2021, 2022, 2024, 2026, 2027, 2029, 2030, 2033, 2035, 2036, 2041, 2042, 2044, 2045, 2046, 2047, 2048, 2049, 2052, 2053, 2054, 2056, 2059, 2065, 2066, 2067, 2071, 2072, 2073, 2076, 2079, 2080, 2083, 2086, 2087, 2090, 2092, 2094, 2095, 2096, 2098, 2099, 2100, 2101, 2102, 2104, 2107, 2109, 2112, 2113, 2115, 2117, 2119, 2121, 2122, 2124, 2125, 2126, 2127, 2130, 2132, 2136, 2139, 2141, 2143, 2145, 2147, 2151, 2152, 2154, 2156, 2161, 2165, 2166, 2168, 2171, 2172, 2176, 2177, 2178, 2181, 2183, 2184, 2185, 2188, 2189, 2192, 2193, 2195, 2196, 2197, 2198, 2200, 2202, 2203, 2204, 2205, 2209, 2210, 2212, 2214, 2216, 2217, 2218, 2219, 2223, 2224, 2225, 2227, 2228, 2231, 2233, 2234, 2236, 2237, 2239, 2240, 2243, 2244, 2245, 2247, 2250, 2254, 2255, 2257, 2258, 2262, 2263, 2264, 2265, 2267, 2268, 2270, 2271, 2273, 2275, 2279, 2280, 2287, 2288, 2289, 2292, 2294, 2297, 2299, 2303, 2305, 2307, 2308, 2312, 2317, 2320, 2322, 2326, 2327, 2328, 2329, 2330, 2331, 2336, 2338, 2339, 2344, 2345, 2346, 2348, 2352, 2354, 2355, 2356, 2357, 2358, 2359, 2360, 2362, 2363, 2364, 2366, 2369, 2370, 2371, 2372, 2373, 2374, 2378, 2379, 2380, 2381, 2383, 2386, 2388, 2392, 2393, 2395, 2398, 2400, 2403, 2407, 2411, 2412, 2420, 2423, 2425, 2428, 2430, 2432, 2433, 2434, 2435, 2437, 2444, 2445, 2446, 2448, 2452, 2453, 2454, 2458, 2461, 2462, 2465, 2469, 2472, 2473, 2474, 2476, 2478, 2480, 2482, 2484, 2486, 2487, 2489, 2493, 2497, 2498, 2502, 2507, 2509, 2510, 2513, 2515, 2517, 2528, 2530, 2531, 2535, 2536, 2537, 2540, 2541, 2544, 2551, 2552, 2558, 2560, 2563, 2574, 2576, 2577, 2579, 2581, 2582, 2584, 2585, 2587, 2588, 2589, 2590, 2593, 2599, 2602, 2603, 2607, 2609, 2610, 2612, 2613, 2614, 2615, 2618, 2619, 2621, 2626, 2629, 2630, 2632, 2634, 2637, 2638, 2647, 2648, 2655, 2656, 2658, 2661, 2663, 2668, 2669, 2670, 2671, 2673, 2676, 2678, 2679, 2680, 2681, 2682, 2683, 2685, 2686, 2688, 2689, 2690, 2692, 2696, 2697, 2698, 2701, 2702, 2703, 2704, 2707, 2709, 2712, 2713, 2714, 2715, 2717, 2720, 2721, 2722, 2723, 2724, 2726, 2727, 2734, 2739, 2741, 2748, 2749, 2750, 2752, 2758, 2759, 2760, 2761, 2764, 2766, 2771, 2773, 2774, 2775, 2776, 2777, 2778, 2779, 2781, 2783, 2784, 2786, 2787, 2789, 2791, 2792, 2795, 2797, 2799, 2800, 2801, 2802, 2803, 2805, 2807, 2808, 2809, 2814, 2817, 2822, 2823, 2824, 2825, 2834, 2837, 2839, 2841, 2844, 2848, 2849, 2852, 2854, 2856, 2858, 2859, 2860, 2864, 2868, 2873, 2874, 2875, 2876, 2878, 2879, 2881, 2883, 2884, 2885, 2888, 2892, 2895, 2896, 2898, 2902, 2905, 2912, 2914, 2917, 2920, 2921, 2922, 2923, 2924, 2926, 2927, 2930, 2931, 2932, 2934, 2936, 2937, 2938, 2939, 2940, 2941, 2943, 2945, 2946, 2949, 2952, 2953, 2954, 2962, 2963, 2965, 2966, 2970, 2971, 2973, 2986, 2988, 2990, 2992, 2993, 2994, 2995, 2998, 2999, 3000, 3004, 3005, 3006, 3008, 3009, 3010, 3018, 3020, 3026, 3027, 3029, 3031, 3033, 3034, 3036, 3037, 3039, 3046, 3049, 3050, 3052, 3053, 3057, 3061, 3062, 3064, 3065, 3066, 3067, 3069, 3070]
        important_indices = [0, 1, 2, 3, 6, 7, 9, 10, 14, 16, 17, 18, 20, 21, 24, 25, 26, 27, 28, 30, 31, 32, 34, 37, 38, 40, 41, 43, 45, 46, 47, 48, 49, 50, 52, 53, 54, 55, 56, 58, 59, 61, 62, 65, 71, 72, 73, 74, 75, 79, 80, 88, 90, 91, 95, 97, 99, 102, 103, 105, 110, 111, 115, 116, 117, 120, 124, 128, 129, 130, 132, 134, 136, 138, 139, 140, 141, 143, 145, 148, 149, 150, 152, 153, 158, 159, 163, 164, 165, 166, 167, 169, 171, 173, 174, 175, 176, 177, 182, 184, 185, 188, 190, 193, 194, 195, 196, 198, 204, 207, 209, 213, 218, 220, 221, 222, 223, 225, 226, 227, 229, 230, 232, 233, 234, 236, 237, 238, 240, 241, 242, 243, 245, 246, 247, 249, 250, 251, 253, 254, 255, 256, 257, 259, 260, 263, 264, 265, 268, 270, 272, 273, 275, 276, 277, 279, 280, 283, 284, 285, 286, 287, 289, 291, 292, 293, 294, 295, 296, 297, 298, 301, 305, 306, 307, 308, 310, 311, 312, 313, 314, 315, 317, 318, 319, 321, 322, 323, 324, 326, 328, 332, 333, 335, 336, 337, 338, 340, 341, 342, 345, 346, 347, 348, 350, 351, 357, 358, 360, 362, 364, 366, 368, 373, 374, 376, 378, 379, 381, 383, 384, 385, 386, 390, 391, 392, 393, 394, 395, 396, 398, 399, 401, 402, 403, 409, 411, 412, 414, 416, 419, 420, 421, 423, 424, 426, 427, 428, 430, 432, 437, 439, 440, 441, 442, 443, 444, 446, 449, 450, 451, 453, 454, 456, 457, 458, 459, 460, 463, 464, 465, 467, 468, 469, 471, 472, 473, 474, 477, 480, 482, 483, 484, 485, 486, 488, 489, 491, 493, 495, 496, 497, 499, 501, 502, 503, 504, 505, 506, 507, 510, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 525, 526, 527, 528, 529, 530, 531, 532, 534, 535, 536, 537, 545, 550, 552, 553, 554, 556, 558, 559, 560, 561, 562, 564, 567, 568, 569, 571, 572, 573, 575, 578, 581, 582, 583, 585, 588, 589, 590, 591, 594, 595, 598, 601, 605, 607, 608, 612, 613, 614, 615, 616, 619, 620, 621, 622, 627, 629, 634, 635, 636, 641, 642, 643, 645, 647, 649, 651, 653, 657, 658, 660, 663, 664, 665, 666, 667, 668, 669, 670, 671, 673, 676, 677, 678, 680, 681, 682, 684, 685, 687, 688, 689, 690, 693, 694, 695, 697, 700, 701, 702, 704, 705, 706, 708, 709, 710, 711, 713, 714, 715, 716, 717, 721, 724, 725, 726, 728, 730, 733, 734, 735, 736, 740, 741, 742, 743, 744, 747, 749, 750, 751, 753, 758, 759, 760, 762, 763, 764, 765, 766, 767, 771, 772, 773, 774, 777, 778, 780, 781, 782, 784, 786, 787, 788, 789, 791, 793, 794, 797, 798, 800, 802, 803, 807, 808, 809, 810, 812, 813, 817, 819, 820, 821, 822, 823, 825, 826, 827, 829, 830, 831, 832, 833, 834, 835, 837, 839, 840, 842, 844, 845, 846, 849, 851, 852, 853, 855, 859, 860, 861, 863, 864, 865, 866, 867, 868, 870, 872, 874, 878, 880, 884, 890, 892, 894, 895, 902, 903, 904, 906, 911, 914, 916, 917, 919, 920, 921, 922, 925, 926, 927, 930, 936, 937, 938, 939, 941, 943, 945, 946, 947, 950, 952, 954, 957, 958, 968, 970, 973, 976, 978, 979, 981, 982, 985, 987, 992, 995, 997, 998, 999, 1003, 1004, 1005, 1007, 1008, 1010, 1011, 1020, 1021, 1024, 1025, 1027, 1029, 1030, 1031, 1033, 1034, 1035, 1036, 1038, 1039, 1041, 1045, 1046, 1048, 1050, 1052, 1053, 1054, 1056, 1059, 1062, 1065, 1066, 1069, 1070, 1071, 1073, 1075, 1077, 1078, 1080, 1081, 1083, 1085, 1087, 1088, 1091, 1094, 1097, 1098, 1099, 1100, 1101, 1104, 1105, 1106, 1109, 1110, 1111, 1112, 1114, 1115, 1117, 1118, 1122, 1126, 1127, 1128, 1129, 1130, 1131, 1132, 1135, 1136, 1137, 1138, 1139, 1140, 1141, 1142, 1143, 1149, 1152, 1153, 1154, 1157, 1158, 1159, 1160, 1161, 1162, 1164, 1165, 1166, 1167, 1170, 1171, 1172, 1176, 1177, 1178, 1179, 1180, 1181, 1183, 1184, 1188, 1189, 1191, 1192, 1194, 1195, 1196, 1198, 1199, 1201, 1202, 1205, 1207, 1209, 1210, 1212, 1213, 1214, 1215, 1216, 1217, 1218, 1219, 1221, 1223, 1224, 1226, 1228, 1229, 1230, 1232, 1234, 1238, 1239, 1240, 1241, 1245, 1247, 1248, 1249, 1251, 1252, 1254, 1255, 1257, 1258, 1259, 1260, 1261, 1263, 1266, 1270, 1271, 1272, 1274, 1275, 1279, 1280, 1281, 1282, 1283, 1284, 1288, 1290, 1291, 1292, 1294, 1295, 1296, 1297, 1299, 1300, 1304, 1308, 1309, 1310, 1311, 1312, 1313, 1315, 1316, 1317, 1320, 1321, 1324, 1325, 1326, 1332, 1333, 1335, 1336, 1337, 1338, 1343, 1345, 1346, 1351, 1352, 1354, 1356, 1358, 1359, 1361, 1362, 1363, 1365, 1366, 1367, 1368, 1369, 1370, 1371, 1372, 1374, 1375, 1376, 1377, 1380, 1381, 1382, 1383, 1384, 1388, 1391, 1392, 1394, 1395, 1396, 1397, 1399, 1400, 1401, 1408, 1410, 1412, 1413, 1414, 1415, 1417, 1418, 1419, 1421, 1422, 1423, 1424, 1425, 1430, 1431, 1432, 1433, 1438, 1441, 1443, 1445, 1446, 1447, 1449, 1451, 1453, 1455, 1456, 1458, 1459, 1464, 1465, 1467, 1469, 1470, 1474, 1475, 1477, 1478, 1479, 1480, 1481, 1482, 1483, 1484, 1488, 1489, 1491, 1492, 1494, 1495, 1497, 1498, 1500, 1501, 1504, 1506, 1507, 1509, 1510, 1511, 1513, 1514, 1515, 1516, 1517, 1518, 1519, 1520, 1521, 1525, 1526, 1527, 1528, 1529, 1530, 1533, 1535, 1540, 1543, 1544, 1545, 1546, 1547, 1548, 1551, 1552, 1554, 1555, 1556, 1560, 1561, 1562, 1563, 1564, 1565, 1566, 1567, 1568, 1569, 1572, 1574, 1575, 1578, 1579, 1580, 1586, 1587, 1589, 1593, 1594, 1596, 1597, 1598, 1599, 1601, 1602, 1603, 1606, 1607, 1608, 1609, 1610, 1611, 1614, 1617, 1618, 1619, 1620, 1622, 1623, 1625, 1626, 1627, 1628, 1629, 1631, 1632, 1633, 1634, 1637, 1638, 1639, 1642, 1643, 1644, 1645, 1646, 1647, 1651, 1653, 1654, 1655, 1656, 1657, 1659, 1661, 1663, 1664, 1665, 1667, 1668, 1669, 1670, 1672, 1675, 1676, 1677, 1679, 1680, 1681, 1684, 1689, 1691, 1692, 1697, 1700, 1701, 1703, 1704, 1707, 1709, 1711, 1713, 1714, 1715, 1716, 1717, 1718, 1721, 1722, 1724, 1725, 1727, 1729, 1730, 1731, 1732, 1734, 1737, 1742, 1743, 1744, 1746, 1747, 1750, 1751, 1752, 1753, 1754, 1756, 1758, 1761, 1764, 1766, 1767, 1769, 1771, 1772, 1774, 1775, 1779, 1781, 1782, 1783, 1785, 1786, 1789, 1791, 1794, 1795, 1796, 1797, 1798, 1799, 1800, 1803, 1804, 1805, 1807, 1809, 1810, 1811, 1813, 1814, 1815, 1816, 1817, 1818, 1819, 1823, 1824, 1826, 1827, 1829, 1831, 1833, 1834, 1835, 1837, 1838, 1841, 1843, 1844, 1845, 1851, 1852, 1854, 1856, 1861, 1865, 1866, 1868, 1870, 1871, 1872, 1873, 1874, 1875, 1877, 1878, 1881, 1883, 1884, 1886, 1889, 1890, 1891, 1894, 1897, 1899, 1900, 1901, 1902, 1904, 1905, 1906, 1909, 1910, 1911, 1912, 1915, 1916, 1917, 1923, 1924, 1925, 1926, 1927, 1929, 1930, 1931, 1933, 1934, 1936, 1937, 1938, 1941, 1944, 1947, 1949, 1951, 1954, 1955, 1956, 1957, 1958, 1960, 1961, 1962, 1965, 1968, 1970, 1971, 1973, 1974, 1976, 1978, 1979, 1980, 1981, 1982, 1984, 1985, 1986, 1987, 1988, 1989, 1991, 1992, 1993, 1994, 1995, 1996, 1997, 1998, 1999, 2001, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2016, 2017, 2018, 2020, 2021, 2023, 2024, 2025, 2027, 2030, 2031, 2034, 2035, 2036, 2037, 2039, 2041, 2042, 2044, 2045, 2046, 2047]
      
        if len(important_indices) != self.target_norm.weight.shape[0]:
            raise ValueError(
             
                "dim not match"
            )

      
        device = self.norm.weight.device

        
        indices_tensor = torch.tensor(important_indices, dtype=torch.long, device=device)

       
        selected_weights = self.norm.weight.data.index_select(0, indices_tensor)

       
        with torch.no_grad():
            self.target_norm.weight.data.copy_(selected_weights)
       

      
        for i, layer in enumerate(self.layers):
          
            if isinstance(layer, CustomLayer):
               
                layer.init_norm()
            else:
               pass
                
        
       

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        position_ids=None,
        past_key_values=None,
        inputs_embeds=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        cache_position=None,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
            )

        if self.gradient_checkpointing and self.training and use_cache:
            raise NotImplementedError

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        # causal_mask = self._update_causal_mask(
        #     attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        # )
        assert attention_mask is None

        # embed positions
        hidden_states = inputs_embeds
        Wemb = self.zoom(self.embed_tokens.weight).to(device=input_ids.device)
        if os.environ.get("DEBUG", False):
            print("emb token", Wemb[0, :6])
        compressed_hidden_states = embedding(input_ids, Wemb)
        # assert not torch.isnan(compressed_hidden_states).any(), f"NaN detected in model output in af emb"

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = None
        aux_loss = 0

        # set state for logging loss
        grad_acumulation_steps = hyper_params["gradient_accumulation_steps"]
        cur_train_step = None
        if (self.cur_step + 1) % (grad_acumulation_steps * 20) == 0:
            cur_train_step = (self.cur_step + 1) // grad_acumulation_steps
        self.cur_step += 1
        
        for layer_idx, decoder_layer in enumerate(self.layers):
            if self.gradient_checkpointing and self.training:
                raise NotImplementedError
            
            if layer_idx not in self.config.del_layers:
                layer_outputs = decoder_layer(
                    hidden_states,
                    compressed_hidden_states,
                    attention_mask=None,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                )

                compressed_hidden_states = layer_outputs[1]
                loss_dict = layer_outputs[2]
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=None,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                )
            
            hidden_states = layer_outputs[0]

            if layer_idx not in self.config.del_layers:
                _log_dict = {}
                for k, v in loss_dict.items():
                    if self.config.use_aux_loss:
                        if isinstance(aux_loss, torch.Tensor):
                            aux_loss = aux_loss.to(v.device)
                        aux_loss = aux_loss + v * hyper_params["aux_loss_scale_factor"]
                    main_logger.debug(f"L{decoder_layer.layer_idx}-{k}: {v.item()}")
                    
                    if cur_train_step:
                        _log_dict[f"L{decoder_layer.layer_idx}-{k}"] = v.item()
                
                if cur_train_step and (os.environ.get("LOCAL_RANK", 0) == 0 or accelerator.is_main_process) and len(_log_dict) > 0:
                    wandb.log(_log_dict, cur_train_step)

        hidden_states = self.norm(hidden_states)
        compressed_hidden_states = self.target_norm(compressed_hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        if not return_dict:
            return tuple(v for v in [hidden_states, None, all_hidden_states, all_self_attns] if v is not None)
        
        return IIModelOutput(
            last_hidden_state=hidden_states,
            compressed_hidden_states=compressed_hidden_states,
            aux_loss=aux_loss,
            past_key_values=None,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )
    


def calculate_language_loss(lgts, labels, vocab_size):
    loss = None
    if labels is not None:
        # Shift so that tokens < n predict n
        shift_logits = lgts[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        loss_fct = CrossEntropyLoss()
        shift_logits = shift_logits.view(-1, vocab_size)
        shift_labels = shift_labels.view(-1)
        # Enable model parallelism
        shift_labels = shift_labels.to(shift_logits.device)
        loss = loss_fct(shift_logits, shift_labels)
    return loss


class CoTrainLM(Qwen2ForCausalLM):
    _tied_weights_keys = ["lm_head.weight"]
    def __init__(self, config: CustomConfig):
        super().__init__(config)
        self.model = Model(config)
        # self.zoom_up = nn.Linear(config.target_hidden_size, config.hidden_size, bias=False)
        if not config.tie_word_emb_proj:
            self.zoom_down = nn.Linear(config.hidden_size, config.target_hidden_size, bias=False)
            self.zoom_down.weight.data.normal_(mean=0.0, std=0.01)  # no init weights
        self.mseloss = LOSS_DICT[config.aux_loss_type]()
        self.kl_temperature = self.config.kl_temperature
        self.cur_step = 0
        self.cur_loss_accumulation = 0
        self.cur_logit_loss_accumulation = 0
        self.check_data_cls_loss = config.check_data_cls_loss
        self.data_cls_losses = [0] * 8
        self.data_cls_cnt = [0] * 8
        self.post_init()

    def merge_weight(self):
        # print(self.lm_head.weight.data.shape)
        if not self.config.tie_word_emb_proj:
            self.lm_head.weight.data = self.zoom_down(self.lm_head.weight.data).contiguous()
        else:
            self.lm_head.weight.data = self.model.zoom(self.lm_head.weight.data).contiguous()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        data_cls=None,
        position_ids=None,
        past_key_values=None,
        inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        cache_position=None,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )

        hidden_states = outputs[0]
        compressed_hidden_states = outputs[1]
        aux_loss = outputs[2]

        logits = self.lm_head(hidden_states)
        if not self.config.tie_word_emb_proj:
            Whead = self.zoom_down(self.lm_head.weight)
        else:
            Whead = self.model.zoom(self.lm_head.weight)
        if os.environ.get("DEBUG", False):
            print("head weight", Whead[0, :6])
        target_logits = linear(compressed_hidden_states, Whead)

        if self.config.use_logits_loss:
            target_logp = F.log_softmax(target_logits / self.kl_temperature, dim=-1)
            raw_logp = F.log_softmax(logits / self.kl_temperature, dim=-1)
            logits_loss = F.kl_div(target_logp, raw_logp, log_target=True, reduction="batchmean")
            # logits_loss = self.mseloss(target_logits, logits)
            aux_loss = aux_loss + logits_loss
            main_logger.debug(f"logits_loss: {round(logits_loss.item(), 4)}")
        
        raw_loss = calculate_language_loss(logits.float(), labels, self.config.vocab_size)
        target_loss = calculate_language_loss(target_logits.float(), labels, self.config.vocab_size)
        main_logger.debug(f"raw_loss: {round(raw_loss.item(), 4)}, target_loss: {round(target_loss.item(), 4)}")

        # wandb log
        self.cur_loss_accumulation += target_loss.item()
        if self.config.use_logits_loss:
            self.cur_logit_loss_accumulation += logits_loss.item()
        loss_log_steps = hyper_params["gradient_accumulation_steps"] * 5
        if self.check_data_cls_loss:
            assert hidden_states.shape[0] == 1, "only appliable in bs = 1"
            spec_cls = data_cls[0].item()
            self.data_cls_cnt[spec_cls] += 1
            self.data_cls_losses[spec_cls] += target_loss.item()
        if (self.cur_step + 1) % loss_log_steps == 0:
            cur_train_step = (self.cur_step + 1) // hyper_params["gradient_accumulation_steps"]
            _log_dict = {"target_loss": self.cur_loss_accumulation / loss_log_steps}
            if self.config.use_logits_loss:
                _log_dict["logits_loss"] = self.cur_logit_loss_accumulation / loss_log_steps
            # self.kl_temperature = 0.9 * self.kl_temperature + 0.1 * _log_dict["target_loss"] * 1.5
            if self.check_data_cls_loss:
                _log_dict.update({
                    f"{data_cls_reversed_dict[i]}_loss": loss / self.data_cls_cnt[i] 
                    for i, loss in enumerate(self.data_cls_losses) if self.data_cls_cnt[i] > 0
                })
            if (os.environ.get("LOCAL_RANK", 0) == 0 or accelerator.is_main_process):
                wandb.log(_log_dict, step=cur_train_step)
            self.cur_loss_accumulation = 0
            self.cur_logit_loss_accumulation = 0
            self.data_cls_cnt = [0] * 8
            self.data_cls_losses = [0] * 8
        self.cur_step += 1

        if not return_dict:
            raise NotImplementedError

        return CausalLMOutputWithPast(
            loss=target_loss + aux_loss if self.config.use_ntp_loss else aux_loss,
            logits=target_logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def freeze_original_model(self):
        key_words = ["zoom", "target"]
        for n, p in self.named_parameters():
            flag = False

            for key in key_words:
                if key in n:
                    flag = True

            p.requires_grad = flag

    def tie_custom_weights(self, tie_n):
        raise ValueError("low perf")
        layers = self.model.layers
        for i in range(2, self.config.num_hidden_layers - 1, tie_n):
            share_layer: CustomLayer = layers[i]
            for j in range(i + 1, min(self.config.num_hidden_layers - 1, i + tie_n)):
                cur_layer: CustomLayer = layers[j]
                cur_layer.mlp.zoom.weight = share_layer.mlp.zoom.weight
                cur_layer.self_attn.zoom_up.weight = share_layer.self_attn.zoom_up.weight
                cur_layer.self_attn.zoom_down.weight = share_layer.self_attn.zoom_down.weight

    def tie_word_emb_proj(self):
        # self.model.zoom.weight = self.zoom_down.weight
        self.zoom_down.weight = self.model.zoom.weight

    def get_trained_params(self):
        state_dict = {}
        for n, p in self.named_parameters():
            if p.requires_grad:
                state_dict[n] = p
        return state_dict

    def save_pretrained(self, *args, **kwargs):
        if kwargs.get("only_save_trainable", True):
            state_dict = self.get_trained_params()
            kwargs["state_dict"] = state_dict
        return super().save_pretrained(*args, **kwargs)
