import os
import time
import json
from tqdm import tqdm

import torch
from torch.utils.data import Dataset, DataLoader

from utils.data_utils import get_trainloaders
from utils.eval_utils import load_and_eval_ppl, eval_zero_shot
from utils.util import *


def block_influence(
    input_hidden_state: torch.Tensor,
    output_hidden_state: torch.Tensor,
    angular=False,
):
    """
    input_hidden_state: B, S, D
    output_hidden_state: B, S, D
    """
    _, _, d = input_hidden_state.shape
    input_hidden_state = input_hidden_state.reshape(-1, d)
    output_hidden_state = output_hidden_state.reshape(-1, d)

    norm_input = input_hidden_state.norm(dim=-1, keepdim=True)
    norm_output = output_hidden_state.norm(dim=-1, keepdim=True)

    sim = (input_hidden_state @ output_hidden_state.T) / (norm_input * norm_output)
    sim = sim.diagonal().nan_to_num(nan=0.5)

    sim = torch.clamp(sim, min=0, max=1)

    if angular:
        return (torch.arccos(sim) / torch.pi)

    # assert ((1 - sim) <= 1).any().item()
    return 1 - sim


def compute_bi(hiddens: List[torch.Tensor], angular: bool=False, n_prune_layers=None):
    importances = [0 for _ in range(len(hiddens)-1)]  # layer-wise importance scores
    n = 1
    if angular:
        assert n_prune_layers is not None, "Set number of layers to prune to use angular importance"
        n = n_prune_layers

    for i in range(len(hiddens) - n):
        in_hidden = hiddens[i]
        out_hidden = hiddens[i+n]
        if angular:
            # use only last token for angular distance as described in section 3.2
            # https://arxiv.org/pdf/2403.17887.pdf
            in_hidden = in_hidden[:,-1:]
            out_hidden = out_hidden[:,-1:]
        
        importances[i] += block_influence(
            in_hidden,
            out_hidden,
            angular=angular
        ).sum().cpu().item()
    
    return importances

@torch.inference_mode()
def get_layer_importance(model, dataloader, bs=1, device=None, angular=False):
    if device == None:
        device = model.device
    # Get input IDs
    testenc = dataloader.input_ids
    testattn = dataloader.attention_mask

    # Calculate number of samples
    nsamples = testenc.numel() // model.seqlen
    
    layer_importance = [0 for _ in range(model.config.num_hidden_layers)]

    # Loop through each batch
    for i in tqdm(range(0, nsamples, bs), desc="BI Score calculating ..."):
        # Calculate end index
        j = min(i+bs, nsamples)
        # Prepare inputs and move to device
        inputs = testenc[:, (i * model.seqlen):(j * model.seqlen)].to(device)
        inputs = inputs.reshape(j-i, model.seqlen)

        attn = testattn[:, (i * model.seqlen):(j * model.seqlen)].to(device)
        attn = attn.reshape(j-i, model.seqlen)

        # Forward pass through the model
        hidden_states = model(inputs, attention_mask=attn, output_hidden_states=True).hidden_states

        layer_importance = list_item_sum(compute_bi(hidden_states, angular=angular), layer_importance)

    layer_importance_average = [i/len(layer_importance) for i in layer_importance]

    return layer_importance_average

