import json
import torch
import torch.utils.data as Data
from torch import nn, optim
import numpy as np
import shutil
from model import *
from utils import *
from data import *
from torch.optim.lr_scheduler import CosineAnnealingLR
from warmup_scheduler import GradualWarmupScheduler 
from collections import defaultdict
import os 
from draw_condense_heatmap import plot_weight_heatmap_eigen, weighted_singular_vector_cosine

def get_condense_condition1_modified(args, model, device, layer):
    '''
        获取condense的条件
        返回condense1_cond_num1
    '''
    with torch.no_grad():
        W_V_weight = model.decoder.layers[layer].dec_self_attn.W_V.weight.data.cpu().numpy()
        W_1_weight = model.decoder.layers[layer].pos_ffn.fc[0].weight.data.cpu().numpy()
        W_2_weight = model.decoder.layers[layer].pos_ffn.fc[2].weight.data.cpu().numpy()
        W_proj = model.projection.weight.data.cpu().numpy()

    W_V_weight = W_V_weight.T
    W_1_weight = W_1_weight.T
    W_2_weight = W_2_weight.T
    W_proj = W_proj.T

    target_vector = W_V_weight[:, 0]
    target_vector = target_vector / np.linalg.norm(target_vector)
    W_a = np.matmul(target_vector, W_V_weight)

    W_b = np.matmul(W_1_weight, W_2_weight)
    tmp = np.matmul(W_b, W_proj)[:, 0]
    # print(f'shape of W_a: {W_a.shape}, shape of tmp: {tmp.shape}, Wb: {W_b.shape}, W_proj: {W_proj.shape}')
    condense_cond_num1 = np.sum((W_a * tmp) > 0).item()

    return condense_cond_num1

def get_condense_condition2_modified(args, model, device, layer):
    with torch.no_grad():
        W_V_weight = model.decoder.layers[layer].dec_self_attn.W_V.weight.data.cpu().numpy()
        W_1_weight = model.decoder.layers[layer].pos_ffn.fc[0].weight.data.cpu().numpy()
        W_2_weight = model.decoder.layers[layer].pos_ffn.fc[2].weight.data.cpu().numpy()
        W_proj = model.projection.weight.data.cpu().numpy()

    W_V_weight = W_V_weight.T
    W_1_weight = W_1_weight.T
    W_2_weight = W_2_weight.T
    W_proj = W_proj.T

    W_b = np.matmul(W_1_weight, W_2_weight)

    target_vector = W_V_weight[:, 0]
    target_vector = target_vector / np.linalg.norm(target_vector)
    W_a = np.matmul(target_vector, W_V_weight)
    diag_matrix = np.diag(W_a)


    tmp_matrix = np.matmul(diag_matrix, W_b)
    tmp_vector = tmp_matrix[0,:]

    condense_cond_num2 = 0
    for i in range(tmp_matrix.shape[0]):
        if np.dot(tmp_vector, tmp_matrix[i, :]) > 0:
            condense_cond_num2 += 1
    
    return condense_cond_num2

def get_smoothed_rank(tensor):
    smoothed_rank = tensor.norm('nuc').item() / tensor.norm(2).item()
    return smoothed_rank

# def train_step(args, model, train_dataloader, optimizer, criterion, device, clip=1, scheduler=None, current_epoch=0, logger=None):
#     model.train()
#     epoch_loss = 0
#     total_samples = 0
#     loss_list = []
#     ## 新增
#     grad_norms = defaultdict(list)
#     param_norms = defaultdict(list)
#     smoothed_ranks = defaultdict(list)

#     condense_cond_num1 = {f'{layer}': [] for layer in range(args.n_layers)}
#     condense_cond_num2 = {f'{layer}': [] for layer in range(args.n_layers)}
#     direction_cos_similarity = defaultdict(list)
#     module_names = ['W_V.weight', 'W_Q.weight', 'W_K.weight', 'fc.0.weight', 'fc.2.weight']
    
    
#     for i, dec_inputs in enumerate(train_dataloader):  
#         with torch.no_grad():
#             # 计算condense的条件
#             for layer in range(args.n_layers):
#                 # if layer not in condense_cond_num1:
#                 #     condense_cond_num1[layer] = []
#                 # if layer not in condense_cond_num2:
#                 #     condense_cond_num2[layer] = []
                
#                 condense_cond_num1[f'{layer}'].append(get_condense_condition1_modified(args, model, device, layer))
#                 condense_cond_num2[f'{layer}'].append(get_condense_condition2_modified(args, model, device, layer))
#             if (i * 4) % len(train_dataloader) == 0 or (current_epoch == 0 and i % 100 == 0):
#                 debug_dir = os.path.join(args.working_dir, 'debug_checkpoints')
#                 os.makedirs(debug_dir, exist_ok=True)
#                 torch.save(model.state_dict(), os.path.join(debug_dir, f'model_{current_epoch}_{i}.pt'))

