import torch
from tqdm.auto import tqdm
import time
from torch.cuda.amp import GradScaler
import torch.nn as nn
import numpy as np
from torch.cuda.amp import autocast
import os
import math
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from .hypernetwork import learnable_mask

class collect_info_reg_llama(nn.Module):
    def __init__(self, model, p=None, lam=4.0, importance_dict=None):
        super(collect_info_reg_llama, self).__init__()
        self.sum_ori_params = 0 
        self.p = p  
        self.lam = lam  
        ind = 0
        for layer in model.model.layers:
            dim = layer.self_attn.q_proj.weight.shape
            self.sum_ori_params += dim[0] * dim[1]
            layer.self_attn.q_mask.p = p
            layer.self_attn.q_mask.logits = nn.Parameter(torch.zeros(dim[0])+1-p)
            layer.self_attn.q_mask.c = importance_dict[ind].to(layer.self_attn.q_mask.c.device)
            ind += 1

            dim = layer.self_attn.k_proj.weight.shape
            self.sum_ori_params += dim[0] * dim[1]
            layer.self_attn.k_mask.p = p
            layer.self_attn.k_mask.logits = nn.Parameter(torch.zeros(dim[0])+1-p)
            layer.self_attn.k_mask.c = importance_dict[ind].to(layer.self_attn.k_mask.c.device)
            ind += 1

            dim = layer.self_attn.v_proj.weight.shape
            self.sum_ori_params += dim[0] * dim[1]
            layer.self_attn.v_mask.p = p
            layer.self_attn.v_mask.logits = nn.Parameter(torch.zeros(dim[0])+1-p)
            layer.self_attn.v_mask.c = importance_dict[ind].to(layer.self_attn.v_mask.c.device)
            ind += 1

            dim = layer.self_attn.o_proj.weight.shape
            self.sum_ori_params += dim[0] * dim[1]
            layer.self_attn.o_mask.p = p
            layer.self_attn.o_mask.logits = nn.Parameter(torch.zeros(dim[0])+1-p)
            layer.self_attn.o_mask.c = importance_dict[ind].to(layer.self_attn.o_mask.c.device)
            ind += 1

            dim = layer.mlp.gate_proj.weight.shape
            self.sum_ori_params += dim[0] * dim[1]
            layer.mlp.gate_mask.p = p
            layer.mlp.gate_mask.logits = nn.Parameter(torch.zeros(dim[0])+1-p)
            layer.mlp.gate_mask.c = importance_dict[ind].to(layer.mlp.gate_mask.c.device)
            ind+=1

            dim = layer.mlp.up_proj.weight.shape
            self.sum_ori_params += dim[0] * dim[1]
            layer.mlp.up_mask.p = p
            layer.mlp.up_mask.logits = nn.Parameter(torch.zeros(dim[0])+1-p)
            layer.mlp.up_mask.c = importance_dict[ind].to(layer.mlp.up_mask.c.device)
            ind+=1

            dim = layer.mlp.down_proj.weight.shape
            self.sum_ori_params += dim[0] * dim[1]
            layer.mlp.down_mask.p = p
            layer.mlp.down_mask.logits = nn.Parameter(torch.zeros(dim[0])+1-p)
            layer.mlp.down_mask.c = importance_dict[ind].to(layer.mlp.down_mask.c.device)
            ind+=1

        print("Number of original parameters: %.3f" % (self.sum_ori_params / 10 ** 6))
            
    def forward(self, model):
        #block_mlp_dim = None
        sum_params = 0
        for m in model.modules():
            if type(m).__name__ == 'learnable_mask' and m.logits.numel() != 0:
                sum_params += m.get_nnz()
        print(sum_params)


        device_id = torch.cuda.current_device()
        sum_params_tensor = torch.tensor(float(sum_params), device=f"cuda:{device_id}")
        torch.distributed.all_reduce(sum_params_tensor)
        sum_params = sum_params_tensor

        # Calculate parameter ratio
        #print("Number of parameters after pruning: %.3f" % (sum_params / 10 ** 6))
        param_ratio = sum_params / self.sum_ori_params
        if param_ratio > self.p:
            clamped_p_ratio = torch.clamp(param_ratio, min=self.p)
            loss = torch.log(clamped_p_ratio / self.p)
        else:
            clamped_p_ratio = torch.clamp(param_ratio, max=self.p)
            loss = torch.log(self.p / clamped_p_ratio)

        return self.lam * loss
    