def main_func(args, modelhander):
    
    # 原代码的加载数据集的方式
    dataloader = get_trainloaders(args.calibration_dataset,
                                  tokenizer=modelhander.tokenizer,
                                  nsamples=args.nsamples,
                                  seed=args.seed,
                                  seqlen=modelhander.model.seqlen
                                  )

    shortgpt_info = {}
    logging.info(f"Dataloader({args.calibration_dataset}) loaded.")

    # 2. 如果不开启迭代重算，就走你原来的逻辑（保持兼容）
    if not getattr(args, "iterative_depth", False):
        logging.info("Compute the BI importance score of layers (one-shot).")
        start_time = time.time()
        layer_importance = get_layer_importance(
            model=modelhander.model,
            dataloader=dataloader,
            bs=1,
            device=modelhander.model.device,
            angular=False
        )
        total_time = time.time() - start_time
        logging.info(f"Total time elapsed (s): {total_time}")

        sorted_indices = sorted(
            range(len(layer_importance)),
            key=lambda i: layer_importance[i],
            reverse=False
        )
        logging.info("The BI score of Layers:")
        for idx, score in zip(list(range(modelhander.config.num_hidden_layers)), layer_importance):
            logging.info(f"{idx:>3} → {score}")

        logging.info(f"Remove order of layer index: \n{sorted_indices}")

        shortgpt_info["layer_importance"] = layer_importance
        shortgpt_info["sorted_indices"] = sorted_indices
        shortgpt_info["time_elapsed"] = total_time

        num_remove_blocks = modelhander.config.num_hidden_layers - args.target_layers

        if args.continue_saving:
            logging.info("Continues saving sub-model (one-shot order, no re-compute BI).")
            removal_list = []
            for idx in sorted_indices[:num_remove_blocks]:
                removal_list.append(idx)
                logging.info(
                    "*" * 10
                    + f" Layers {modelhander.model.config.num_hidden_layers} "
                    + f"→ {modelhander.model.config.num_hidden_layers-1}. "
                      f"Remove layer index (original space) is {idx}"
                    + "*" * 10
                )
                # 修正当前模型 index
                cut_num = len([x for x in removal_list if x < idx])
                cur_idx = idx - cut_num
                modelhander.remove_layers(removal_list=cur_idx)

                save_path = os.path.join(
                    args.save_path,
                    args.save_name + f"_{modelhander.model.config.num_hidden_layers}"
                )
                modelhander.save(path=save_path)
                key = f"Layers_{modelhander.model.config.num_hidden_layers}"
                shortgpt_info[key] = {
                    "remove_layer_idx_list_original": copy.copy(removal_list)
                }
                logging.info(f"Save model to {save_path}")
                for da in args.ppl_data:
                    logging.info(f"Starting {da} PPL evaluation...")
                    ppl = load_and_eval_ppl(
                        modelhander.model,
                        dataset=da,
                        tokenizer=modelhander.tokenizer
                    )
                    shortgpt_info[key][f"{da}_ppl"] = ppl

        else:
            removal_list = sorted_indices[:num_remove_blocks]
            logging.info(f"Remove layer index list is {removal_list}")
            modelhander.remove_layers(removal_list=removal_list)
            save_path = os.path.join(
                args.save_path,
                args.save_name + f"_{modelhander.model.config.num_hidden_layers}"
            )
            modelhander.save(path=save_path)
            logging.info(f"Save model to {save_path}")
            key = f"Layers_{modelhander.model.config.num_hidden_layers}"
            shortgpt_info[key] = {"remove_layer_idx_list_original": removal_list}
            for da in args.ppl_data:
                logging.info(f"Starting {da} PPL evaluation...")
                ppl = load_and_eval_ppl(
                    modelhander.model,
                    dataset=da,
                    tokenizer=modelhander.tokenizer
                )
                shortgpt_info[key][f"{da}_ppl"] = ppl

    # 3. 开启迭代重算 BI 的 depth 剪枝
    else:
        logging.info(">>> Iterative depth pruning with BI (one layer per iteration, recompute BI each round).")

        orig_num_layers = modelhander.config.num_hidden_layers
        target_layers = args.target_layers
        assert target_layers > 0 and target_layers <= orig_num_layers, \
            f"target_layers should be in (0, {orig_num_layers}], got {target_layers}"

        num_remove_blocks = orig_num_layers - target_layers
        logging.info(
            f"Original layers: {orig_num_layers}, "
            f"target layers: {target_layers}, "
            f"num_remove_blocks: {num_remove_blocks}"
        )

        # 当前模型层索引 -> 原始层索引 的映射
        current2orig = list(range(orig_num_layers))
        shortgpt_info["iterative_depth"] = True
        shortgpt_info["orig_num_layers"] = orig_num_layers
        shortgpt_info["target_layers"] = target_layers
        shortgpt_info["mapping_current2orig_after_each_iter"] = {}

        for it in range(num_remove_blocks):
            cur_L = modelhander.config.num_hidden_layers
            logging.info("=" * 80)
            logging.info(
                f"[Iter {it}] current #layers = {cur_L}. "
                f"Will compute BI and prune 1 layer."
            )

            # 3.1 当前模型上重新算一遍 layer importance
            start_time = time.time()
            layer_importance = get_layer_importance(
                model=modelhander.model,
                dataloader=dataloader,
                bs=1,
                device=modelhander.model.device,
                angular=False
            )
            elapsed = time.time() - start_time
            logging.info(f"[Iter {it}] BI computed. time elapsed (s): {elapsed}")

            # 3.2 选出当前模型中最不重要的一层（index 是“当前模型空间”的 index）
            sorted_indices = sorted(
                range(len(layer_importance)),
                key=lambda i: layer_importance[i],
                reverse=False
            )
            prune_idx_cur = sorted_indices[0]  # 当前模型里的 index
            prune_idx_orig = current2orig[prune_idx_cur]  # 对应原始模型里的 index

            logging.info("[Iter {}] Layer importance (current model space):".format(it))
            for idx, score in zip(range(len(layer_importance)), layer_importance):
                logging.info(f"{idx:>3} (orig {current2orig[idx]:>3}) → {score}")

            logging.info(
                f"[Iter {it}] Will remove layer idx={prune_idx_cur} "
                f"(original idx={prune_idx_orig}), "
                f"BI score={layer_importance[prune_idx_cur]}"
            )

            # 3.3 记录本轮信息
            key = f"iter_{it}"
            shortgpt_info[key] = {
                "current_num_layers_before": cur_L,
                "remove_layer_idx_current": prune_idx_cur,
                "remove_layer_idx_original": prune_idx_orig,
                "layer_importance_current": layer_importance,
                "time_elapsed": elapsed,
            }

            # 3.4 真正删除当前这一层
            modelhander.remove_layers(removal_list=prune_idx_cur)

            # 3.5 更新 current2orig 映射（弹掉这个位置）
            current2orig.pop(prune_idx_cur)
            new_L = modelhander.config.num_hidden_layers
            assert new_L == cur_L - 1, "remove_layers 应该每次删 1 层"

            shortgpt_info["mapping_current2orig_after_each_iter"][f"iter_{it}"] = current2orig.copy()

            # 3.6 保存模型 & 评估
            save_path = os.path.join(
                args.save_path,
                args.save_name + f"_{new_L}"
            )
            modelhander.save(path=save_path)
            logging.info(f"[Iter {it}] Save pruned model to {save_path}")

            shortgpt_info[key]["save_path"] = save_path
            shortgpt_info[key]["current_num_layers_after"] = new_L

            for da in args.ppl_data:
                logging.info(f"[Iter {it}] Starting {da} PPL evaluation...")
                ppl = load_and_eval_ppl(
                    modelhander.model,
                    dataset=da,
                    tokenizer=modelhander.tokenizer
                )
                shortgpt_info[key][f"{da}_ppl"] = ppl

        logging.info(
            f"Iterative depth pruning finished. "
            f"Final #layers = {modelhander.config.num_hidden_layers}"
        )

    # 4. 保存 shortgpt_info
    logging.info(
        f"Save {args.method} information to "
        f"{os.path.join(args.save_path, f'{args.method}_info.json')}"
    )
    with open(os.path.join(args.save_path, f'{args.method}_info.json'), 'w') as json_file:
        json.dump(shortgpt_info, json_file, indent=4)