#                 os.makedirs(args.working_dir + '/condense_heatmap', exist_ok=True)
#                 for name, module in model.named_modules():
#                     if isinstance(module, nn.Linear) and ('proj' not in name):
#                         weight = module.weight.data.cpu().numpy()
#                         # cos_sim_matrix = cosine_similarity_array(weight)
#                         plot_weight_heatmap_eigen(weight, args.working_dir + f'/condense_heatmap/epoch{current_epoch}_{i}_{name}.png')
#         # print(dec_inputs)
#         # print(type(dec_inputs),)
#         if isinstance(dec_inputs, list):
#             dec_inputs, dec_outputs = dec_inputs
#             seq_len = args.seq_len
#             optimizer.zero_grad()
#             dec_inputs, dec_outputs = dec_inputs.to(device), dec_outputs.to(device)
#             outputs, _ = model(dec_inputs)

#             batch_size = dec_inputs.size(0)  # 获取当前批次的实际大小
#             total_samples += batch_size
            
#             loss = criterion(outputs.view(batch_size, seq_len, args.vocab_size)[:,-1,:], dec_outputs[:,-1].view(-1))
#         else:
#             dec_outputs = dec_inputs[:, 1:]
#             dec_inputs = dec_inputs[:, :-1]
#             seq_len = args.seq_len - 1
#             optimizer.zero_grad()
#             dec_inputs, dec_outputs = dec_inputs.to(device), dec_outputs.to(device)
#             outputs, _ = model(dec_inputs)

#             batch_size = dec_inputs.size(0)  # 获取当前批次的实际大小
#             total_samples += batch_size
            
#             loss = criterion(outputs.view(-1, args.vocab_size), dec_outputs.view(-1))

#         loss_list.append(loss.item())
#         epoch_loss += loss.item() * batch_size  # 将损失乘以批次大小
#         loss.backward()
        

#         # 梯度
#         for name, param in model.named_parameters():
#             if param.grad is not None:
#                 # 计算L2范数
#                 grad_norm = param.grad.norm(2).item()
#                 grad_norms[name].append(grad_norm)

#         # 范数
#         for name, param in model.named_parameters():
#             param_norm = param.data.norm(2).item()
#             param_norms[name].append(param_norm)

#         ## 算起来太慢了
#         if i % 100 == 0:
#             for name, param in model.named_parameters():
#                 for module_name in module_names:
#                     if module_name in name:
#                         smoothed_rank = get_smoothed_rank(param.data.cpu())
#                         smoothed_ranks[name].append(smoothed_rank)
#             # 计算方向变化
#             old_params = {name: param.data.cpu().numpy() for name, param in model.named_parameters()}
#         # torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
#         optimizer.step()
    
#         if scheduler is not None:
#             scheduler.step()

#         if i % 100 == 0:
#             if logger is not None:
#                 logger.info(f'Epoch: {current_epoch:<5}  Batch: {i:<5}  Loss: {loss.item():.4e}')
#             else:
#                 print(f'Epoch: {current_epoch:<5}  Batch: {i:<5}  Loss: {loss.item():.4e}')
#             # 计算方向变化
#             for name, param in model.named_parameters():
#                 for module_name in module_names:
#                     if module_name in name:
#                         if name in old_params:
#                             cos_sim, _, _ = weighted_singular_vector_cosine(old_params[name], param.data.cpu().numpy())
#                             direction_cos_similarity[name].append(cos_sim)
                            
                            
#     os.makedirs(f'{args.working_dir}/loss', exist_ok=True)
#     os.makedirs(f'{args.working_dir}/condense_num_step', exist_ok=True)
#     np.save(f'{args.working_dir}/loss/epoch_loss_{current_epoch}.npy', np.array(loss_list))
    
#     np.savez(f'{args.working_dir}/condense_num_step/condense_cond_num1_{current_epoch}.npz', **{k: np.array(v) for k, v in condense_cond_num1.items()})
#     np.savez(f'{args.working_dir}/condense_num_step/condense_cond_num2_{current_epoch}.npz', **{k: np.array(v) for k, v in condense_cond_num2.items()})
#     np.savez(f'{args.working_dir}/grad_norms_{current_epoch}.npz', **{k: np.array(v) for k, v in grad_norms.items()})
#     np.savez(f'{args.working_dir}/param_norms_{current_epoch}.npz', **{k: np.array(v) for k, v in param_norms.items()})
#     np.savez(f'{args.working_dir}/smoothed_ranks_{current_epoch}.npz', **{k: np.array(v) for k, v in smoothed_ranks.items()})
#     np.savez(f'{args.working_dir}/direction_cos_similarity_{current_epoch}.npz', **{k: np.array(v) for k, v in direction_cos_similarity.items()})


#     return epoch_loss / total_samples, grad_norms, param_norms, condense_cond_num1, condense_cond_num2, loss_list, smoothed_ranks  # 返回平均损失

