import copy
import logging
import math
from abc import abstractmethod

import deepspeed
from tqdm import tqdm
import torch

from .utils import *
from tasks.pruning.pruners import Pruner
from modules.eval.setup_eval import eval_ppl, eval_lm_eval

LAYER_NAME_MAPPING = {
    'Llama': {
        'attn': {
            'q': 'self_attn.q_proj',
            'k': 'self_attn.k_proj',
            'v': 'self_attn.v_proj',
            'o': 'self_attn.o_proj',
            'q_name': 'q_proj',
            'k_name': 'k_proj',
            'v_name': 'v_proj',
            'o_name': 'o_proj',
            'block': 'self_attn'
        },
        'mlp': {
            'd': 'mlp.down_proj',
            'g': 'mlp.gate_proj',
            'u': 'mlp.up_proj',
            'd_name': 'down_proj',
            'g_name': 'gate_proj',
            'u_name': 'up_proj',
            'block': 'mlp'
        },
        'layers': 'model.layers'
    },
    'Mistral': {
        'attn': {
            'q': 'self_attn.q_proj',
            'k': 'self_attn.k_proj',
            'v': 'self_attn.v_proj',
            'o': 'self_attn.o_proj',
            'q_name': 'q_proj',
            'k_name': 'k_proj',
            'v_name': 'v_proj',
            'o_name': 'o_proj',
            'block': 'self_attn'
        },
        'mlp': {
            'd': 'mlp.down_proj',
            'g': 'mlp.gate_proj',
            'u': 'mlp.up_proj',
            'd_name': 'down_proj',
            'g_name': 'gate_proj',
            'u_name': 'up_proj',
            'block': 'mlp'
        },
        'layers': 'model.layers'
    },
    'Qwen': {
        'attn': {
            'q': 'self_attn.q_proj',
            'k': 'self_attn.k_proj',
            'v': 'self_attn.v_proj',
            'o': 'self_attn.o_proj',
            'q_name': 'q_proj',
            'k_name': 'k_proj',
            'v_name': 'v_proj',
            'o_name': 'o_proj',
            'block': 'self_attn'
        },
        'mlp': {
            'd': 'mlp.down_proj',
            'g': 'mlp.gate_proj',
            'u': 'mlp.up_proj',
            'd_name': 'down_proj',
            'g_name': 'gate_proj',
            'u_name': 'up_proj',
            'block': 'mlp'
        },
        'layers': 'model.layers'
    },
    'phi-2': {
        'attn': {
            'q': 'self_attn.q_proj',
            'k': 'self_attn.k_proj',
            'v': 'self_attn.v_proj',
            'o': 'self_attn.dense',
            'q_name': 'q_proj',
            'k_name': 'k_proj',
            'v_name': 'v_proj',
            'o_name': 'dense',
            'block': 'self_attn'
        },
        'mlp': {
            'u': 'mlp.fc1',
            'd': 'mlp.fc2',
            'u_name': 'fc1',
            'd_name': 'fc2',
            'block': 'mlp'
        },
        'layers': 'model.layers'
    },
    'opt': {
        'attn': {
            'q': 'self_attn.q_proj',
            'k': 'self_attn.k_proj',
            'v': 'self_attn.v_proj',
            'o': 'self_attn.out_proj',
            'block': 'self_attn'
        },
        'mlp': {
            'u': 'fc1',
            'd': 'fc2',
            'block': ''
        },
        'layers': 'base_model.decoder.layers'
    },
    # 'gpt2': {
    #     'o': 'attn.c_proj',
    #     'd': 'mlp.c_proj'
    # }

}

logger = logging.getLogger(__name__)


