import copy
import logging
import torch
import types

from torch import nn
from tqdm import tqdm

from modules.eval.setup_eval import eval_ppl
from . import sec_pruner
from .sec_pruner import WrappedGPT
from .utils import *

logger = logging.getLogger(__name__)


class wanda_sp(sec_pruner):

    def __init__(self, model, config, data):
        super().__init__(model, config, data)

    def prune(self):
        func_name = self.config.task.prune.func_name
        if func_name in ['baseline_uniform_wanda_TEE']:
            self.baseline_uniform_wanda_TEE()
        elif func_name in ['critical_block_analysis']:
            self.critical_block_analysis()
        elif func_name in ['detect_critical_layer']:
            self.detect_critical_layer()
        else:
            raise Exception

    def baseline_uniform_wanda_TEE(self):
        n_samples = self.config.task.prune.prune_dataset.n_samples
        seq_len = self.config.task.prune.prune_dataset.seq_len
        pruning_ratio = self.config.task.prune.ratio
        remain_ratio = self.config.task.prune.remain_ratio
        head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads
        is_gqa = (self.model.config.num_key_value_heads < self.model.config.num_attention_heads)
        if is_gqa:
            repeat_times = self.model.config.num_attention_heads // self.model.config.num_key_value_heads
            original_kv_head_count = self.model.config.num_key_value_heads

        self.before_pruning()
        self.before_pruning_parameters = super().count_params()

        with torch.no_grad():
            inps, outs, attention_mask, position_ids, cache_position, position_embeddings = prepare_calibration_input_qwen3(
                self.get_model(), self.data, n_samples,
                seq_len)
        dict = {}

        tee_blocks = []

        def hooking(block_type, residue_block_dict, sum=True, reverse=False):
            if sum:
                logger.info('All weights are involved.')
            else:
                if reverse:
                    logger.info('only residue part are involved.')
                    logger.info(f"residue shape:{residue_block_dict['g_pruned']}")
                else:
                    logger.info('only remaining weights on model are involved.')
                    logger.info(f"residue shape:{residue_block_dict['g_pruned']}")
            if block_type == 'attn':
                raise NotImplementedError
            elif block_type == 'mlp':
                def forward(self, x):
                    down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
                    down_proj_residue = residue_block_dict['d_pruned'](
                        self.act_fn(residue_block_dict['g_pruned'](x)) * residue_block_dict['u_pruned'](x))
                    dict['activation'] = down_proj_residue
                    if sum:
                        return down_proj + down_proj_residue
                    else:
                        if reverse:
                            return down_proj_residue
                        else:
                            return down_proj
            else:
                raise NotImplementedError

            return forward

        layers = self.get_layers()

        for i in tqdm(range(len(layers)), desc="Processing layers"):
            layer = layers[i]
            subset = {}
            subset.update({self.layer_mapping['attn']['o']: find_layers(layer)[self.layer_mapping['attn']['o']]})
            subset.update({self.layer_mapping['mlp']['d']: find_layers(layer)[self.layer_mapping['mlp']['d']]})

            if f"model.layers.{i}" in getattr(self.get_model(), 'hf_device_map',
                                              {}):  ## handle the case for llama-30B and llama-65B, when the device map has multiple GPUs;
                dev = self.get_model().hf_device_map[f"model.layers.{i}"]
                inps, outs, attention_mask, position_ids, cache_position, position_embeddings = inps.to(
                    dev), outs.to(
                    dev), attention_mask.to(
                    dev), position_ids.to(dev), cache_position.to(dev), position_embeddings.to(dev)

            wrapped_layers = {}
            for name in subset:
                wrapped_layers[name] = WrappedGPT(subset[name])

            def add_batch(name):
                def tmp(_, inp, out):
                    wrapped_layers[name].add_batch(inp[0].data, out.data)

                return tmp

            residue_mlp = None
            if i not in self.config.task.prune.outlier_layers:
                for j in range(n_samples):
                    with torch.no_grad():
                        outs[j] = \
                            layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids,
                                  cache_position=cache_position, position_embeddings=position_embeddings)[
                                0]
                inps, outs = outs, inps
                torch.cuda.empty_cache()
            else:
                self.before_pruning_step()
                handles = []
                for name in wrapped_layers:
                    handles.append(subset[name].register_forward_hook(add_batch(name)))
                for j in range(n_samples):
                    with torch.no_grad():
                        outs[j] = \
                            layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids,
                                  cache_position=cache_position, position_embeddings=position_embeddings)[
                                0]
                for h in handles:
                    h.remove()

                for name in subset:
                    print(f"pruning layer {i} name {name}")

                    if name == self.layer_mapping['attn']['o']:
                        if self.config.task.prune.prune_modules in ['mha', 'all']:
                            if i == self.config.task.prune.outlier_layers[0] + 1:
                                pass
                    else:
                        if self.config.task.prune.prune_modules in ['mlp', 'all']:
                            if i == self.config.task.prune.outlier_layers[0] + 1:
                                pass
                            else:
                                W_metric = torch.neg(torch.abs(subset[name].weight.data) * torch.sqrt(
                                    wrapped_layers[name].scaler_row.reshape((1, -1))))
                                W_metric = W_metric.mean(axis=0)
                                if pruning_ratio >= 1:
                                    thresh = torch.inf
                                else:
                                    # thresh = torch.sort(W_metric.cuda())[0][math.ceil(W_metric.numel() * pruning_ratio)].cpu()
                                    thresh = torch.sort(W_metric.cuda())[0][int(W_metric.numel() * pruning_ratio)].cpu()
                                W_mask = (W_metric >= thresh)

                                self.during_pruning_step(self.substract_mlp(layer), W_mask, W_metric, thresh)

                                residue_mlp = compress_residue(layer, None, W_mask, None, None, self.model.device,
                                                               mapping=self.layer_mapping,
                                                               head_dim=head_dim, bias=False, is_gqa=is_gqa)
                                mlp_block = getattr(layer, self.layer_mapping['mlp']['block'])
                                mlp_block.forward = types.MethodType(hooking('mlp', residue_mlp, False, False),
                                                                     mlp_block)
                                tee_blocks.append(self.substract_mlp(layer))
                            print()
                        self.W_metrics[f"{i}.{name}"] = W_metric.clone()
                    wrapped_layers[name].free()
                self.after_pruning_step()

                if self.config.task.prune.act_type in ['sparse']:
                    for j in range(n_samples):
                        with torch.no_grad():
                            outs[j] = \
                                layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids,
                                      cache_position=cache_position, position_embeddings=position_embeddings)[
                                    0]
                inps, outs = outs, inps  # the pruned output as input to the next layer
            torch.cuda.empty_cache()
        self.finishing_pruning(tee_blocks)

    def critical_block_analysis(self):
        n_samples = self.config.task.prune.prune_dataset.n_samples
        seq_len = self.config.task.prune.prune_dataset.seq_len
        pruning_ratio = self.config.task.prune.ratio
        head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads
        is_gqa = (self.model.config.num_key_value_heads < self.model.config.num_attention_heads)
        if is_gqa:
            repeat_times = self.model.config.num_attention_heads // self.model.config.num_key_value_heads
            original_kv_head_count = self.model.config.num_key_value_heads

        self.before_pruning()
        self.before_pruning_parameters = super().count_params()

        with torch.no_grad():
            inps, outs, attention_mask, position_ids, cache_position, position_embeddings = prepare_calibration_input_qwen3(
                self.get_model(), self.data, n_samples,
                seq_len)

        inps_backup = inps.clone()
        outs_backup = outs.clone()
        attention_mask_backup, position_ids_backup, cache_position_backup, position_embeddings_backup = copy.deepcopy(
            attention_mask), copy.deepcopy(position_ids), copy.deepcopy(cache_position), copy.deepcopy(
            position_embeddings)

        dict = {}

        def back_up_model():
            self.model.to('cpu')
            backup_model = copy.deepcopy(self.model)
            self.model.to('cuda')
            return backup_model

        def restore_model(backup_model):
            self.model.to('cpu')
            self.model = copy.deepcopy(backup_model)
            self.model.to('cuda')

        def hooking(block_type, residue_block_dict, sum=True, reverse=False):
            if sum:
                logger.info('All weights are involved.')
            else:
                if reverse:
                    logger.info('only residue part are involved.')
                    logger.info(f"residue shape:{residue_block_dict['g_pruned']}")
                else:
                    logger.info('only remaining weights on model are involved.')
                    logger.info(f"residue shape:{residue_block_dict['g_pruned']}")
            if block_type == 'attn':
                raise NotImplementedError
            elif block_type == 'mlp':
                def forward(self, x):
                    down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
                    down_proj_residue = residue_block_dict['d_pruned'](
                        self.act_fn(residue_block_dict['g_pruned'](x)) * residue_block_dict['u_pruned'](x))
                    dict['activation'] = down_proj_residue
                    if sum:
                        return down_proj + down_proj_residue
                    else:
                        if reverse:
                            return down_proj_residue
                        else:
                            return down_proj
            else:
                raise NotImplementedError

            return forward

        model_backup = back_up_model()
        layer_count = len(self.get_layers())

        for detect_index in tqdm(range(layer_count), desc='Analysis of outlier saliency'):
            inps = inps_backup.clone()
            outs = outs_backup.clone()
            attention_mask = copy.deepcopy(attention_mask_backup)
            position_ids = copy.deepcopy(position_ids_backup)
            cache_position = copy.deepcopy(cache_position_backup)
            position_embeddings = copy.deepcopy(position_embeddings_backup)

            restore_model(model_backup)
            layers = self.get_layers()
            # backup_layer = layers[detect_index]
            # layers[detect_index] = nn.Identity()
            for i in tqdm(range(detect_index + 1), desc="Processing layers"):

                layer = layers[i]
                subset = {}
                subset.update({self.layer_mapping['attn']['o']: find_layers(layer)[self.layer_mapping['attn']['o']]})
                subset.update({self.layer_mapping['mlp']['d']: find_layers(layer)[self.layer_mapping['mlp']['d']]})

                if f"model.layers.{i}" in getattr(self.get_model(), 'hf_device_map',
                                                  {}):  ## handle the case for llama-30B and llama-65B, when the device map has multiple GPUs;
                    dev = self.get_model().hf_device_map[f"model.layers.{i}"]
                    inps, outs, attention_mask, position_ids, cache_position, position_embeddings = inps.to(
                        dev), outs.to(
                        dev), attention_mask.to(
                        dev), position_ids.to(dev), cache_position.to(dev), position_embeddings.to(dev)

                wrapped_layers = {}
                for name in subset:
                    wrapped_layers[name] = WrappedGPT(subset[name])

                def add_batch(name):
                    def tmp(_, inp, out):
                        wrapped_layers[name].add_batch(inp[0].data, out.data)

                    return tmp

                residue_mlp = None
                if i != detect_index:
                    for j in range(n_samples):
                        with torch.no_grad():
                            outs[j] = \
                                layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids,
                                      cache_position=cache_position, position_embeddings=position_embeddings)[
                                    0]
                    inps, outs = outs, inps
                    torch.cuda.empty_cache()
                else:
                    self.before_pruning_step()
                    handles = []
                    for name in wrapped_layers:
                        handles.append(subset[name].register_forward_hook(add_batch(name)))
                    for j in range(n_samples):
                        with torch.no_grad():
                            outs[j] = \
                                layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids,
                                      cache_position=cache_position, position_embeddings=position_embeddings)[
                                    0]
                    for h in handles:
                        h.remove()

                    for name in subset:
                        print(f"pruning layer {i} name {name}")

                        if name == self.layer_mapping['attn']['o']:
                            if self.config.task.prune.prune_modules in ['mha', 'all']:
                                pass
                        else:
                            if self.config.task.prune.prune_modules in ['mlp', 'all']:
                                W_metric = torch.neg(torch.abs(subset[name].weight.data) * torch.sqrt(
                                    wrapped_layers[name].scaler_row.reshape((1, -1))))
                                W_metric = W_metric.mean(axis=0)
                                if pruning_ratio >= 1:
                                    thresh = torch.inf
                                else:
                                    thresh = torch.sort(W_metric.cuda())[0][int(W_metric.numel() * pruning_ratio)].cpu()
                                W_mask = (W_metric >= thresh)
                                self.during_pruning_step(self.substract_mlp(layer), W_mask, W_metric, thresh)
                                residue_mlp = compress_residue(layer, None, W_mask, None, None, self.model.device,
                                                               mapping=self.layer_mapping,
                                                               head_dim=head_dim, bias=False, is_gqa=is_gqa)
                                mlp_block = getattr(layer, self.layer_mapping['mlp']['block'])
                                mlp_block.forward = types.MethodType(hooking('mlp', residue_mlp, False, False),
                                                                     mlp_block)
                            self.W_metrics[f"{i}.{name}"] = W_metric.clone()
                        wrapped_layers[name].free()
                    self.after_pruning_step()

                    if self.config.task.prune.act_type in ['sparse']:
                        for j in range(n_samples):
                            with torch.no_grad():
                                outs[j] = \
                                    layer(inps[j].unsqueeze(0), attention_mask=attention_mask,
                                          position_ids=position_ids,
                                          cache_position=cache_position, position_embeddings=position_embeddings)[
                                        0]
                    inps, outs = outs, inps  # the pruned output as input to the next layer
                torch.cuda.empty_cache()
            self.finishing_pruning(detect_index)
            # layers[detect_index] = backup_layer

    def detect_critical_layer(self):
        # firstly search for the critical_layer
        # current_cos_sim, current_std_dis, current_l2_dis = self.obtain_information()
        # show(current_cos_sim, 'cosine_sim', f'cosine_sim in {self.model_name}')
        # show(current_std_dis, 'abs( std(in) - std(out) )', f'std change in {self.model_name}')
        # show(current_l2_dis, 'l2_distance', f'l2_distance in {self.model_name}')
        # critical_layer_index = torch.argmax(current_std_dis)
        # order = greedy_by_value(current_cos_sim, critical_layer_index)
        # print(order)

        layers = self.get_layers()
        total = len(layers)

        records = []
        cos0, std0, _ = self.obtain_information()
        critical = int(torch.argmax(std0).item())
        logger.info(f"critical layer index is: {critical}.")
        current_block = [critical]

        while len(current_block) < total:
            cos_sim, _ = self.obtain_information_block(block_record=current_block)

            br = set(current_block)
            blocks = []
            i = 0
            while i < total:
                if i in br:
                    blk = [i]
                    j = i + 1
                    while j < total and j in br:
                        blk.append(j)
                        j += 1
                    blocks.append(blk)
                    i = j
                else:
                    blocks.append([i])
                    i += 1

            sorted_blk = sorted(current_block)
            for b_idx, blk in enumerate(blocks):
                if blk == sorted_blk:
                    pos = b_idx
                    break
            candidates = []
            if pos > 0:
                left_layer = blocks[pos - 1][0]
                candidates.append((cos_sim[pos - 1].item(), left_layer))
            if pos < len(blocks) - 1:
                right_layer = blocks[pos + 1][0]
                candidates.append((cos_sim[pos + 1].item(), right_layer))

            _, next_layer = min(candidates, key=lambda t: t[0])
            current_block.append(next_layer)
            records.append(copy.deepcopy(current_block))
        logger.info("critical layer and cosine similarity processed.")
        logger.info(f"The final order of protection is: {records}.")
        return records

    def count_params(self, block_lists):
        layer_params = 0
        for l in block_lists:
            layer_params += sum(p.numel() for p in l.parameters())
        return layer_params

    def finishing_pruning(self, block_lists):
        super().finishing_pruning()
        # tee_count = self.count_params(block_lists)
        # logger.info(f"Parameters in TEE: {tee_count}")
        # logger.info(f"TEE parameters ratio: {tee_count / self.before_pruning_parameters}")
        # self.evaluation(tee_count / self.before_pruning_parameters)
        self.block_lists = block_lists
        self.load_external_weights()
        tee_count = self.count_params(block_lists)
        self.evaluation(tee_count / self.before_pruning_parameters)

    def load_external_weights(self):
        if self.config.task.prune.external_weight_config.enable:
            outlier_mlp = self.block_lists[0]
            device = next(outlier_mlp.parameters()).device
            external_weight_dict = torch.load(self.config.task.prune.external_weight_config.path,
                                              map_location=device, weights_only=False)
            default_dtype = next(outlier_mlp.parameters()).dtype

            def preprocess_layers(target, external_weight_dict):
                for name, layer in external_weight_dict.items():
                    mapping_name = f'{name[0]}_name'
                    if mapping_name in self.layer_mapping['mlp']:
                        target_name = self.layer_mapping['mlp'][mapping_name]
                        if isinstance(layer, torch.Tensor):
                            layer = layer.to(dtype=default_dtype)
                        elif isinstance(layer, nn.Parameter):
                            layer = nn.Parameter(layer.data.to(dtype=default_dtype))
                        elif isinstance(layer, nn.Module):
                            for param in layer.parameters():
                                param.data = param.data.to(dtype=default_dtype)
                        setattr(target, target_name, layer)
                    else:
                        logger.info(f'unknown key detected and skipped: {name}. ')

            def preprocess_blocks(decoder_class, external_weight_dict):
                config = external_weight_dict['config']
                state_dict = external_weight_dict['state_dict']
                try:
                    layer_count = max([int(name[0]) for name in state_dict])
                except ValueError:
                    layer_count = 0
                    new_state_dict = {}
                    for key, value in state_dict.items():
                        new_state_dict[f"0.{key}"] = value
                    state_dict = new_state_dict
                decoders = nn.ModuleList([decoder_class(config, layer_idx) for layer_idx in
                                          range(99, 100 + layer_count)])
                decoders.eval()
                decoders.load_state_dict(state_dict)
                decoders = decoders.to(device)
                decoders = decoders.to(dtype=default_dtype)
                return decoders

            if self.config.task.prune.external_weight_config.target in ['decoder_layer']:
                layers = self.get_layers()
                mlp_class = type(self.substract_mlp(layers[0]))
                decoder_class = type(layers[0])
                saved_layers = []
                for i in range(len(layers)):
                    if i in self.config.task.prune.external_weight_config.index:
                        saved_layers.append(copy.deepcopy(layers[i]))
                        layers[i] = nn.Identity()
                if self.config.task.prune.external_weight_config.arch in ['mlp']:
                    mlp = mlp_class(external_weight_dict['config'])
                    preprocess_layers(mlp, external_weight_dict)
                    layers[self.config.task.prune.external_weight_config.index[0]] = mlp
                elif self.config.task.prune.external_weight_config.arch in ['decoder_layer']:
                    decoders = preprocess_blocks(decoder_class, external_weight_dict)
                    layers[self.config.task.prune.external_weight_config.index[0]] = decoders
                elif self.config.task.prune.external_weight_config.arch in ['mlp_with_rms_norm']:
                    mlp = mlp_class(external_weight_dict['config'])
                    preprocess_layers(mlp, external_weight_dict)
                    norm_func = saved_layers[0].input_layernorm

                    def normed_forward(norm_func):
                        def forward(self, x):
                            x = norm_func(x)
                            down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
                            return down_proj

                        return forward

                    mlp.forward = types.MethodType(normed_forward(norm_func), mlp)
                    layers[self.config.task.prune.external_weight_config.index[0]] = mlp
            else:
                # replace the outlier block first
                outlier_mlp = self.block_lists[0]
                preprocess_layers(outlier_mlp, external_weight_dict)

    def get_imps(self):
        return self.W_metrics

    def step(self):
        pass