def train_step(args, model, train_dataloader, optimizer, criterion, device, clip=1, scheduler=None, current_epoch=0, logger=None):
    model.train()
    epoch_loss = 0
    total_samples = 0
    loss_list = []
    ## 新增
    grad_norms = defaultdict(list)
    param_norms = defaultdict(list)
    smoothed_ranks = defaultdict(list)

    condense_cond_num1 = {f'{layer}': [] for layer in range(args.n_layers)}
    condense_cond_num2 = {f'{layer}': [] for layer in range(args.n_layers)}
    direction_cos_similarity = defaultdict(list)
    # module_names = ['W_V.weight', 'W_Q.weight', 'W_K.weight', 'fc.0.weight', 'fc.2.weight']
    # ===== 新增 / 修改部分开始 =====
    # 原来的：module_names = ['W_V.weight', 'W_Q.weight', 'W_K.weight', 'fc.0.weight', 'fc.2.weight']
    # 改成“模块前缀”（不带 .weight/.bias），避免逐参数粒度而是按模块对齐
    module_prefixes = ['W_V', 'W_Q', 'W_K', 'fc.0', 'fc.2']

    def _match_param_by_prefix(name: str, prefix: str) -> bool:
        """
        允许类似 `.../layers.0.W_Q.weight`、`.../fc.0.weight` 这样的层级名；
        只要中间包含 `prefix` 且后缀是 .weight 或 .bias 就算匹配。
        你也可以改成更严格的 startswith/endswith 组合。
        """
        return (prefix in name)

    def _get_combined_matrix(model, prefix: str):
        """
        返回某个模块前缀对应的 [W | b]（按列拼接），若无 bias 则仅返回 W。
        要求 W 形状 [out_dim, in_dim]，b 形状 [out_dim]。
        """
        W = None
        b = None
        W_name = None
        b_name = None
        with torch.no_grad():
            for name, p in model.named_parameters():
                if _match_param_by_prefix(name, prefix):
                    if name.endswith('.weight'):
                        W = p.detach().cpu()
                        W_name = name
                    elif name.endswith('.bias'):
                        b = p.detach().cpu()
                        b_name = name
        if W is None:
            return None  # 该前缀找不到weight，跳过
        if b is not None:
            # [out, in] 和 [out] -> [out, in+1]
            combined = torch.cat([W, b.view(-1, 1)], dim=1)
            combined_name = f'{prefix}.[weight|bias]'
        else:
            combined = W
            combined_name = f'{prefix}.weight_only'
        return combined.numpy(), combined_name

    def _snapshot_combined_params(model, prefixes):
        """
        生成 {prefix: np.ndarray([W|b])} 的快照。
        """
        snap = {}
        for pf in prefixes:
            out = _get_combined_matrix(model, pf)
            if out is not None:
                arr, _ = out
                snap[pf] = arr
        return snap

    def get_smoothed_rank_combined(model, prefix):
        """
        读取指定 prefix 的 [W|b] 合并矩阵，并计算 smoothed rank。
        若没有 bias 就只用 W。
        返回 (rank_value, tag)；若不存在 W 则返回 (None, None)。
        """
        out = _get_combined_matrix(model, prefix)  # 上次给的函数
        if out is None:
            return None, None
        mat_np, tag = out                      # mat_np: numpy array, shape [out, in(+1)]
        mat_t = torch.from_numpy(mat_np)       # get_smoothed_rank 接受 torch.Tensor
        r = get_smoothed_rank(mat_t)           # 你原有的函数
        return r, tag
    # ===== 新增 / 修改部分结束 =====

    
    
    for i, dec_inputs in enumerate(train_dataloader):  
        with torch.no_grad():
            # 计算condense的条件
            for layer in range(args.n_layers):
                # if layer not in condense_cond_num1:
                #     condense_cond_num1[layer] = []
                # if layer not in condense_cond_num2:
                #     condense_cond_num2[layer] = []
                
                condense_cond_num1[f'{layer}'].append(get_condense_condition1_modified(args, model, device, layer))
                condense_cond_num2[f'{layer}'].append(get_condense_condition2_modified(args, model, device, layer))
            if (i * 4) % len(train_dataloader) == 0 or (current_epoch == 0 and i % 100 == 0):
                debug_dir = os.path.join(args.working_dir, 'debug_checkpoints')
                os.makedirs(debug_dir, exist_ok=True)
                torch.save(model.state_dict(), os.path.join(debug_dir, f'model_{current_epoch}_{i}.pt'))

                os.makedirs(args.working_dir + '/condense_heatmap', exist_ok=True)
                for name, module in model.named_modules():
                    if isinstance(module, nn.Linear) and ('proj' not in name):
                        weight = module.weight.data.cpu().numpy()
                        if module.bias is not None:
                            weight = np.concatenate([weight, module.bias.data.cpu().numpy()[:, None]], axis=1)
                        
                        plot_weight_heatmap_eigen(weight, args.working_dir + f'/condense_heatmap/epoch{current_epoch}_{i}_{name}.png')
        # print(dec_inputs)
        # print(type(dec_inputs),)
        if isinstance(dec_inputs, list):
            dec_inputs, dec_outputs = dec_inputs
            seq_len = args.seq_len
            optimizer.zero_grad()
            dec_inputs, dec_outputs = dec_inputs.to(device), dec_outputs.to(device)
            outputs, _ = model(dec_inputs)

            batch_size = dec_inputs.size(0)  # 获取当前批次的实际大小
            total_samples += batch_size
            
            loss = criterion(outputs.view(batch_size, seq_len, args.vocab_size)[:,-1,:], dec_outputs[:,-1].view(-1))
        else:
            dec_outputs = dec_inputs[:, 1:]
            dec_inputs = dec_inputs[:, :-1]
            seq_len = args.seq_len - 1
            optimizer.zero_grad()
            dec_inputs, dec_outputs = dec_inputs.to(device), dec_outputs.to(device)
            outputs, _ = model(dec_inputs)

            batch_size = dec_inputs.size(0)  # 获取当前批次的实际大小
            total_samples += batch_size
            
            loss = criterion(outputs.view(-1, args.vocab_size), dec_outputs.view(-1))

        loss_list.append(loss.item())
        epoch_loss += loss.item() * batch_size  # 将损失乘以批次大小
        loss.backward()
        

        # 梯度
        for name, param in model.named_parameters():
            if param.grad is not None:
                # 计算L2范数
                grad_norm = param.grad.norm(2).item()
                grad_norms[name].append(grad_norm)

        # 范数
        for name, param in model.named_parameters():
            param_norm = param.data.norm(2).item()
            param_norms[name].append(param_norm)

        ## 算起来太慢了
        # 1) 在 i % 100 == 0 时做快照（旧：old_params = {name: param...}）
        if i % 100 == 0:
            # 仍旧计算秩等
            # for name, param in model.named_parameters():
            #     for pf in module_prefixes:
            #         if _match_param_by_prefix(name, pf) and name.endswith('.weight'):
            #             smoothed_rank = get_smoothed_rank(param.data.cpu())
            #             smoothed_ranks[name].append(smoothed_rank)
            for pf in module_prefixes:
                r, tag = get_smoothed_rank_combined(model, pf)
                if r is not None:
                    # 以 prefix 作为 key 存，避免 .weight/.bias 粒度不一致
                    smoothed_ranks[pf].append(float(r))
            # 新：按模块前缀抓 [W|b] 快照
            old_params = _snapshot_combined_params(model, module_prefixes)

        # torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
    
        if scheduler is not None:
            scheduler.step()

        if i % 100 == 0:
            if logger is not None:
                logger.info(f'Epoch: {current_epoch:<5}  Batch: {i:<5}  Loss: {loss.item():.4e}')
            else:
                print(f'Epoch: {current_epoch:<5}  Batch: {i:<5}  Loss: {loss.item():.4e}')
            # 新：按前缀取“新”的 [W|b]，与 old_params[prefix] 比较
            for pf in module_prefixes:
                new_pack = _get_combined_matrix(model, pf)
                if new_pack is None or pf not in old_params:
                    continue
                new_mat, new_tag = new_pack
                old_mat = old_params[pf]
                cos_sim, _, _ = weighted_singular_vector_cosine(old_mat, new_mat)
                # 方向相似度的 key 用前缀，避免 .weight/.bias 粒度带来的不一致
                direction_cos_similarity[pf].append(cos_sim)
                            
    os.makedirs(f'{args.working_dir}/loss', exist_ok=True)
    os.makedirs(f'{args.working_dir}/condense_num_step', exist_ok=True)
    np.save(f'{args.working_dir}/loss/epoch_loss_{current_epoch}.npy', np.array(loss_list))
    
    np.savez(f'{args.working_dir}/condense_num_step/condense_cond_num1_{current_epoch}.npz', **{k: np.array(v) for k, v in condense_cond_num1.items()})
    np.savez(f'{args.working_dir}/condense_num_step/condense_cond_num2_{current_epoch}.npz', **{k: np.array(v) for k, v in condense_cond_num2.items()})
    np.savez(f'{args.working_dir}/grad_norms_{current_epoch}.npz', **{k: np.array(v) for k, v in grad_norms.items()})
    np.savez(f'{args.working_dir}/param_norms_{current_epoch}.npz', **{k: np.array(v) for k, v in param_norms.items()})
    np.savez(f'{args.working_dir}/smoothed_ranks_{current_epoch}.npz', **{k: np.array(v) for k, v in smoothed_ranks.items()})
    np.savez(f'{args.working_dir}/direction_cos_similarity_{current_epoch}.npz', **{k: np.array(v) for k, v in direction_cos_similarity.items()})


    return epoch_loss / total_samples, grad_norms, param_norms, condense_cond_num1, condense_cond_num2, loss_list, smoothed_ranks  # 返回平均损失