class sec_pruner(Pruner):
    def __init__(self, model, config, data):
        super().__init__(model, config, data)
        self.before_pruning_parameters = None
        self.use_cache = None
        self.train_data = None
        self.eval_data = None
        for k, _ in LAYER_NAME_MAPPING.items():
            if k in config.model.name:
                self.layer_mapping = LAYER_NAME_MAPPING[k]
                self.model_arch = k
                self.model_name = config.model.alias
                break
        if self.layer_mapping is None:
            raise Exception(f'model {config.model.name} is not supported yet')
        self.W_metrics = {}
        self.tokenizer = None
        self.data_processed = False
        self.block_lists = None
        self.gqa_mask_record = None

    def obtain_information(self, input=None, stop_index=None):
        n_samples = self.config.task.prune.prune_dataset.n_samples
        seq_len = self.config.task.prune.prune_dataset.seq_len
        self.model.eval()
        use_cache = self.model.config.use_cache
        self.model.config.use_cache = False

        with torch.no_grad():
            inps, outs, attention_mask, position_ids, cache_position, position_embeddings = prepare_calibration_input_qwen3(
                self.get_model(), self.data, n_samples,
                seq_len)
        layers = self.get_layers()
        if input is not None:
            inps = input
        if stop_index is None:
            stop_index = len(layers)
        elif stop_index == 0:
            return None, None

        def forward_layer(layer, inputs):
            with torch.no_grad():
                if isinstance(layer, nn.Identity):
                    outputs = layer(inputs)
                else:
                    outputs = inputs.detach().clone()
                    for j in range(n_samples):
                        outputs[j] = \
                            layer(inputs[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids,
                                  cache_position=cache_position, position_embeddings=position_embeddings)[
                                0]
            return outputs

        current_cos_sim = []
        current_std_dis = []
        current_l2_dis = []
        inputs = inps.detach().clone()
        for j in tqdm(range(0, stop_index), desc="Obtaining following layers' cosine similarity"):
            current_layer = layers[j]
            outputs = forward_layer(current_layer, inputs)
            current_cos_sim.append(cosine_similarity(inputs, outputs)[2])
            l2, l2_t = l2_distance(inputs.float(), outputs.float())
            current_l2_dis.append(l2)
            std_v, std_t = std(inputs.float(), outputs.float())
            current_std_dis.append(std_v)
            inputs, outputs = outputs, inputs
        current_cos_sim = torch.tensor(current_cos_sim).float()
        current_l2_dis = torch.tensor(current_l2_dis).float()
        current_std_dis = torch.tensor(current_std_dis).float()
        self.model.config.use_cache = use_cache
        return current_cos_sim, current_std_dis, current_l2_dis

    def obtain_information_block(self, input=None, stop_index=None, block_record=None):
        if block_record is None:
            block_record = []
        n_samples = self.config.task.prune.prune_dataset.n_samples
        seq_len = self.config.task.prune.prune_dataset.seq_len
        self.model.eval()
        use_cache = self.model.config.use_cache
        self.model.config.use_cache = False

        with torch.no_grad():
            inps, outs, attention_mask, position_ids, cache_position, position_embeddings = prepare_calibration_input_qwen3(
                self.get_model(), self.data, n_samples,
                seq_len)
        layers = self.get_layers()
        if input is not None:
            inps = input
        if stop_index is None:
            stop_index = len(layers)
        elif stop_index == 0:
            return None, None

        def forward_layer(layer, inputs):
            with torch.no_grad():
                if isinstance(layer, nn.Identity):
                    outputs = layer(inputs)
                else:
                    outputs = inputs.detach().clone()
                    for j in range(n_samples):
                        outputs[j] = \
                            layer(inputs[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids,
                                  cache_position=cache_position, position_embeddings=position_embeddings)[
                                0]
            return outputs

        block_idxs = []
        for idx in block_record:
            if isinstance(idx, torch.Tensor):
                block_idxs.append(int(idx.item()))
            else:
                block_idxs.append(int(idx))
        br = set(block_idxs)

        blocks = []
        i = 0
        while i < stop_index:
            if i in br:
                block = [i]
                j = i + 1
                while j < stop_index and j in br:
                    block.append(j)
                    j += 1
                blocks.append(block)
                i = j
            else:
                blocks.append([i])
                i += 1

        block_cos, block_std = [], []
        inputs = inps.detach().clone()
        for block in blocks:
            out = inputs
            for layer_idx in block:
                out = forward_layer(layers[layer_idx], out)

            block_cos.append(cosine_similarity(inputs, out)[2])
            std_v, _ = std(inputs.float(), out.float())
            block_std.append(std_v)

            inputs = out.detach().clone()
        self.model.config.use_cache = use_cache
        return torch.tensor(block_cos), torch.tensor(block_std)

    def get_layers(self):
        return nested_getattr(self.get_model(), self.layer_mapping['layers'])

    def substract_mlp(self, layer):
        return getattr(layer, self.layer_mapping['mlp']['block'])

    def substract_attn(self, layer):
        return getattr(layer, self.layer_mapping['attn']['block'])

    def substract_layer(self, index):
        return self.get_layers()[index]

    @abstractmethod
    def prune(self):
        pass

    @abstractmethod
    def step(self):
        pass

    @abstractmethod
    def get_imps(self):
        pass

    def evaluation(self, current_ratio, title=None):
        eval_ppl(self.get_wrapped_model(), self.tokenizer, save=True,
                 save_path=title if title is not None else f'{self.config.task.output_folder}/sp_{current_ratio}_ppl.pth')
        eval_lm_eval(self.get_wrapped_model(), self.tokenizer, self.config,
                     title if title is not None else f'sp_{current_ratio}_lm_eval.pth')
        return

    def before_pruning(self):
        self.get_wrapped_model().eval()
        self.use_cache = self.get_model_config().use_cache
        self.get_model_config().use_cache = False

    def after_pruning_step(self, current_ratio=None, real_pruning=True):
        pass

    def finishing_pruning(self):
        self.get_model_config().use_cache = self.use_cache
        torch.cuda.empty_cache()

    def check_unstr_sparsity(self):
        logger.info("*" * 30)
        count = 0
        total_params = 0
        layers = self.get_layers()
        block = self.config.task.prune.prune_modules
        for i in range(len(layers)):
            layer = layers[i]

            if block in ['mlp']:
                mlp_block = self.substract_mlp(layer)
                subset = find_layers(mlp_block)
            elif block in ['mha']:
                attn_block = self.substract_attn(layer)
                subset = find_layers(attn_block)
            else:
                subset = find_layers(layer)
            sub_count = 0
            sub_params = 0
            for name in subset:
                W = subset[name].weight.data
                count += (W == 0).sum().item()
                total_params += W.numel()

                sub_count += (W == 0).sum().item()
                sub_params += W.numel()

            logger.info(f"layer {i} sparsity {float(sub_count) / sub_params:.6f}")
        logger.info("*" * 30)
        return float(count) / total_params, count, total_params - count

    def count_params(self):
        layer_params = 0
        layers = self.get_layers()
        block = self.config.task.prune.prune_modules
        if block in ['mlp']:
            for l in layers:
                mlp_block = self.substract_mlp(l)
                layer_params += sum(p.numel() for p in mlp_block.parameters())
        elif block in ['mha']:
            for l in layers:
                attn_block = self.substract_attn(l)
                layer_params += sum(p.numel() for p in attn_block.parameters())
        else:
            for l in layers:
                layer_params += sum(p.numel() for p in l.parameters())
        return layer_params

    def real_metrics_mapping(self):
        mapping = {'first': '1st', 'second': '2rd', 'mix': 'mix', 'grad': 'grad', 'weight': 'weight'}
        return mapping[self.config.task.prune.real_metrics]


class WrappedGPT:
    """
    This class wraps a GPT layer for specific operations.
    """

    def __init__(self, layer, layer_id=0, layer_name="none"):
        self.layer = layer
        self.dev = self.layer.weight.device
        self.rows = layer.weight.data.shape[0]
        self.columns = layer.weight.data.shape[1]

        self.scaler_row = torch.zeros((self.columns), device=self.dev)
        self.nsamples = 0
        self.feb_output = None
        self.layer_id = layer_id
        self.layer_name = layer_name

    def add_batch(self, inp, out):
        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)
        tmp = inp.shape[0]
        if isinstance(self.layer, nn.Linear):
            if len(inp.shape) == 3:
                inp = inp.reshape((-1, inp.shape[-1]))
            inp = inp.t()

        self.scaler_row *= self.nsamples / (self.nsamples + tmp)
        self.nsamples += tmp
        inp = inp.type(torch.float32)
        self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2 / self.nsamples

    def free(self):
        self.scaler_row = None
        torch.cuda.empty_cache()
