import torch 
import torch.nn as nn 
from Pruner.datasets.train_data import get_wanda_c4

class WrappedGPT:
    """
    This class wraps a GPT layer for specific operations.
    """
    def __init__(self, layer, layer_id=0, layer_name="none"):
        self.layer = layer
        self.dev = self.layer.weight.device
        self.rows = layer.weight.data.shape[0]
        self.columns = layer.weight.data.shape[1]
        self.scaler_row = torch.zeros((self.columns), device=self.dev)
        self.nsamples = 0
        self.fluc_inp = torch.zeros((self.columns), device=self.dev)
        self.baseline_inp = torch.zeros((self.columns), device=self.dev)

    def add_batch(self, inp, out):
        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)
        # tmp: batch size
        tmp = inp.shape[0]
        if isinstance(self.layer, nn.Linear):
            if len(inp.shape) == 3:
                inp = inp.reshape((-1, inp.shape[-1]))
            inp = inp.t()
        
        old_baseline_inp = self.baseline_inp
        self.baseline_inp *= self.nsamples / (self.nsamples+tmp)
        self.baseline_inp += torch.mean(inp, dim=1) / (self.nsamples + tmp)

        if self.nsamples == 0:
            self.fluc_inp = 0
        else:
            self.fluc_inp *= (self.nsamples - 1) / (self.nsamples + tmp - 1)
            self.fluc_inp += torch.sum((inp - self.baseline_inp.unsqueeze(1)) * (inp - old_baseline_inp.unsqueeze(1)), dim=1) / (self.nsamples+tmp)

        self.scaler_row *= self.nsamples / (self.nsamples + tmp)
        self.nsamples += tmp

        inp = inp.type(torch.bfloat16)
        self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2  / self.nsamples

    def free(self):
        self.scaler_row = None
        self.baseline_inp = None
        if hasattr(self, 'fluc_inp'):
            self.fluc_inp = None
        torch.cuda.empty_cache()


def find_layers(module, layers=[nn.Linear], name=''):
    """
    Recursively find the layers of a certain type in a module.

    Args:
        module (nn.Module): PyTorch module.
        layers (list): List of layer types to find.
        name (str): Name of the module.

    Returns:
        dict: Dictionary of layers of the given type(s) within the module.
    """
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(find_layers(
            child, layers=layers, name=name + '.' + name1 if name != '' else name1
        ))
    return res


def prepare_calibration_input(model_name, model, dataloader, device):
    layers = model.layers

    dtype = next(iter(model.parameters())).dtype
    inps = torch.zeros((128, 128, model.cfg.d_model), dtype=dtype, device=device)
    inps.requires_grad = False
    cache = {'i': 0, 'mask': None, "freqs_cis": None}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module
        def instantation_forward(self, inp, start_pos, freqs_cis, mask):
        # def instantation_forward(self, inp, start_pos, mask_module, outdated_zs):
            inps[cache['i']] = inp.view(inps.shape[1], inps.shape[2])
            cache['i'] += 1
            cache['mask'] = mask
            cache['freqs_cis'] = freqs_cis
            raise ValueError
        def forward(self, inp, attention_mask=None, layer_head_mask=None, past_key_value=None, output_attentions=None, use_cache=None, intermediate_mask=None):
            inps[cache['i']] = inp.view(inps.shape[1], inps.shape[2])
            cache['i'] += 1
            cache['mask'] = attention_mask
            raise ValueError
    layers[0] = Catcher(layers[0])
    for batch in dataloader:
        try:
            if 'llama' in model_name:
                model.instantation_forward(batch.to(device), 0, None, None) #？
            elif 'opt' in model_name:
                model(batch.to(device))
            # model.no_mask_infer(batch[0].to(device))
        except ValueError:
            pass 
    layers[0] = layers[0].module

    outs = torch.zeros_like(inps)
    mask = cache['mask']
    freqs_cis = cache['freqs_cis']

    return inps, outs, mask, freqs_cis 