def test_step(args, model, test_data_loader, criterion, device):
    model.eval()
    epoch_loss = 0
    total_samples = 0


    for i, dec_inputs in enumerate(test_data_loader):
        if isinstance(dec_inputs, list):
            dec_inputs, dec_outputs = dec_inputs
            seq_len = args.seq_len
            dec_inputs, dec_outputs = dec_inputs.to(device), dec_outputs.to(device)
            outputs, _ = model(dec_inputs)
            
            batch_size = dec_inputs.size(0)  # 获取当前批次的实际大小
            total_samples += batch_size
            loss = criterion(outputs.view(batch_size, seq_len, args.vocab_size)[:,-1,:], dec_outputs[:,-1].view(-1))
            # loss = criterion(outputs.view(-1, args.vocab_size), dec_outputs.view(-1))
            
            epoch_loss += loss.item() * batch_size  # 将损失乘以批次大小
        else:
            dec_outputs = dec_inputs[:, 1:]
            dec_inputs = dec_inputs[:, :-1]
            seq_len = args.seq_len - 1
            dec_inputs, dec_outputs = dec_inputs.to(device), dec_outputs.to(device)
            outputs, _ = model(dec_inputs)
            
            batch_size = dec_inputs.size(0)  # 获取当前批次的实际大小
            total_samples += batch_size
            loss = criterion(outputs.view(-1, args.vocab_size), dec_outputs.view(-1))
            
            epoch_loss += loss.item() * batch_size  # 将损失乘以批次大小
    
    return epoch_loss / total_samples  # 返回平均损失


