import copy
import torch
from torch import nn
from argparse import Namespace as NS
from typing import List
import numpy as np

class Mask(nn.Module):
    def __init__(self, 
                 name: str,
                 mask_shape: List, 
                 mask_output_shape: List, 
                 device: str,
                 initial_score: float) -> None:
        super().__init__()
        self.name = name
        self.mask_output_shape = mask_output_shape
        self.initial_score = initial_score

        self.device = device

        self.mask_shape = mask_shape
        self.score = self.initialize_mask(mask_shape)
        self.score.grad = torch.zeros_like(self.score)
        
    def param_init_fn(self, module):
        mean = self.initial_score
        if isinstance(module, nn.Parameter):
            module.data.normal_(mean, 1e-2)
        else:
            for tensor in module.parameters():
                tensor.data.normal_(mean, 1e-2)

    def initialize_mask(self, mask_shape: List):
        z_score = nn.Parameter(torch.ones(*mask_shape, device=self.device))
        self.param_init_fn(z_score)
        return z_score

    def sample_z(self):
        z = (torch.rand_like(self.score) < self.score).to(self.score.device, dtype=torch.bfloat16)
        if self.name == 'head':
            empty_head = False
            for layer_idx in range(z.size(0)):
                if z[layer_idx].sum() == 0:
                    empty_head = True
                    break
            while empty_head:
                z = (torch.rand_like(self.score) < self.score).to(self.score.device, dtype=torch.bfloat16)
                for layer_idx in range(z.size(0)):
                    if z[layer_idx].sum() == 0:
                        empty_head = True
                        break
                    else:
                        empty_head = False

        grad = (z - self.score) / torch.sqrt((self.score + 1e-8) * (1- self.score + 1e-8))
        return z, grad

    def forward(self):
        z, grad = self.sample_z()
        return z.reshape(self.mask_output_shape), grad.reshape(self.mask_shape)
    
    def determain_forward(self):
        total_score = self.mask_param_num()

        num_elements_to_keep = int((self.score.sum() / total_score) * self.score.numel())
        num_elements_to_keep = min(num_elements_to_keep, self.score.numel())

        flat_scores = self.score.view(-1)
        _, sorted_indices = flat_scores.sort(descending=True)
        top_x_indices = sorted_indices[:num_elements_to_keep]

        z = torch.zeros_like(self.score)

        z.view(-1)[top_x_indices] = 1
        return z.to(self.score.device, dtype=torch.bfloat16).reshape(self.mask_output_shape)
        
    def constrain_parameters(self):
        self.score.data.clamp_(min=.0, max=1.0)
    
    def mask_param_num(self):
        product = 1
        for num in self.mask_shape:
            product *= num
        return product

    def solve_v_total(self, subset):
        score = self.score

        k = subset
        a, b = 0, 0
        b = max(b, score.max())

        def f(v):
            s = (score - v).clamp(0, 1).sum()
            return s - k

        if f(0) < 0:
            return 0
        itr = 0
        while (1):
            itr += 1
            v = (a + b) / 2
            obj = f(v)
            if abs(obj) < 1e-3 or itr > 20:
                break
            if obj < 0:
                b = v
            else:
                a = v
        v = max(0, v)
        return v