def initialize_wanda(model, tokenizer, device, mask_module, model_name): # TODO DEBUG
    print('================== WANDA INITIALIZE ==================')
    dataloader = get_wanda_c4(tokenizer, 128, 128, model_name)
    with torch.no_grad():
        inps, outs, mask, freqs_cis = prepare_calibration_input(model_name, model, dataloader, device)
    if 'layer' in mask_module.pruning_modules:
        mask_module.masks['layer'].score = nn.Parameter(torch.zeros(*mask_module.masks['layer'].mask_shape, device=device))
    if 'head' in mask_module.pruning_modules:
        num_head = model.cfg.n_kv_heads
        head_dim = model.cfg.d_model // model.cfg.n_heads
    if 'intermediate' in mask_module.pruning_modules:
        pass
    layers = model.layers
    for i in range(len(layers)):
        layer = layers[i]
        subset = find_layers(layer)

        wrapped_layers = {}
        for name in subset:
            wrapped_layers[name] = WrappedGPT(subset[name])

        def add_batch(name):
            def tmp(_, inp, out):
                wrapped_layers[name].add_batch(inp[0].data, out.data)
            return tmp
        
        handles = []
        for name in wrapped_layers:
            handles.append(subset[name].register_forward_hook(add_batch(name)))
        for j in range(128):
            with torch.no_grad():
                if 'llama' in model_name:
                    outs[j] = layer.instantation_forward(inps[j].unsqueeze(0), start_pos=0, freqs_cis=freqs_cis, mask=mask)[0]
                elif 'opt' in model_name:
                    outs[j] = layer(inps[j].unsqueeze(0))[0]
        for h in handles:
            h.remove()

        if 'head' in mask_module.pruning_modules:
            for name in subset:
                if 'wo' in name or 'out_proj' in name:
                    W_metric = torch.abs(subset[name].weight.data) * torch.sqrt(wrapped_layers[name].scaler_row.reshape((1,-1)))

                    W_metric = W_metric.mean(axis=0).reshape(-1, head_dim).sum(dim=1) 
                    if 'llama3' in model_name:
                        W_metric = W_metric.reshape(num_head, - 1).sum(dim=1)
                    thresh = torch.sort(W_metric, descending=True)[0][int(model.cfg.mask.start_sparsity * num_head)].cpu()

                    W_mask = (W_metric >= thresh).float()

                    norm_data = (W_metric - W_metric.mean()) / (W_metric.std() + 1e-8)
                    sigmap_data = torch.sigmoid(norm_data)
                    mean_res = model.cfg.mask.start_sparsity - sigmap_data.mean()
                    
                    mask_module.masks['head'].score.data[i] = (sigmap_data + mean_res).clamp_(0, 1)

                    head_mask = W_mask.reshape(mask_module.masks['head'].mask_output_shape[1:])
                    if 'opt' in model_name:
                        head_mask = head_mask.squeeze()
                        head_mask = head_mask.to(dtype= torch.bfloat16)
                    wrapped_layers[name].free()
                    # TODO: wrapped_layers[name].free()
        if 'intermediate' in mask_module.pruning_modules:
            for name in subset:
                if 'w2' in name or 'fc2' in name:
                    W_metric = torch.abs(subset[name].weight.data) * torch.sqrt(wrapped_layers[name].scaler_row.reshape((1,-1)))
                    W_metric = W_metric.mean(axis=0)
                    # use descending order to get the threshold
                    thresh = torch.sort(W_metric, descending=True)[0][int(model.cfg.mask.start_sparsity * W_metric.numel())].cpu()

                    W_mask = (W_metric >= thresh).float()

                    norm_data = (W_metric - W_metric.mean()) / (W_metric.std() + 1e-8)
                    sigmap_data = torch.sigmoid(norm_data)
                    mean_res = model.cfg.mask.start_sparsity - sigmap_data.mean()
                    
                    mask_module.masks['intermediate'].score.data[i] = (sigmap_data + mean_res).clamp_(0, 1)

                    intermediate_mask = W_mask.reshape(mask_module.masks['intermediate'].mask_output_shape[1:])
                    if 'opt' in model_name:
                        intermediate_mask = intermediate_mask.squeeze()
                        intermediate_mask = intermediate_mask.to(dtype= torch.bfloat16)
                wrapped_layers[name].free()

        for j in range(128):
            with torch.no_grad():
                if 'llama' in model_name:
                    outs[j] = layer(inps[j].unsqueeze(0), start_pos=0, freqs_cis=freqs_cis, mask=mask, head_z=head_mask, head_layer_z=None, intermediate_z=intermediate_mask, mlp_z=None, hidden_z=None)[0]
                elif 'opt' in model_name:
                    outs[j] = layer(inps[j].unsqueeze(0), layer_head_mask = head_mask, intermediate_mask = intermediate_mask)[0]
        inps, outs = outs, inps