def train_step_wikitext(args, model, train_dataloader, optimizer, criterion, device, clip=1, scheduler=None, current_epoch=0, logger=None):
    model.train()
    epoch_loss = 0
    total_samples = 0
    loss_list = []

    # 用于记录整个 epoch 的指标
    # grad_norms = defaultdict(list)
    param_norms = defaultdict(list)
    condense_cond_num1 = defaultdict(list)
    condense_cond_num2 = defaultdict(list)
    smoothed_ranks = defaultdict(list)
    module_names = ['W_V.weight', 'W_Q.weight', 'W_K.weight', 'fc.0.weight', 'fc.2.weight']
    
    # --- 训练循环 ---
    for i, dec_inputs in enumerate(train_dataloader):  
 
        dec_outputs = dec_inputs[:, 1:]
        dec_inputs = dec_inputs[:, :-1]
        dec_inputs, dec_outputs = dec_inputs.to(device), dec_outputs.to(device)

        # 1. 前向传播和计算损失
        optimizer.zero_grad()
        outputs, _ = model(dec_inputs)
        loss = criterion(outputs.view(-1, args.vocab_size), dec_outputs.view(-1))
        
        # 2. 反向传播
        loss.backward()

        ## 计算梯度
        # for name, param in model.named_parameters():
        #     # 筛选我们关心的模块
        #     is_target_module = any(module_name in name for module_name in module_names)
        #     if is_target_module:
        #         if param.grad is not None:
        #             grad_norms[name].append(param.grad.norm(2).item())
        
        # 3. (可选) 梯度裁剪
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        # 4. 更新权重
        optimizer.step()
        if scheduler is not None:
            scheduler.step()

        # --- 指标记录和日志 ---
        batch_size = dec_inputs.size(0)
        total_samples += batch_size
        current_loss = loss.item()
        loss_list.append(current_loss)
        epoch_loss += current_loss * batch_size

        # 在 no_grad 上下文中计算所有无需梯度的指标，提高效率
        with torch.no_grad():
            # 计算梯度和参数范数
            for name, param in model.named_parameters():
                # 筛选我们关心的模块
                is_target_module = any(module_name in name for module_name in module_names)
                if is_target_module:
                    # if param.grad is not None:
                    #     grad_norms[name].append(param.grad.norm(2).item())
                    param_norms[name].append(param.data.norm(2).item())

            # 计算 condense 条件数
            for layer in range(args.n_layers):
                condense_cond_num1[layer].append(get_condense_condition1_modified(args, model, device, layer))
                condense_cond_num2[layer].append(get_condense_condition2_modified(args, model, device, layer))

        # --- 定期日志和中间保存 (用于调试) ---
        if i % 100 == 0:
            log_msg = f'Epoch: {current_epoch:<5}  Batch: {i:<5}/{len(train_dataloader)}  Loss: {current_loss:.4e}'
            if logger:
                logger.info(log_msg)
            else:
                print(log_msg)

            plt.figure()
            plt.plot(loss_list)
            plt.xlabel('Batch Index')
            plt.ylabel('Loss')
            plt.title(f'Training Loss Curve for Epoch {current_epoch}')
            plt.grid(True)
            plt.savefig(os.path.join(args.working_dir, 'loss', f'loss_epoch_{current_epoch}.png'))
            plt.close()

            # (可选) 如果确实需要在训练中途保存模型和热力图用于调试
            # 注意：这会显著减慢训练速度
            if (i > 0 and i % 400 == 0) or (current_epoch == 0 and i % 400 == 0):
                debug_dir = os.path.join(args.working_dir, 'debug_checkpoints')
                os.makedirs(debug_dir, exist_ok=True)
                torch.save(model.state_dict(), os.path.join(debug_dir, f'model_{current_epoch}_{i}.pt'))
                os.makedirs(os.path.join(args.working_dir, 'condense_heatmap'), exist_ok=True)
                for name, module in model.named_modules():
                    if isinstance(module, nn.Linear) and ('proj' not in name):
                        weight = module.weight.data.cpu().numpy()
                        plot_weight_heatmap_eigen(weight, os.path.join(args.working_dir, 'condense_heatmap', f'epoch{current_epoch}_{i}_{name}.png'))
            
            for name, param in model.named_parameters():
                for module_name in module_names:
                    if module_name in name:
                        smoothed_rank = get_smoothed_rank(param.data.cpu())
                        smoothed_ranks[name].append(smoothed_rank)
                        
    # --- Epoch 结束后执行的操作 ---
    # 确保目录存在
    os.makedirs(os.path.join(args.working_dir, 'loss'), exist_ok=True)
    os.makedirs(os.path.join(args.working_dir, 'condense_num_step'), exist_ok=True)
    os.makedirs(os.path.join(args.working_dir, 'norms'), exist_ok=True)

    # 1. 保存本 Epoch 的所有指标
    print(f"Epoch {current_epoch} finished. Saving metrics...")
    np.save(os.path.join(args.working_dir, 'loss', f'epoch_loss_{current_epoch}.npy'), np.array(loss_list))
    
    # 保存字典时，直接保存字典对象，让numpy/pickle处理
    # 在加载时需要使用 np.load(..., allow_pickle=True)
    np.save(os.path.join(args.working_dir, 'condense_num_step', f'condense_cond_num1_{current_epoch}.npy'), condense_cond_num1)
    np.save(os.path.join(args.working_dir, 'condense_num_step', f'condense_cond_num2_{current_epoch}.npy'), condense_cond_num2)
    np.savez(f'{args.working_dir}/smoothed_ranks_{current_epoch}.npz', **{k: np.array(v) for k, v in smoothed_ranks.items()})
    # np.savez(os.path.join(args.working_dir, 'norms', f'grad_norms_{current_epoch}.npz'), **{k: np.array(v) for k, v in grad_norms.items()})
    np.savez(os.path.join(args.working_dir, 'norms', f'param_norms_{current_epoch}.npz'), **{k: np.array(v) for k, v in param_norms.items()})


    
    # print(f"Metrics and plots for epoch {current_epoch} saved successfully.")
            
    return epoch_loss / total_samples, None, param_norms, condense_cond_num1, condense_cond_num2, loss_list