class OurMask(nn.Module):
    def __init__(self, cfg, device):
        super(OurMask, self).__init__()

        # base and target model info
        n_matrix_mlp = 2 if "opt" in cfg.name else 3
        self.base_model_info = self.set_model_info(cfg, n_matrix_mlp=n_matrix_mlp) 

        mask_cfg = cfg.mask

        self.pruning_modules = mask_cfg.pruning_modules        
        self.start_sparsity = mask_cfg.start_sparsity 
        self.target_sparsity = mask_cfg.target_sparsity
        self.device = device

        self.masks = {}
        for pruning_module in self.pruning_modules:# head, head_layer, mlp, intermediate
            self.initialize_one_module(pruning_module)
        self.masks = torch.nn.ModuleDict(self.masks)
    
    def set_model_info(self, cfg, n_matrix_mlp):
        ns = NS() 
        ns.hidden_size = cfg.d_model
        ns.intermediate_size = cfg.intermediate_size
        ns.num_attention_heads = cfg.n_kv_heads
        ns.mlp_num_per_layer = n_matrix_mlp
        ns.dim_per_head = ns.hidden_size // cfg.n_heads
        ns.num_layers = cfg.n_layers
        ns.vocab_size = cfg.vocab_size

        return ns
        
    def initialize_one_module(self, module_name: str):
        func_name = f"initialize_{module_name}"
        try:
            method = getattr(self, func_name)
        except AttributeError:
            raise NotImplementedError("Instance `{}` does not implement `{}`".format(self, func_name))
        method()

    def initialize_head(self):
        mask_shape = [self.base_model_info.num_layers, self.base_model_info.num_attention_heads]# llama2-7B 就是[32, 32]
        mask_output_shape = [self.base_model_info.num_layers, 1, self.base_model_info.num_attention_heads, 1]
        head_mask = Mask(name="head",
                        mask_shape=mask_shape,
                        mask_output_shape=mask_output_shape,
                        device=self.device,
                        initial_score=self.start_sparsity)
        self.masks["head"] = head_mask 
        
    def initialize_intermediate(self):
        mask_shape = [self.base_model_info.num_layers, self.base_model_info.intermediate_size]
        mask_output_shape = [self.base_model_info.num_layers, 1, 1, self.base_model_info.intermediate_size] 
        int_mask = Mask(name="intermediate",
                        mask_shape=mask_shape,
                        mask_output_shape=mask_output_shape,
                        device=self.device,
                        initial_score=self.start_sparsity)
        self.masks["intermediate"] = int_mask
    
    def constrain_parameters(self):
        for key in self.masks:
            self.masks[key].constrain_parameters()
    
    def forward(self, ppl_during_train=False):
        self.constrain_parameters()
        
        zs = {f"{pruning_module}_z": [] for pruning_module in self.pruning_modules}
        grads = {f"{pruning_module}_grad": [] for pruning_module in self.pruning_modules}
        
        if "layer" in self.pruning_modules:
            zs.pop("layer_z")
            zs["mlp_z"] = []
            zs["head_layer_z"] = []
        
        if not ppl_during_train:
            for pruning_module in self.pruning_modules:
                mask = self.masks[pruning_module]
                z, grad = mask()
                zs[f"{pruning_module}_z"] = z
                grads[f"{pruning_module}_grad"] = grad
        else: # removed layerwise! 
            with torch.no_grad():
                for pruning_module in self.pruning_modules:
                    mask = self.masks[pruning_module]
                    z = mask.determain_forward()
                    if pruning_module == 'head':
                        template_score = mask.score.squeeze()
                        z = z.squeeze()
                        for layer_idx in range(z.size(0)):
                            if z[layer_idx].sum() == 0:
                                max_idx = template_score[layer_idx].argmax()
                                z[layer_idx][max_idx] = 1
                        z.reshape(mask.mask_output_shape)
                    zs[f"{pruning_module}_z"] = z
        if "layer_z" in zs:
            zs["mlp_z"] = zs.pop("layer_z")
            zs["head_layer_z"] = zs["mlp_z"]
        
        return (zs, grads) if not ppl_during_train else (zs, None)

    def reweight_score(self, outdated_score, update_score, zs):
        score = ((update_score / (outdated_score + 1e-8)) ** zs) * (((1 - update_score) / (1 - outdated_score + 1e-8)) ** (1 - zs))
        return score

    def update(self, t, fn_list, grad_list, K, outdated_masks, outdated_zs): 
        T = 5
        if t == 0:
            self.delta = sum(fn_list) / len(fn_list)
        elif not isinstance(outdated_masks, list):
            self.delta = (T - 1) / T * self.delta + 1 / T * sum(fn_list) / len(fn_list)
        grad_plot_list = {f"{pruning_module}_grad": [] for pruning_module in self.pruning_modules}
        for pruning_module in self.pruning_modules:
            mask = self.masks[pruning_module]
            for i in range(K):
                if outdated_zs is not None:
                    zs = outdated_zs
                    weight = 1
                    if isinstance(outdated_zs, list):
                        weight = K - 1
                        outdated_zs_t = outdated_zs[i]
                        zs = outdated_zs_t
                    if pruning_module == 'layer':
                        zs = zs['head_layer_z']
                    else:
                        zs = zs[f'{pruning_module}_z']
                    zs = zs.reshape(mask.mask_shape)
                    outdated_masks_t = outdated_masks
                    if isinstance(outdated_masks, list):
                        outdated_masks_t = outdated_masks[i]
                    reweight_score = self.reweight_score(outdated_masks_t[pruning_module].score.data, mask.score.data, zs)

                    mask.score.grad.data += 1 / weight * (fn_list[i] - sum(fn_list) / len(fn_list)) * grad_list[i][f"{pruning_module}_grad"] * reweight_score
                else:
                    mask.score.grad.data += 1 / (K - 1) * (fn_list[i] - self.delta) * grad_list[i][f"{pruning_module}_grad"]
            grad_plot_list[f"{pruning_module}_grad"] = mask.score.grad.data
        torch.nn.utils.clip_grad_norm_(self.masks.parameters(), 3) 
        return grad_plot_list

    def constrain_score(self, target_mask_num):
        with torch.no_grad():
            for pruning_module in self.pruning_modules:
                mask = self.masks[pruning_module]
                v = mask.solve_v_total(target_mask_num[pruning_module])
                mask.score.sub_(v).clamp_(0, 1)
                delta = (mask.score.sum() - target_mask_num[pruning_module]) / target_mask_num[pruning_module]
                if abs(delta) > 1e-2:
                    mask.score.mul_(target_mask_num[pruning_module] / mask.score.sum())
                    mask.score.clamp_(0, 1)
                    
    def calculate_target_mask_num(self, epoch, t, epochs, length):
        ts = int(0.16 * epochs * length)
        te = int(0.60 * epochs * length)

        if (epoch * length + t) < ts:
            sparsity_rate = self.start_sparsity
        elif (epoch * length + t) < te:
            sparsity_rate = self.target_sparsity + (self.start_sparsity -self.target_sparsity) * (1 - (epoch * length + t - ts) / (te - ts))**3
        else:
            sparsity_rate = self.target_sparsity
        target_mask_num = {}
        for pruning_module in self.pruning_modules:
            target_mask_num[pruning_module] = int(sparsity_rate * self.masks[pruning_module].mask_param_num())

        return target_mask_num