def initialize_FLAP(model, tokenizer, device, mask_module, model_name):
    print('================== FLAP INITIALIZE ==================')
    dataloader = get_wanda_c4(tokenizer, 128, 128, model_name)
    with torch.no_grad():
        inps, outs, mask, freqs_cis = prepare_calibration_input(model_name, model, dataloader, device)
    layers = model.layers
    attn_metric_list, mlp_metric_list = [], []
    attn_baseline_inp_list, mlp_baseline_inp_list = [], []

    for i in range(len(layers)):
        layer = layers[i]
        subset = find_layers(layer)

        wrapped_layers = {}
        for name in subset:
            wrapped_layers[name] = WrappedGPT(subset[name])

        def add_batch(name):
            def tmp(_, inp, out):
                wrapped_layers[name].add_batch(inp[0].data, out.data)
            return tmp
        
        handles = []
        for name in wrapped_layers:
            handles.append(subset[name].register_forward_hook(add_batch(name)))
        for j in range(128):
            with torch.no_grad():
                if 'llama' in model_name:
                    outs[j] = layer.instantation_forward(inps[j].unsqueeze(0), start_pos=0, freqs_cis=freqs_cis, mask=mask)[0]
                elif 'opt' in model_name:
                    outs[j] = layer(inps[j].unsqueeze(0))[0]
        for h in handles:
            h.remove()

        if 'layer' in mask_module.pruning_modules:
            for name in subset:
                score = (torch.abs(subset[name].weight.data) * torch.sqrt(wrapped_layers[name].scaler_row.reshape((1,-1)))).sum()
                mask_module.masks['layer'].score.data[i] += score.data
        
        if 'head' in mask_module.pruning_modules:
            for name in subset:
                if 'wo' in name or 'out_proj' in name:
                    W_metric = wrapped_layers[name].fluc_inp * torch.sum(subset[name].weight.data ** 2, dim=0)
                    W_metric = W_metric ** 2
                    attn_metric_list.append(W_metric.cpu())
                    attn_baseline_inp_list.append(wrapped_layers[name].baseline_inp.type(torch.half))

                    wrapped_layers[name].free()
        if 'intermediate' in mask_module.pruning_modules:
            for name in subset:
                if 'w2' in name or 'fc2' in name:
                    W_metric = wrapped_layers[name].fluc_inp * torch.sum(subset[name].weight.data ** 2, dim=0)
                    mlp_metric_list.append(W_metric.cpu())
                    mlp_baseline_inp_list.append(wrapped_layers[name].baseline_inp.type(torch.half))
                    
                    wrapped_layers[name].free()
        inps, outs = outs, inps
        torch.cuda.empty_cache()
    
    standarlization = lambda x: (x-torch.mean(x, axis=1, keepdim=True)) / torch.std(x, axis=1, keepdim=True)

    head_dim = model.cfg.d_model // model.cfg.n_heads

    attn_metric = torch.stack(attn_metric_list)
    attn_metric = standarlization(attn_metric)
    attn_metric = attn_metric.reshape(len(layers), -1, head_dim).mean(dim=2)
    if 'llama3' in model_name:
        attn_metric = attn_metric.reshape(len(layers), model.cfg.n_kv_heads, (model.cfg.n_heads // model.cfg.n_kv_heads)).mean(dim=2)

    mlp_metric = torch.stack(mlp_metric_list)
    mlp_metric = standarlization(mlp_metric)

    sorted_attn = torch.sort(attn_metric.flatten(), descending=True)[0]
    attn_thres = sorted_attn[int(attn_metric.numel() * model.cfg.mask.start_sparsity)]
    attn_mask = (attn_metric > attn_thres).float()

    sorted_mlp = torch.sort(mlp_metric.flatten(), descending=True)[0]
    mlp_thres = sorted_mlp[int(mlp_metric.numel() * model.cfg.mask.start_sparsity)]
    mlp_mask = (mlp_metric > mlp_thres).float()

    attn_norm = (attn_mask-torch.mean(attn_mask)) / (torch.std(attn_mask) + 1e-8)
    mlp_norm = (mlp_mask-torch.mean(mlp_mask)) / (torch.std(mlp_mask) + 1e-8)
    attn_metric = torch.sigmoid(attn_norm)
    mlp_metric = torch.sigmoid(mlp_norm)

    attn_metric = attn_metric.clamp(0, 1)
    mlp_metric = mlp_metric.clamp(0, 1)

    for i in range(len(layers)):
        if 'head' in mask_module.pruning_modules:
            mask_module.masks['head'].score.data[i] = attn_metric[i].float()
        if 'intermediate' in mask_module.pruning_modules:
            mask_module.masks['intermediate'].score.data[i] = mlp_metric[i].float()


def initialize_random(cfg, mask_module):
    sparsity = mask_module.start_sparsity
    layers = cfg.n_layers
    for i in range(layers):
        if 'head' in mask_module.pruning_modules:
            mask_module.masks['head'].score.data[i] = torch.rand_like(mask_module.masks['head'].score.data[i])
        if 'intermediate' in mask_module.pruning_modules:
            mask_module.masks['intermediate'].score.data[i] = torch.rand_like(mask_module.masks['intermediate'].score.data[i])

    if 'head' in mask_module.pruning_modules:
        head_scores = mask_module.masks['head'].score.data
        head_mean = head_scores.mean()  
    
    if 'intermediate' in mask_module.pruning_modules:
        intermediate_scores = mask_module.masks['intermediate'].score.data
        intermediate_mean = intermediate_scores.mean()  
        
    for i in range(layers):
        if 'head' in mask_module.pruning_modules:
            mask_module.masks['head'].score.data[i] = mask_module.masks['head'].score.data[i] * sparsity / head_mean
        if 'intermediate' in mask_module.pruning_modules:
            mask_module.masks['intermediate'].score.data[i] = mask_module.masks['intermediate'].score.data[i] * sparsity / intermediate_mean