def train(args, datas = None, train_dataloader = None, test_dataloader = None, **kwargs):
    '''
    Required:
        args: 超参数字典
        datas: 所有类型的数据集构成的字典
    '''
    # 训练集
    if train_dataloader is not None:
        train_dataloader = train_dataloader
    else:
        train_dataloader = get_train_data(args, datas)
        np.savez(f'{args.working_dir}/data/datas.npz', **datas)
    if test_dataloader is not None:
        test_data_loader = test_dataloader
    else:
        # 所有数据集对应的data_loader
        data_loader_group = get_data_loader_group(args, datas)
        test_data_loader = data_loader_group[list(data_loader_group.keys())[0]]  # 取第一个数据集作为测试集
        
    args.num_batches = len(train_dataloader)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    my_logger = Log(f'{args.working_dir}/train_log.log')
    
    # 模型与参数量
    model = get_model(args, device, **kwargs)
    my_logger.info(f'Total parameters: {sum(p.numel() for p in model.parameters())}')

    criterion = nn.CrossEntropyLoss(ignore_index=-100).to(device)
    
    optimizer, scheduler = get_optimizer(model, args, **kwargs)


    # 保存参数
    save_args = dict(vars(args))
    # 将kwargs中的参数也保存
    for key, value in kwargs.items():
        save_args[key] = value
    save_to_json_noindent(save_args, f'{args.working_dir}/config.json')


    # 保存源代码
    for file in ['main.py', 'data.py', 'train.py', 'test.py', 'script_LTP.py']:
        shutil.copy(file, f'{args.working_dir}/src/{file}')
    for dir in ['utils', 'model', 'data_generator']:
        shutil.copytree(dir, f'{args.working_dir}/src/{dir}', dirs_exist_ok=True)    
    
    train_loss_his = []        # 训练集loss
    test_loss_his = []         # data_train=0的数据的总loss
    group_loss_his = []        # 每类数据的loss，其中训练数据的loss为0（因计算量过大且不是很有意义）

    acc_epoch_his = []    
    train_acc_his = []         # data_train=1的数据的总accuracy(训练集accuracy)
    test_acc_his = []          # data_train=0的数据的总accuracy
    group_acc_his = []         # 每类数据的accuracy

    train_loss_list_along_step = []
    grad_norms = defaultdict(list)
    param_norms = defaultdict(list)
    smoothed_rank = defaultdict(list)
    condense_cond_num1_list = []
    condense_cond_num2_list = []


    print('training...')
    torch.save(model.state_dict(), f'{args.working_dir}/model/model_ini.pt')
    for epoch in range(args.n_epoch):

        # 训练并计算loss
        train_loss, grad_norms_epoch, param_norms_epoch, condense_cond_num1_epoch, condense_cond_num2_epoch, tmp_train_loss_list, tmp_smoothed_rank  = train_step(args, model, train_dataloader, optimizer, criterion, device, args.clip, scheduler=scheduler, current_epoch = epoch, logger = my_logger)
        for key in grad_norms_epoch:
            grad_norms[key].extend(grad_norms_epoch[key])
        for key in param_norms_epoch:
            param_norms[key].extend(param_norms_epoch[key])
        for key in tmp_smoothed_rank:
            smoothed_rank[key].extend(tmp_smoothed_rank[key])

        condense_cond_num1_list.extend(condense_cond_num1_epoch)
        condense_cond_num2_list.extend(condense_cond_num2_epoch)
        
        train_loss_his.append(train_loss)
        train_loss_list_along_step.extend(tmp_train_loss_list)
        
        test_loss = test_step(args, model, test_data_loader, criterion, device)
        test_loss_his.append(test_loss)

        # 输出信息
        if epoch % args.print_loss_epoch == 0:
            my_logger.info(f'Epoch: {epoch:<5}  Train Loss: {train_loss:.4e}  Test Loss: {test_loss:.4e}')

        # 保存模型
        if (epoch % args.save_model_epoch == 0) or epoch == args.n_epoch-1:
            torch.save(model.state_dict(), f'{args.working_dir}/model/model_{epoch}.pt')
        

        # 保存loss, acc并更新图片
        if ((epoch % args.plot_loss_acc_epoch == 0) and (epoch != 0)) or (epoch == args.n_epoch-1):
            # 保存loss
            np.save(f'{args.working_dir}/loss/train_loss_his.npy', np.array(train_loss_his))
            np.save(f'{args.working_dir}/loss/test_loss_his.npy', np.array(test_loss_his))
            np.save(f'{args.working_dir}/loss/train_loss_list_along_step.npy', np.array(train_loss_list_along_step))


            # 绘制loss
            plot_loss(args.working_dir)

            # # 绘制mask和unmask的acc
            # plot_acc(args.working_dir)

            # # 绘制具体某类数据的acc
            # if np.sum(args.data_show) != 0:
            #     plot_loss_of_each_data(args.working_dir)
            #     plot_acc_of_each_data(args.working_dir)
        # save grad_norms & param_norms
        os.makedirs(f'{args.working_dir}/norms', exist_ok=True)
        grad_norms_path = f'{args.working_dir}/norms/grad_norms.npz'
        param_norms_path = f'{args.working_dir}/norms/param_norms.npz'
        smoothed_rank_path = f'{args.working_dir}/norms/smoothed_rank.npz'
        np.savez(grad_norms_path, **{k: np.array(v) for k, v in grad_norms.items()})
        np.savez(param_norms_path, **{k: np.array(v) for k, v in param_norms.items()})
        np.savez(smoothed_rank_path, **{k: np.array(v) for k, v in smoothed_rank.items()})
        my_logger.info(f'Saved grad norms to {grad_norms_path}')
        my_logger.info(f'Saved param norms to {param_norms_path}')

        os.makedirs(f'{args.working_dir}/condense_num', exist_ok=True)
        condense_num_path1 = f'{args.working_dir}/condense_num/condense_cond_num1.npy'
        condense_num_path2 = f'{args.working_dir}/condense_num/condense_cond_num2.npy'
        np.save(condense_num_path1, np.array(condense_cond_num1_list))
        np.save(condense_num_path2, np.array(condense_cond_num2_list))
    print('training finished!')



