import os
import math
from torch.autograd import Function
import numpy
import numpy as np
import pickle
import matplotlib.pyplot as plt
import ot

import torch
import torch.optim as optim
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn.functional as F
import time
import logging
import numpy as np
from scipy.stats import kendalltau
from copulas.bivariate import Frank
from scipy.spatial.distance import mahalanobis
import scipy.stats as stats

from utils.utils import (
    flatten_tensor_list,
    get_summary_stats,
    dump_tensor_to_mat,
    add_noise_to_grads,
)
from policies.pruners import GradualPruner
from utils.plot import analyze_new_weights
from common.io import _load, _dump
from common.timer import Timer
from copy import copy



class CopulaPruner(GradualPruner):
    have_not_started_pruning = True

    def __init__(self, model, inp_args, **kwargs):
        super(CopulaPruner, self).__init__(model, inp_args, **kwargs)
        print("In Optimal Transport Pruner")
        self._fisher_inv_diag = None
        self._prune_direction = inp_args.prune_direction
        self._zero_after_prune = inp_args.zero_after_prune
        self._inspect_inv = inp_args.inspect_inv
        self._fisher_mini_bsz = inp_args.fisher_mini_bsz
        if self._fisher_mini_bsz < 0:
            self._fisher_mini_bsz = 1
        if self.args.woodburry_joint_sparsify:
            self._param_stats = []
        if self.args.dump_fisher_inv_mat:
            self._all_grads = []
        if self.args.fisher_inv_path is None:
            N_samples = self.args.fisher_subsample_size * self.args.fisher_mini_bsz
            N_batches = self.args.fisher_subsample_size
            seed = self.args.seed
            self.args.fisher_inv_path = os.path.join(
                "./prob_regressor_data",
                f"{self.args.arch}_{self.args.dset}_{N_samples}samples_{N_batches}batches_{seed}seed.fisher_inv",
            )

    def _get_weights(self):
        weights = []
        masks = []

        for idx, module in enumerate(self._modules):
            assert self._weight_only
            weights.append(module.weight.data.flatten())
            masks.append(module.weight_mask.data.flatten())
        weights = torch.cat(weights).to(module.weight.device)
        masks = torch.cat(masks).to(module.weight_mask.device)
        return weights, masks

    def _compute_sample_fisher(self, loss, return_outer_product=False):
        """Inputs:
            loss: scalar or B, Bx1 tensor
        Outputs:
            grads_batch: BxD
            gTw: B (grads^T @ weights)
            params: D
            ff: 0.0 or DxD(sum of grads * grads^T)
        """
        ys = loss
        params = []
        for module in self._modules:
            for name, param in module.named_parameters():
                # print("name is {} and shape of param is {} \n".format(name, param.shape))
                if self._weight_only and "bias" in name:
                    continue
                else:
                    params.append(param)

        grads = torch.autograd.grad(ys, params)  # first order gradient
        if self.args.add_noise:
            noise_values = self.args.add_noise
            grads = add_noise_to_grads(
                grads, noise_std_scale=noise_values[0], prop=noise_values[1]
            )

        # Do gradient_masking: mask out the parameters which have been pruned previously
        # to avoid rogue calculations for the hessian

        for idx, module in enumerate(self._modules):
            grads[idx].data.mul_(module.weight_mask)
            params[idx].data.mul_(module.weight_mask)

        grads = flatten_tensor_list(grads)
        params = flatten_tensor_list(params)

        if self.args.dump_grads_mat:
            self._all_grads.append(grads)

        self._num_params = len(grads)
        self._old_weights = params

        # gTw = params.T @ grads
        gTw = None

        if not return_outer_product:
            # return grads, grads, gTw, params
            return grads, None, gTw, params
        else:
            return grads, torch.ger(grads, grads), gTw, params

    def _get_pruned_wts_scaled_basis(self, pruned_params, flattened_params):
        # import pdb;pdb.set_trace()
        return -1 * torch.div(
            torch.mul(pruned_params, flattened_params), self._fisher_inv_diag
        )

    def _release_grads(self):
        optim.SGD(self._model.parameters(), lr=1e-10).zero_grad()

    def _compute_wgH(self, dset, subset_inds, device, num_workers, debug=False):
        backup_masks = []
        for idx, module in enumerate(self._modules):
            backup_masks.append(module.weight_mask)
            module.weight_mask = torch.ones_like(module.weight_mask)

        st_time = time.perf_counter()

        self._model = self._model.to(device)

        print("in copula pruning: len of subset_inds is ", len(subset_inds))

        goal = self.args.fisher_subsample_size

        # print(goal)
        # print(self.args.fisher_mini_bsz)

        assert len(subset_inds) == goal * self.args.fisher_mini_bsz

        dummy_loader = torch.utils.data.DataLoader(
            dset,
            batch_size=self._fisher_mini_bsz,
            num_workers=num_workers,
            sampler=SubsetRandomSampler(subset_inds),
        )
        ## get g and g^T * w, denoted as XX and yy respectively
        Gs = []
        GTWs = []

        if self.args.aux_gpu_id != -1:
            aux_device = torch.device("cuda:{}".format(self.args.aux_gpu_id))
        else:
            aux_device = torch.device("cpu")

        if self.args.disable_log_soft:
            # set to true for resnet20 case
            # set to false for mlpnet as it then returns the log softmax and we go to NLL
            criterion = torch.nn.functional.cross_entropy
        else:
            criterion = F.nll_loss

        self._fisher_inv = None

        num_batches = 0
        num_samples = 0

        FF = 0.0

        for in_tensor, target in dummy_loader:
            self._release_grads()

            # in_tensor = add_noise(in_tensor)
            # target = randomize_labels(target)

            in_tensor, target = in_tensor.to(device), target.to(device)
            output = self._model(in_tensor)
            loss = criterion(output, target, reduction="mean")

            ## compute grads, XX, yy
            g, _, gTw, w = self._compute_sample_fisher(loss, return_outer_product=False)
            Gs.append(torch.Tensor(g[None, :].detach().cpu().numpy()))
            # GTWs.append(torch.Tensor(gTw[None, None].detach().cpu().numpy()))
            w = w.detach().cpu().numpy()
            # FF += ff
            del g, gTw

            num_batches += 1
            num_samples += self._fisher_mini_bsz
            if num_samples == goal * self._fisher_mini_bsz:
                break

        grads = torch.cat(tuple(Gs), 0)
        print(
            "# of examples done {} and the goal (#outer products) is {}".format(
                num_samples, goal
            )
        )
        print("# of batches done {}".format(num_batches))

        end_time = time.perf_counter()
        print(
            "Time taken to compute empirical fisher is {} seconds".format(
                str(end_time - st_time)
            )
        )
        for idx, module in enumerate(self._modules):
            module.weight_mask = backup_masks[idx]

        return grads, GTWs, w, None

    def __top_k_indice(self, vector, k):
        """Function to get the indices of the k largest values of a vector"""
        _, indices = torch.topk(vector.abs(), k)
        return set(indices.tolist())

    def __same_support(self, a, b, k):
        """Check if two vectors have the same support"""
        return self.__top_k_indice(a, k) == self.__top_k_indice(b, k)

    def __find_tau_c(self, a, b, k, eps=1e-12):
        """Find the largest tau such that the support of a-tau*b is the same as a"""
        low, high = 0, min(torch.max(a / (b + eps)), 100)

        count = 0
        while (high - low > eps) and (count < 50):
            count += 1
            # print(f'low={low}, high={high}')
            mid = (low + high) / 2
            if self.__same_support(a, a - mid * b, k):
                low = mid
            else:
                high = mid

        return torch.tensor(low)


    class CustomLoss(Function):
        @staticmethod
        def forward(ctx, a, b, X, w_bar, lam, ome, CopulaPruner, tau):
            lam_torch = torch.tensor(lam, device=X.device)
            ome_torch = torch.tensor(ome, device=X.device)

            w = (a - tau * b).to(X.device)
            
            w = torch.clamp(w, min=0, max=1)
            w_bar = torch.clamp(w_bar, min=0, max=1)

            joint_density_w = CopulaPruner.frankc_n_np(w, X)
            joint_density_w_bar_r, partial_derivatives = CopulaPruner.frankc_n(w_bar, X)

            ctx.save_for_backward(w, X, b, w_bar, lam_torch, ome_torch,
                                torch.from_numpy(joint_density_w_bar_r).to(lam_torch.device),
                                torch.from_numpy(joint_density_w).to(lam_torch.device),
                                torch.from_numpy(partial_derivatives).to(lam_torch.device))
            ctx.CopulaPruner = CopulaPruner

            n = X.shape[0]

            entropy_w1_l1 = CopulaPruner.compute_copula_entropy(joint_density_w_bar_r)
            entropy_w2_l1 = CopulaPruner.compute_copula_entropy(joint_density_w)

            entropy_w1_l1 = torch.from_numpy(entropy_w1_l1).to(lam_torch.device)
            entropy_w2_l1 = torch.from_numpy(entropy_w2_l1).to(lam_torch.device)

            entropy_w1_l1 = CopulaPruner.replace_nans_with_zero(entropy_w1_l1).to(lam_torch.device)
            entropy_w2_l1 = CopulaPruner.replace_nans_with_zero(entropy_w2_l1).to(lam_torch.device)

            diff = torch.sum(torch.linalg.norm(entropy_w1_l1 - entropy_w2_l1, dim=1), dim=0) / n

            loss = (1 / 2) * torch.sum(torch.from_numpy(np.array([CopulaPruner.euclidean_distance_2d(joint_density_w, joint_density_w_bar_r)])).to(lam_torch.device)) / 100 + \
                (n / 2) * lam_torch * torch.linalg.norm(w - w_bar) ** 2 + \
                (n / 2) * (ome_torch * diff)

            assert not torch.isnan(loss).any(), "loss contains NaN"

            print(f"\t loss: {loss}")
            return loss

        @staticmethod
        def backward(ctx, grad_output):
            w, X, b, w_bar, lam_torch, ome_torch, joint_density_w_bar_r, joint_density_w, partial_derivatives = ctx.saved_tensors
            CopulaPruner = ctx.CopulaPruner

            n = X.shape[0]
            n_torch = torch.tensor(n, device=w.device)
            joint_density_w_bar_r = joint_density_w_bar_r.cpu().numpy()
            joint_density_w = joint_density_w.cpu().numpy()
            partial_derivatives = partial_derivatives.cpu().numpy()

            partial_deriv_w1 = CopulaPruner.partial_derivative_entropy_w1(joint_density_w, partial_derivatives)
 

            entropy_w1_l1 = CopulaPruner.compute_copula_entropy(joint_density_w)
            entropy_w2_l1 = CopulaPruner.compute_copula_entropy(joint_density_w_bar_r)

            entropy_w2_l1_sum = np.sum(entropy_w2_l1, axis=1)
            # 计算绝对值
            abs_entropy_w2_l1_sum = np.abs(entropy_w2_l1_sum)

            # 计算绝对值的总和
            total_abs_sum = np.sum(abs_entropy_w2_l1_sum)

            # 计算每个值占总和的比例
            proportions = abs_entropy_w2_l1_sum / total_abs_sum


            gong1_g = CopulaPruner.euclidean_distance_and_gradient_torch(joint_density_w, joint_density_w_bar_r, partial_derivatives)

            gong1_g = (gong1_g.T @ proportions)

            gong3 = entropy_w1_l1 - entropy_w2_l1
            gong3_g = -np.sum(gong3 * partial_deriv_w1, axis=0) / 100

            gong1_g = torch.from_numpy(gong1_g).to(w.device)
            gong3_g = torch.from_numpy(gong3_g).to(w.device)





            grad_w = gong1_g + n_torch * lam_torch * (w - w_bar) + n_torch * ome_torch * gong3_g
            b = b.to(w.device)
            # 设置上下限
            min_value = -0.01
            max_value = 0.01

            # 使用 torch.clamp 将每个元素裁剪到 -1 到 1 之间
            grad_w = torch.clamp(grad_w, min=min_value, max=max_value).to(w.device)
            b = torch.clamp(b, min=min_value, max=max_value).to(w.device)
            # 对梯度进行裁剪
            # 设定一个合理的阈值

            k=-(b @ grad_w)
            print(f"\t k: {k}")

            min_value = -5e-07

            if torch.isnan(k):
                k = torch.tensor([min_value]).to(w.device)

            grad_tau = (k) * grad_output
            grad_tau = grad_tau.view(-1)
            


            print(f"\t grad_tau: {grad_tau}")


            print(f"\t b: {b}")
            print(f"\t grad_output: {grad_output}")

            print(f"\t grad_w: {grad_w}")
            print(f"\t gong1_g: {gong1_g}")
            print(f"\t gong2_g: {(n_torch * lam_torch * (w - w_bar))}")
            print(f"\t gong3_g: {(n_torch * ome_torch * gong3_g)}")

            return None, None, None, None, None, None, None, grad_tau




    def __find_minimizing_tau(
            self, a, b, tau_c, k, X, y, w_bar, lam, ome, reg, transport,entropy, PI, lr=0.005, eps=9e-10
    ):
        """Find the tau that minimizes Q"""
        tau = torch.tensor([0.5], requires_grad=True, device=X.device)
        tau.data = tau.data.clamp(min=0, max=tau_c - eps)

        optimizer = optim.Adam([tau], lr=lr)

        a = self._hard_threshold(a, k)

        a = a.to("cpu")
        b = b.to("cpu")

        mask = (a != 0).float()
        b *= mask.to("cpu")
        
        for i in range(150):
            print(f"\t 开始循环 = {i}")
            optimizer.zero_grad()
            
            # loss=self.CustomLoss.apply((a - tau * b).to(X.device), b , X, w_bar, lam, ome, self, tau)
            loss=self.CustomLoss.apply(a.to(X.device), b.to(X.device) , X, w_bar, lam, ome, self, tau)


            # loss, _, _ = self._pruning_objective(
            #     X=X,
            #     y=y,
            #     w=(a - tau * b).to(X.device),#在cpu上
            #     w_bar=w_bar,
            #     lam=lam,
            #     ome=ome,
            #     reg=reg,
            #     entropy=entropy,
            #     transport=transport,
            #     PI=PI,
            # )
            loss.backward()
            optimizer.step()

            print(f"\t optimized tau = {tau.data}")

            if tau.grad is None:
                raise Exception("Sorry, grads seem to be None")
            # Stop the loop when gradient is close to 0
            if np.abs(tau.grad.to("cpu")) < eps or tau >= tau_c.to("cpu") or tau <= 0:
                break

        tau.data = tau.data.clamp(min=0, max=tau_c)
        return tau.data

    def _hard_threshold(self, x, k):
        # Set all but the largest k elements in x to zero
        if k < 1:
            return torch.zeros(len(x))
        weights = x.clone()
        threshold = (
            weights.abs().flatten().kthvalue(int(weights.numel() - k + 1), dim=-1)[0]
        )
        weights[weights.abs() < threshold] = 0
        return weights

    def _set_weights(self, weight_updates, module_param_indices_list, set_mask):
        for idx, module in enumerate(self._modules):
            weight = weight_updates[
                module_param_indices_list[idx] : module_param_indices_list[idx + 1]
            ]
            weight = weight.view(module.weight.shape)

            mask = (weight != 0).float()
            if set_mask:
                module.weight_mask = mask

            with torch.no_grad():
                module.weight.data = module.weight.data * mask

    def _pruning_objective(self, X, y, w, w_bar, lam, ome, reg,entropy, transport, PI=None):
        n = X.shape[0]
        # w = w.to(X.device)

        lam_torch = torch.tensor(lam, device=w.device)
        ome_torch = torch.tensor(ome, device=w.device)

        w = torch.clamp(w, min=0, max=1)
        w_bar = torch.clamp(w_bar, min=0, max=1)
        joint_density_w=self.frankc_n_np(w,X)
        joint_density_w_bar_r =self.frankc_n_np(w_bar,X)

        entropy_w1_l1 = self.compute_copula_entropy(joint_density_w_bar_r)
        entropy_w2_l1 = self.compute_copula_entropy(joint_density_w)

        print("\t entropy_w1_l1", entropy_w1_l1.shape)
        print("\t entropy_w1_l1", entropy_w2_l1.shape)

        entropy_w1_l1=torch.from_numpy(entropy_w1_l1)
        entropy_w2_l1=torch.from_numpy(entropy_w2_l1)
        
        entropy_w1_l1=self.replace_nans_with_zero(entropy_w1_l1).to(lam_torch.device)
        entropy_w2_l1=self.replace_nans_with_zero(entropy_w2_l1).to(lam_torch.device)

        diff = torch.sum(torch.linalg.norm(entropy_w1_l1 - entropy_w2_l1, dim=1),dim=0)/n
        if not transport:

            if not entropy:
                Q = torch.sum(torch.from_numpy(np.array([self.euclidean_distance_2d(joint_density_w, joint_density_w_bar_r)])).to(lam_torch.device)) + (
                        n / 2
                ) * lam_torch * torch.linalg.norm(w - w_bar) ** 2
            else:

                print("\t q第一项", (
                        1 / 2
                )*torch.sum(torch.from_numpy(np.array([self.euclidean_distance_2d(joint_density_w, joint_density_w_bar_r)])).to(lam_torch.device))/100)
 
                print("\t q第三项", (
                        n / 2
                )*(ome_torch *diff )
                      )

                print("\t q第二项", (
                        n / 2
                ) * lam_torch * torch.linalg.norm(w - w_bar) ** 2

                     )


                # Q =torch.sum(torch.from_numpy(np.array([self.euclidean_distance_2d(joint_density_w, joint_density_w_bar_r)])).to(lam_torch.device))/100+(
                #         n / 2
                # )*(ome_torch *diff )



                Q =(
                        1 / 2
                )*torch.sum(torch.from_numpy(np.array([self.euclidean_distance_2d(joint_density_w, joint_density_w_bar_r)])).to(lam_torch.device))/100 + (
                        n / 2
                ) * lam_torch * torch.linalg.norm(w - w_bar) ** 2 +(
                        n / 2
                )*(ome_torch *diff )


        else:
            Q=[]
            PI, M = torch.eye(n) * 1 / n, None
            if PI is None:
                PI, M = self._get_transportation_plan(
                    joint_density_w=joint_density_w, joint_density_w_bar_r=joint_density_w_bar_r, reg=reg, transport=transport
                )

        # M = (
        #         torch.cdist(
        #             (joint_density_w_bar_r).reshape(n, 1),
        #             (joint_density_w).reshape(n, 1),
        #             p=2,
        #         )
        #         ** 2
        #     )
            # ot_dist = torch.sum(PI * M.to(PI.device))
            # # print('\tScaled OT distance', ot_dist*n)
            # # print('\tSq Euclidean distance', torch.linalg.norm(X @ w_bar - X @ w)**2)
            # Q = (n / 2) * ot_dist + (n / 2) * lam_torch * torch.linalg.norm(
            #     joint_density_w - joint_density_w_bar_r
            # ) ** 2+ ome * (entropy_w1_l1 - entropy_w2_l1)
        M=[]

        print("\t Q.shap", Q.shape)
        return Q, PI, M

    def _get_transportation_cost(self, joint_density_w,joint_density_w_bar_r):
        original_distr = joint_density_w
        embedded_distr = joint_density_w_bar_r

        n = joint_density_w.shape[0]

        original_distr = original_distr.detach().cpu().numpy()
        embedded_distr = embedded_distr.detach().cpu().numpy()

        original_distr = original_distr
        embedded_distr = embedded_distr

        # Compute the cost matrix (squared Euclidean distance) between original_distr and embedded_distr
        M = ot.dist(
            original_distr.reshape(n, 1),
            embedded_distr.reshape(n, 1),
            metric="sqeuclidean",
        )

        return M

    def _get_transportation_plan(self, joint_density_w,joint_density_w_bar_r, reg, transport):

        n = len(joint_density_w)
        if not transport:
            return torch.eye(n) * 1 / n, None

        M = self._get_transportation_cost(joint_density_w,joint_density_w_bar_r)

        original_distr_mass = [1 / n for i in range(n)]
        embedded_distr_mass = [1 / n for i in range(n)]#往后的权重逐步扩大

        if reg == "inf":
            PI = np.ones((n, n)) / n**2
        elif reg <= 1e-10:
            PI = ot.emd(original_distr_mass, embedded_distr_mass, M)
        else:
            PI = ot.bregman.sinkhorn(
                original_distr_mass, embedded_distr_mass, M, reg=reg, numItermax=5000
            )
            # PI = ot.bregman.sinkhorn_epsilon_scaling(
            #     original_distr_mass, embedded_distr_mass, M, reg=reg, numItermax=5000
            # )

        return torch.from_numpy(PI).float().to(joint_density_w.device), torch.from_numpy(
            M
        ).float().to(joint_density_w.device)


    def _compute_local_proportions_1d_numpy(self,tensor, r):

        tensor=tensor.cpu().numpy()
        length = len(tensor)
        result = np.zeros_like(tensor, dtype=np.float32)

        for i in range(length):
            if tensor[i] == 0:
                continue

            start = max(0, i - r)
            end = min(length, i + r + 1)
            window = tensor[start:end]

            window_abs = np.abs(window)
            window_sum = np.sum(window_abs)

            if window_sum == 0:
                result[i] = 0
            else:
                result[i] = np.abs(tensor[i]) / window_sum

        return result

    def _compute_local_proportions_2d_numpy(self,tensor, r):
        tensor=tensor.cpu().numpy()
        rows, cols = tensor.shape
        result = np.zeros_like(tensor, dtype=np.float32)

        for row in range(rows):
            for col in range(cols):
                if tensor[row, col] == 0:
                    continue

                start = max(0, col - r)
                end = min(cols, col + r + 1)
                window = tensor[row, start:end]

                window_abs = np.abs(window)
                window_sum = np.sum(window_abs)

                if window_sum == 0:
                    result[row, col] = 0
                else:
                    result[row, col] = np.abs(tensor[row, col]) / window_sum

        return result




    # def frankc(self, w, l):
    #     w=w.detach().cpu().numpy()
    #     l=l.detach().cpu().numpy()
    #
    #
    #     tau, _ = kendalltau(w, l)
    #     frank_copula = Frank()
    #     frank_copula.tau = tau
    #
    #     print(f"\t tau: {tau}")
    #
    #     theta = frank_copula.compute_theta()
    #     frank_copula.theta = theta
    #     ranked_data = np.column_stack((w, l))
    #     joint_density = frank_copula.probability_density(ranked_data)
    #     return joint_density, frank_copula

    def frankc_n(self, w, l):
        # Ensure w and l are numpy arrays
        w = w.detach().cpu().numpy() if isinstance(w, torch.Tensor) else w
        l = l.detach().cpu().numpy() if isinstance(l, torch.Tensor) else l

        # Initialize the result arrays
        rows, cols = l.shape
        joint_density = np.zeros((rows, cols))
        partial_derivatives = np.zeros((rows, cols))

        # Iterate over each row in l
        for i in range(rows):
            tau, _ = kendalltau(w.flatten(), l[i])
            frank_copula = Frank()
            frank_copula.tau = tau


            theta = frank_copula.compute_theta()
            frank_copula.theta = theta
            ranked_data = np.column_stack((w.flatten(), l[i]))

            joint_density[i] = frank_copula.probability_density(ranked_data)
            partial_derivatives[i] = frank_copula.partial_derivative(ranked_data)

        # Convert results to tensors
        # joint_density_tensor = torch.tensor(joint_density, dtype=torch.float32)
        # partial_derivatives_tensor = torch.tensor(partial_derivatives, dtype=torch.float32)

        return joint_density, partial_derivatives

    def frankc_n_np(self, w, l):
        # Ensure w and l are numpy arrays
       
        w = self.replace_nans_with_zero(w) if isinstance(w, torch.Tensor) else w
        l = self.replace_nans_with_zero(l) if isinstance(l, torch.Tensor) else l
        w = w.detach().cpu().numpy() if isinstance(w, torch.Tensor) else w
        l = l.detach().cpu().numpy() if isinstance(l, torch.Tensor) else l

        # Initialize the result arrays
        rows, cols = l.shape
        joint_density = np.zeros((rows, cols))


        # Iterate over each row in l
        for i in range(rows):
            tau, _ = kendalltau(w.flatten(), l[i])
            frank_copula = Frank()
            frank_copula.tau = tau

            # print(f"\t tau for row {i}: {tau}")

            theta = frank_copula.compute_theta()
            frank_copula.theta = theta
            ranked_data = np.column_stack((w.flatten(), l[i]))

            joint_density[i] = frank_copula.probability_density(ranked_data)


        # Convert results to tensors
        # joint_density_tensor = torch.tensor(joint_density, dtype=torch.float32)

        return joint_density





    def euclidean_distance_1d(self, x, y):
        return np.sqrt(np.sum((x - y) ** 2))

    def euclidean_distance_2d(self, x, y):
        return  np.linalg.norm(x - y, axis=1, keepdims=True)**2

    def compute_copula_entropy(self, joint_density, epsilon=1e-20):
        # 加一个很小的常数以避免 log(0) 和乘以 0 的情况
        joint_density_safe = joint_density + epsilon
        # 计算 copula 熵
        copula_entropy = -joint_density_safe * np.log(np.abs(joint_density_safe))
        # 将原始 joint_density 为 0 的地方的熵值设置为 0
        copula_entropy[joint_density == 0] = 0
        return copula_entropy

    def gradient_w1(self, joint_density, joint_density2, partial_derivatives):
        data = np.vstack((joint_density, joint_density2))
        cov_matrix = np.cov(data, rowvar=False)
        inv_cov_matrix = np.linalg.inv(cov_matrix)

        diff = joint_density - joint_density2
        grad = partial_derivatives * diff.T @ inv_cov_matrix / mahalanobis(joint_density, joint_density2,
                                                                           inv_cov_matrix)
        return grad

    def partial_derivative_entropy_w1(self, density, partial_derivatives):

        entropy = partial_derivatives * (np.log(np.abs(density)) + 1)



        return entropy

    def euclidean_distance_and_gradient_torch(self,x, y, partial_derivatives):
        """
        使用PyTorch计算两个向量之间的欧氏距离及其对x的导数

        参数:
        x (numpy.ndarray): 第一个向量
        y (numpy.ndarray): 第二个向量

        返回:
        distance (float): 欧氏距离
        grad_x (numpy.ndarray): 欧氏距离对x的导数
        """
        grad = np.subtract(x, y)

        grad = partial_derivatives * grad


        return grad

    def replace_nans_with_zero(self,tensor):
        # 将NaN值替换为0
        tensor[torch.isnan(tensor)] = 0
        tensor[torch.isinf(tensor)] = 0
        return tensor




    def _get_weight_update(
        self,
        grads,
        target_weights,
        lam,
        ome,
        r,
        transport,
        entropy,
        reg,
        dset,
        subset_inds,
        device,
        num_workers,
        module_param_indices_list,
        sparsity,
        pruning_stage,
    ):



        n = len(grads)
        params_num = grads.shape[1]

        lam_torch = torch.tensor(lam, device=device)
        ome_torch = torch.tensor(ome, device=device)

        n_torch = torch.tensor(n, device=device)

        w, _ = self._get_weights()
        w_bar = copy(target_weights)

        w=w.to(device)

        print("\t w", w)
        print("\t w_bar", w_bar)



        w_r=self._compute_local_proportions_1d_numpy(w, r)
        w_bar_r=self._compute_local_proportions_1d_numpy(w_bar, r)

        # # 将边缘分布转换为均匀分布
        # w_r = stats.norm.cdf(w_r)
        # w_bar_r = stats.norm.cdf(w_bar_r)


        w_r_t =self.replace_nans_with_zero(torch.from_numpy(w_r)).to(device)
        w_bar_r_t = self.replace_nans_with_zero(torch.from_numpy(w_bar_r)).to(device)

        # w_r=w_r.cpu().numpy()
        # w_bar_r =w_bar_r.cpu().numpy()

        print(f"\t reg={reg}")
        non_zero_params_num = int(params_num * (1 - sparsity))
        grads, _, _, _ = self._compute_wgH(
            dset, subset_inds, device, num_workers, debug=False
        )

        grads=self.replace_nans_with_zero(grads)
        # if not transport:
        #     grads_r = torch.flatten(torch.sum(grads, dim=0, keepdim=True))

        # print(grads_r.shape)


        # grads_r=self._compute_local_proportions_2d_numpy(grads, r)

        # # 将边缘分布转换为均匀分布
        # grads_r = stats.norm.cdf(grads_r.cpu().numpy())

        #
        X = grads
        X_t=grads.to(device)


        print("\t X", X)
        print("\t w_r", w_r)
        print("\t w_bar_r", w_bar_r)

        print("\t X.shape", X.shape)
        print("\t w_r.shape", w_r.shape)
        print("\t w_bar_r.shape", w_bar_r.shape)

        # # 使用 torch.isnan 检查张量中是否存在 NaN
        # nan_mask1 = torch.isnan(w_bar_r)
        # nan_mask2 = torch.isnan(w_r)
        # nan_mask3 = torch.isnan(X)
        #
        #
        # # 检查是否有任何NaN值
        # has_nan1 = torch.any(nan_mask1)
        # has_nan2 = torch.any(nan_mask2)
        # has_nan3 = torch.any(nan_mask3)
        #
        # print(f"\t has_nan1: {has_nan1}")
        # print(f"\t has_nan2: {has_nan2}")
        # print(f"\t has_nan3: {has_nan3}")

        #这里用frank copula构建X和w_bar_r的联合分布
        joint_density_w_bar_r = self.frankc_n_np(w_bar_r, X)




        y = joint_density_w_bar_r




        joint_density_w, partial_derivatives = self.frankc_n(w_r, X)




        # print(f"\t joint_density_w: {joint_density_w}")
        # print(f"\t joint_density_w_bar_r: {joint_density_w_bar_r}")
        #Q是loss，PI是运输方案，M是距离矩阵
        Q, PI, M = self._pruning_objective(
            X=X_t,
            y=y,
            w=w_r_t,
            w_bar=w_bar_r_t,
            lam=lam,
            ome=ome,
            reg=reg,
            transport=transport,
            entropy=entropy,

        )


        print(f"\t Objective function value: {Q}")





        partial_deriv_w1 = self.partial_derivative_entropy_w1(joint_density_w, partial_derivatives)


       



        entropy_w1_l1 = self.compute_copula_entropy(joint_density_w)
        entropy_w2_l1 = self.compute_copula_entropy(joint_density_w_bar_r)

        entropy_w2_l1_sum = np.sum(entropy_w2_l1, axis=1)
        # 计算绝对值
        abs_entropy_w2_l1_sum = np.abs(entropy_w2_l1_sum)

        # 计算绝对值的总和
        total_abs_sum = np.sum(abs_entropy_w2_l1_sum)

        # 计算每个值占总和的比例
        proportions = abs_entropy_w2_l1_sum / total_abs_sum

        gong1_g = self.euclidean_distance_and_gradient_torch(joint_density_w, joint_density_w_bar_r,  partial_derivatives)
        gong1_g = (gong1_g.T @ proportions)


        gong3 = entropy_w1_l1 - entropy_w2_l1
        


        print(f"\t gong3.T.shape: {gong3.T.shape}")
        print(f"\t partial_deriv_w1.shape: {partial_deriv_w1.shape}")
        gong3_g = -np.sum(gong3 * partial_deriv_w1,axis=0)/100



        gong1_g = torch.from_numpy(gong1_g).to(X_t.device)

        gong3_g = torch.from_numpy(gong3_g).to(X_t.device)
        # noise_std=1e-15
        if not transport:
            print("\t Perform no transport update")
            if not entropy:
                delta_Qw = gong1_g+ n_torch * lam_torch * (w_r_t - w_bar_r_t)

            else:
                print(f"\t gong1_g: {gong1_g}")
                print(f"\t gong2_g: {(n_torch * lam_torch * (w_r_t - w_bar_r_t))}")
                print(f"\t gong3_g: {(n_torch * ome_torch*gong3_g)}")
                # delta_Qw = gong1_g + n_torch * ome_torch*gong3_g
                delta_Qw = gong1_g+ n_torch * lam_torch * (w_r_t - w_bar_r_t)+ n_torch * ome_torch*gong3_g
                # delta_Qw += noise_std * torch.randn_like(delta_Qw)
        else:
            print("\t Perform optimal transport update")
            if not entropy:
                delta_Qw = PI @ gong1_g + lam_torch * (
                            w_r_t - w_bar_r_t)
            else:
                delta_Qw = PI @ gong1_g + lam_torch * (
                            w_r_t - w_bar_r_t) + ome_torch*gong3_g

        # if not transport:
        #     print("\t Perform no transport update")
        #     if not entropy:
        #         delta_Qw = self.euclidean_distance_1d(joint_density_w, y) + n_torch * lam_torch * (
        #                     w_r - w_bar_r)
        #
        #     else:
        #         delta_Qw = self.euclidean_distance_1d(joint_density_w,y)+ n_torch * lam_torch * (w_r - w_bar_r)+ome_torch*(self.compute_copula_entropy(joint_density_w_bar_r)-self.compute_copula_entropy(joint_density_w))
        # else:
        #     print("\t Perform optimal transport update")
        #     if not entropy:
        #         delta_Qw = PI @ self.euclidean_distance_1d(joint_density_w, y) + lam_torch * (
        #                     w_r - w_bar_r)
        #     else:
        #         delta_Qw = PI @ self.euclidean_distance_1d(joint_density_w, y) + lam_torch * (
        #                     w_r - w_bar_r) + ome_torch * (
        #                                self.compute_copula_entropy(joint_density_w_bar_r) - self.compute_copula_entropy(
        #                            joint_density_w))

        print(f"\t delta_Qw={delta_Qw}")
        tau_c = self.__find_tau_c(a=w_r_t, b=delta_Qw, k=non_zero_params_num)
        #w前几个绝对值最大的位置不会改变
        print(f"\t tau_c = {tau_c}")

        tau_m = self.__find_minimizing_tau(
            a=w_r_t,
            b=delta_Qw,
            tau_c=tau_c,
            k=non_zero_params_num,
            X=X_t,
            y=y,
            w_bar=w_bar_r_t,
            lam=lam,
            ome=ome,
            reg=reg,
            transport=transport,
            entropy=entropy,
            PI=PI,
        )
        print(f"\t tau_m = {tau_m}")

        if tau_m < tau_c.to("cpu"):
            tau = tau_m
        else:
            print("\t tau_m >= tau_c, and we optimize tau with gamma")
            gamma = 1.05
            tau = tau_c
            Q_best, _, _ = self._pruning_objective(
                X=X_t,
                y=y,
                w=w_r_t - tau * delta_Qw,
                w_bar=w_bar_r_t,

                lam=lam,
                ome=ome,
                reg=reg,
                transport=transport,
                entropy=entropy,
                PI=PI,
            )
            Q_gamma_tau, _, _ = self._pruning_objective(
                X=X_t,
                y=y,
                w=w_r_t - gamma * tau * delta_Qw,
                w_bar=w_bar_r_t,
                lam=lam,
                ome=ome,
                reg=reg,
                transport=transport,
                entropy=entropy,
                PI=PI,
            )

            while Q_best > Q_gamma_tau:
                print("\t 开始死循环2")
                Q_best = Q_gamma_tau
                tau = gamma * tau
                Q_gamma_tau, _, _ = self._pruning_objective(
                    X=X_t,
                    y=y,
                    w=w_r_t - gamma * tau * delta_Qw,
                    w_bar=w_bar_r_t,
                    lam=lam,
                    ome=ome,
                    reg=reg,
                    entropy=entropy,
                    transport=transport,
                    PI=PI,
                )
                # print(f"\t tau={tau} with gamma={gamma}")
        print(f"\t tau = {tau}")

        tau = tau
        w_new_r = torch.clamp(w_r_t - tau * delta_Qw, min=0, max=1)
        print(f"\t tau*delta_Qw = {delta_Qw}")
        print(f"\t w_new = {w_new_r}")

        print(f"\t Non-zero value num of w_{pruning_stage}:", (w != 0).sum())
        print(f"\t w_{pruning_stage} = {w}")
        # self._target_weights = copy(w)


        w_new_r_p = self._hard_threshold(w_new_r, non_zero_params_num)

        print(f"\t Non-zero value num of w_{pruning_stage+1}", (w != 0).sum())
        print(f"\t w_{pruning_stage+1} = {w}")

        print("\t Model weights updated")

        print(f'Tensor A is on device: {w_new_r_p.device}')
        print(f'Tensor B is on device: {w.device}')


        # 使用 torch.where 将 tensor_a 中对应 tensor_b 为0的位置设置为0
        w = torch.where(w_new_r_p == 0, torch.tensor(0.0).to(device), w)




        self._set_weights(
            weight_updates=w,
            module_param_indices_list=module_param_indices_list,
            set_mask=True,
        )

        if self.args.ot and self.args.dump_ot_files:
            od = X @ w_bar
            ed = X @ w

            sorted_od, indices_Xwbar = torch.sort(od)
            sorted_ed, indices_Xw = torch.sort(ed)

            sorted_PI = PI[indices_Xwbar]
            sorted_PI = sorted_PI[:, indices_Xw]

            np.savetxt(
                f"./logs/ot_files/std_1_prop_01/od_reg={self.args.reg}_seed={self.args.seed}.csv",
                sorted_od.detach().cpu().numpy(),
                delimiter=",",
            )
            np.savetxt(
                f"./logs/ot_files/std_1_prop_01/ed_reg={self.args.reg}_seed={self.args.seed}.csv",
                sorted_ed.detach().cpu().numpy(),
                delimiter=",",
            )
            np.savetxt(
                f"./logs/ot_files/std_1_prop_01/PI_reg={self.args.reg}_seed={self.args.seed}.csv",
                sorted_PI.detach().cpu().numpy(),
                delimiter=",",
            )

    def on_epoch_begin(
        self, dset, subset_inds, device, num_workers, epoch_num, **kwargs
    ):
        meta = {}
        if epoch_num == 0 or self._pruner_not_active(epoch_num) or self._end == 1:
            print("Pruner is not ACTIVEEEE yaa!")
            self._target_weights, self._original_mask = self._get_weights()
            return False, {}

        CopulaPruner.have_not_started_pruning = False
        # ensure that the model is not in training mode, this is importance, because
        # otherwise the pruning procedure will interfere and affect the batch-norm statistics
        assert not self._model.training

        # reinit params if they were deleted during gradual pruning
        if not hasattr(self, "_all_grads"):
            self._all_grads = []
        if not hasattr(self, "_param_stats"):
            self._param_stats = []

        self._param_idx = 0

        flat_pruned_weights_list = []
        flat_module_weights_list = []
        module_shapes_list = []
        module_param_indices_list = []

        for idx, module in enumerate(self._modules):
            module_param_indices_list.append(self._param_idx)
            assert self._weight_only
            module_shapes_list.append(module.weight.shape)

            self._param_idx += module.weight.numel()

        for idx, module in enumerate(self._modules):
            flat_pruned_weights_list.append(module.weight.flatten())
            flat_module_weights_list.append(module.weight.flatten())

        module_param_indices_list.append(self._param_idx)

        flat_pruned_weights_list = flatten_tensor_list(flat_pruned_weights_list)
        flat_module_weights_list = flatten_tensor_list(flat_module_weights_list)

        # Compute the weight updates using the custom function
        grads, _, _, _ = self._compute_wgH(
            dset, subset_inds, device, num_workers, debug=False
        )

        pruning_stage = (epoch_num - self._start) // self._freq
        total_stages = (self._end - self._start) // self._freq
        print(f"PRUNING STAGE {pruning_stage}")

        # # linear increasing sparsity
        # sparsity = (
        #     pruning_stage / total_stages * self._target_sparsity
        # )

        # exponential increasing sparsity
        if total_stages > 1:
            sparsity = (
                self._target_sparsity
                + (self._initial_sparsity - self._target_sparsity)
                * (1 - (pruning_stage - 1) / (total_stages - 1)) ** 3
            )  # cubic increasing sparsity
        else:
            sparsity = self._target_sparsity

        print(f"Sparsity={sparsity}")
        self._get_weight_update(
            grads=grads,
            target_weights=self._target_weights,
            lam=0.0000000000001,
            ome=0.11,
            r=2500,
            transport=self.args.ot,
            entropy=self.args.et,
            reg=self.args.reg,
            dset=dset,
            subset_inds=subset_inds,
            device=device,
            num_workers=num_workers,
            module_param_indices_list=module_param_indices_list,
            sparsity=sparsity,
            pruning_stage=pruning_stage,
        )

        return True, meta