class collect_info_reg_opt(nn.Module):
    def __init__(self, model, p=None, lam=4.0, importance_dict=None):
        super(collect_info_reg_llama, self).__init__()
        self.sum_ori_params = 0 
        self.p = p  
        self.lam = lam  
        ind = 0
        for layer in model.model.decoder.layers:
            dim = layer.self_attn.k_proj.weight.shape
            self.sum_ori_params += dim[0] * dim[1]
            layer.self_attn.k_mask.p = p
            layer.self_attn.k_mask.logits = nn.Parameter(torch.zeros(dim[0])+1-p)
            layer.self_attn.k_mask.c = importance_dict[ind].to(layer.self_attn.k_mask.c.device)
            ind += 1

            dim = layer.self_attn.v_proj.weight.shape
            self.sum_ori_params += dim[0] * dim[1]
            layer.self_attn.v_mask.p = p
            layer.self_attn.v_mask.logits = nn.Parameter(torch.zeros(dim[0])+1-p)
            layer.self_attn.v_mask.c = importance_dict[ind].to(layer.self_attn.v_mask.c.device)
            ind += 1

            dim = layer.self_attn.q_proj.weight.shape
            self.sum_ori_params += dim[0] * dim[1]
            layer.self_attn.q_mask.p = p
            layer.self_attn.q_mask.logits = nn.Parameter(torch.zeros(dim[0])+1-p)
            layer.self_attn.q_mask.c = importance_dict[ind].to(layer.self_attn.q_mask.c.device)
            ind += 1

            dim = layer.self_attn.out_proj.weight.shape
            self.sum_ori_params += dim[0] * dim[1]
            layer.self_attn.out_mask.p = p
            layer.self_attn.out_mask.logits = nn.Parameter(torch.zeros(dim[0])+1-p)
            layer.self_attn.out_mask.c = importance_dict[ind].to(layer.self_attn.out_mask.c.device)
            ind += 1

            dim = layer.fc1.weight.shape
            self.sum_ori_params += dim[0] * dim[1]
            layer.fc1_mask.p = p
            layer.fc1_mask.logits = nn.Parameter(torch.zeros(dim[0])+1-p)
            layer.fc1_mask.c = importance_dict[ind].to(layer.fc1_mask.c.device)
            ind+=1

            dim = layer.fc2.weight.shape
            self.sum_ori_params += dim[0] * dim[1]
            layer.fc2_mask.p = p
            layer.fc2_mask.logits = nn.Parameter(torch.zeros(dim[0])+1-p)
            layer.fc2_mask.c = importance_dict[ind].to(layer.fc2_mask.c.device)
            ind+=1

        print("Number of original parameters: %.3f" % (self.sum_ori_params / 10 ** 6))
            
    def forward(self, model):
        #block_mlp_dim = None
        sum_params = 0
        for m in model.modules():
            if type(m).__name__ == 'learnable_mask' and m.logits.numel() != 0:
                sum_params += m.get_nnz()

        device_id = torch.cuda.current_device()
        sum_params_tensor = torch.tensor(float(sum_params), device=f"cuda:{device_id}")
        torch.distributed.all_reduce(sum_params_tensor)
        sum_params = sum_params_tensor

        # Calculate parameter ratio
        #print("Number of parameters after pruning: %.3f" % (sum_params / 10 ** 6))
        param_ratio = sum_params / self.sum_ori_params
        if param_ratio > self.p:
            clamped_p_ratio = torch.clamp(param_ratio, min=self.p)
            loss = torch.log(clamped_p_ratio / self.p)
        else:
            clamped_p_ratio = torch.clamp(param_ratio, max=self.p)
            loss = torch.log(self.p / clamped_p_ratio)

        return self.lam * loss