def train_wikitext(args, datas = None, train_dataloader = None, test_dataloader = None, **kwargs):
    '''
    Required:
        args: 超参数字典
        datas: 所有类型的数据集构成的字典
    '''

    args.num_batches = len(train_dataloader)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    my_logger = Log(f'{args.working_dir}/train_log.log')
    
    # 模型与参数量
    model = get_model(args, device, **kwargs)
    my_logger.info(f'Total parameters: {sum(p.numel() for p in model.parameters())}')

    criterion = nn.CrossEntropyLoss(ignore_index=-100).to(device)
    
    optimizer, scheduler = get_optimizer(model, args, **kwargs)


    # 保存参数
    save_args = dict(vars(args))
    # 将kwargs中的参数也保存
    for key, value in kwargs.items():
        save_args[key] = value
    save_to_json_noindent(save_args, f'{args.working_dir}/config.json')


    # 保存源代码
    for file in ['main_wikitext.py', 'data.py', 'train.py', 'script_wikitext.py']:
        shutil.copy(file, f'{args.working_dir}/src/{file}')
    for dir in ['utils', 'model', 'data_generator']:
        shutil.copytree(dir, f'{args.working_dir}/src/{dir}', dirs_exist_ok=True)    
    
    train_loss_his = []        # 训练集loss
    test_loss_his = []         # data_train=0的数据的总loss
    group_loss_his = []        # 每类数据的loss，其中训练数据的loss为0（因计算量过大且不是很有意义）

    acc_epoch_his = []    
    train_acc_his = []         # data_train=1的数据的总accuracy(训练集accuracy)
    test_acc_his = []          # data_train=0的数据的总accuracy
    group_acc_his = []         # 每类数据的accuracy

    train_loss_list_along_step = []
    grad_norms = defaultdict(list)
    param_norms = defaultdict(list)
    condense_cond_num1_list = []
    condense_cond_num2_list = []


    print('training...')
    torch.save(model.state_dict(), f'{args.working_dir}/model/model_ini.pt')
    for epoch in range(args.n_epoch):

        # 训练并计算loss
        train_loss, grad_norms_epoch, param_norms_epoch, condense_cond_num1_epoch, condense_cond_num2_epoch, tmp_train_loss_list = train_step_wikitext(args, model, train_dataloader, optimizer, criterion, device, args.clip, scheduler=scheduler, current_epoch = epoch, logger = my_logger)
        for key in grad_norms_epoch:
            grad_norms[key].extend(grad_norms_epoch[key])
        for key in param_norms_epoch:
            param_norms[key].extend(param_norms_epoch[key])

        condense_cond_num1_list.extend(condense_cond_num1_epoch)
        condense_cond_num2_list.extend(condense_cond_num2_epoch)
        
        test_loss = test_step(args, model, test_dataloader, criterion, device)

        train_loss_his.append(train_loss)
        test_loss_his.append(test_loss)
        train_loss_list_along_step.extend(tmp_train_loss_list)

        # 输出信息
        if epoch % args.print_loss_epoch == 0:
            my_logger.info(f'Epoch: {epoch:<5}  Train Loss: {train_loss:.4e}  Test Loss: {test_loss:.4e}')



        # 保存loss, acc并更新图片
        if ((epoch % args.plot_loss_acc_epoch == 0) and (epoch != 0)) or (epoch == args.n_epoch-1):
            # 保存loss
            np.save(f'{args.working_dir}/loss/train_loss_his.npy', np.array(train_loss_his))
            np.save(f'{args.working_dir}/loss/test_loss_his.npy', np.array(test_loss_his))
            np.save(f'{args.working_dir}/loss/group_loss_his.npy', np.array(group_loss_his))
            np.save(f'{args.working_dir}/loss/acc_epoch_his.npy', np.array(acc_epoch_his))
            np.save(f'{args.working_dir}/loss/train_acc_his.npy', np.array(train_acc_his))
            np.save(f'{args.working_dir}/loss/test_acc_his.npy', np.array(test_acc_his))
            np.save(f'{args.working_dir}/loss/group_acc_his.npy', np.array(group_acc_his))
            np.save(f'{args.working_dir}/loss/train_loss_list_along_step.npy', np.array(train_loss_list_along_step))


            # 绘制loss
            plot_loss(args.working_dir)

        # save grad_norms & param_norms
        os.makedirs(f'{args.working_dir}/norms', exist_ok=True)
        grad_norms_path = f'{args.working_dir}/norms/grad_norms.npz'
        param_norms_path = f'{args.working_dir}/norms/param_norms.npz'
        np.savez(grad_norms_path, **{k: np.array(v) for k, v in grad_norms.items()})
        np.savez(param_norms_path, **{k: np.array(v) for k, v in param_norms.items()})
        my_logger.info(f'Saved grad norms to {grad_norms_path}')
        my_logger.info(f'Saved param norms to {param_norms_path}')

        os.makedirs(f'{args.working_dir}/condense_num', exist_ok=True)
        condense_num_path1 = f'{args.working_dir}/condense_num/condense_cond_num1.npy'
        condense_num_path2 = f'{args.working_dir}/condense_num/condense_cond_num2.npy'
        np.save(condense_num_path1, np.array(condense_cond_num1_list))
        np.save(condense_num_path2, np.array(condense_cond_num2_list))
    print('training finished!')





