import gc
import os
import math
from concurrent.futures.thread import ThreadPoolExecutor

from torch.autograd import Function
import numpy
import numpy as np
import pickle
import matplotlib.pyplot as plt
import ot

import torch
import torch.optim as optim
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn.functional as F
import time
import logging
import numpy as np
from scipy.stats import kendalltau
from copulas.bivariate import Frank
from scipy.spatial.distance import mahalanobis
import scipy.stats as stats

from utils.utils import (
    flatten_tensor_list,
    get_summary_stats,
    dump_tensor_to_mat,
    add_noise_to_grads,
)
from policies.pruners import GradualPruner
from utils.plot import analyze_new_weights
from common.io import _load, _dump
from common.timer import Timer
from copy import copy



class CopulaPruner(GradualPruner):
    have_not_started_pruning = True

    def __init__(self, model, inp_args, **kwargs):
        super(CopulaPruner, self).__init__(model, inp_args, **kwargs)
        print("In Optimal Transport Pruner")
        self._fisher_inv_diag = None
        self._prune_direction = inp_args.prune_direction
        self._zero_after_prune = inp_args.zero_after_prune
        self._inspect_inv = inp_args.inspect_inv
        self._fisher_mini_bsz = inp_args.fisher_mini_bsz
        if self._fisher_mini_bsz < 0:
            self._fisher_mini_bsz = 1
        if self.args.woodburry_joint_sparsify:
            self._param_stats = []
        if self.args.dump_fisher_inv_mat:
            self._all_grads = []
        if self.args.fisher_inv_path is None:
            N_samples = self.args.fisher_subsample_size * self.args.fisher_mini_bsz
            N_batches = self.args.fisher_subsample_size
            seed = self.args.seed
            self.args.fisher_inv_path = os.path.join(
                "./prob_regressor_data",
                f"{self.args.arch}_{self.args.dset}_{N_samples}samples_{N_batches}batches_{seed}seed.fisher_inv",
            )

    def _get_weights(self):
        weights = []
        masks = []

        for idx, module in enumerate(self._modules):
            assert self._weight_only
            weights.append(module.weight.data.flatten())
            masks.append(module.weight_mask.data.flatten())
        weights = torch.cat(weights).to(module.weight.device)
        masks = torch.cat(masks).to(module.weight_mask.device)
        return weights, masks

    def _compute_sample_fisher(self, loss, return_outer_product=False):
        """Compute sample Fisher information matrix.

        Args:
            loss (Tensor): Loss scalar or B, Bx1 tensor.
            return_outer_product (bool, optional): Whether to return the outer product of gradients. Defaults to False.

        Returns:
            grads (Tensor): Flattened gradients, BxD.
            gTw (Tensor or None): Grads^T @ weights, B.
            params (Tensor): Flattened parameters, D.
            ff (Tensor or None): Sum of grads * grads^T, DxD.
        """

        # Collect parameters to compute gradients
        params = []
        for module in self._modules:
            for name, param in module.named_parameters():
                if not (self._weight_only and "bias" in name):
                    params.append(param)

        # Compute gradients with respect to the parameters
        grads = torch.autograd.grad(loss, params, create_graph=False, retain_graph=False)

        # Optionally add noise to the gradients
        if self.args.add_noise:
            noise_values = self.args.add_noise
            grads = add_noise_to_grads(grads, noise_std_scale=noise_values[0], prop=noise_values[1])

        # Apply weight mask to gradients and parameters
        for idx, (grad, module) in enumerate(zip(grads, self._modules)):
            grad.data.mul_(module.weight_mask)
            params[idx].data.mul_(module.weight_mask)

        # Ensure each gradient is 2D (batch size, flattened parameter size)
        grads = [grad.view(grad.size(0), -1) for grad in grads]
        grads = flatten_tensor_list(grads).to(device='cpu', dtype=torch.float32)
        params = flatten_tensor_list(params).to(device='cpu', dtype=torch.float32)

        # Optionally store gradients for debugging or analysis
        if self.args.dump_grads_mat:
            self._all_grads.append(grads)

        self._num_params = grads.numel()
        self._old_weights = params

        gTw = None

        if return_outer_product:
            # Compute and return the outer product of gradients
            outer_product = torch.matmul(grads.T, grads).to(device='cpu', dtype=torch.float32)
            return grads, outer_product, gTw, params
        else:
            return grads, None, gTw, params

    def _get_pruned_wts_scaled_basis(self, pruned_params, flattened_params):
        # import pdb;pdb.set_trace()
        return -1 * torch.div(
            torch.mul(pruned_params, flattened_params), self._fisher_inv_diag
        )

    def _release_grads(self):
        optim.SGD(self._model.parameters(), lr=1e-10).zero_grad()

    def _compute_wgH(self, dset, subset_inds, device, num_workers, debug=False):
        # 备份原始的权重掩码并将它们设置为全1
        backup_masks = [module.weight_mask for module in self._modules]
        for module in self._modules:
            module.weight_mask = torch.ones_like(module.weight_mask)

        st_time = time.perf_counter()

        self._model.to(device)

        print("in copula pruning: len of subset_inds is ", len(subset_inds))

        goal = self.args.fisher_subsample_size
        assert len(subset_inds) == goal * self.args.fisher_mini_bsz

        # 使用 PyTorch 的 DataLoader 加载数据
        dummy_loader = torch.utils.data.DataLoader(
            dset,
            batch_size=self._fisher_mini_bsz,
            num_workers=num_workers,
            sampler=SubsetRandomSampler(subset_inds),
            pin_memory=True  # 优化数据加载
        )

        Gs = []
        GTWs = []

        aux_device = torch.device(
            "cuda:{}".format(self.args.aux_gpu_id)) if self.args.aux_gpu_id != -1 else torch.device("cpu")

        criterion = torch.nn.functional.cross_entropy if self.args.disable_log_soft else F.nll_loss

        num_batches = 0
        num_samples = 0

        for in_tensor, target in dummy_loader:
            self._release_grads()

            in_tensor, target = in_tensor.to(device, non_blocking=True), target.to(device, non_blocking=True)
            output = self._model(in_tensor)
            loss = criterion(output, target, reduction="mean")

            # 计算 Fisher 信息
            g, _, gTw, w = self._compute_sample_fisher(loss, return_outer_product=False)

            # 将张量移动到 CPU 上，减少显存使用，确保 g 是 2D 形式
            g_cpu = g.detach().cpu()

            # 如果 g 不是二维的，转换为二维形式
            if g_cpu.dim() == 1:
                g_cpu = g_cpu.unsqueeze(0)  # 将一维张量转换为二维张量

            Gs.append(g_cpu)

            del g, gTw
            torch.cuda.empty_cache()  # 释放显存

            num_batches += 1
            num_samples += self._fisher_mini_bsz
            if num_samples == goal * self._fisher_mini_bsz:
                break

        # 将所有梯度合并到 CPU 上，确保合并后的grads是二维张量
        grads = torch.cat(Gs, 0).to('cpu')

        print(
            "# of examples done {} and the goal (#outer products) is {}".format(
                num_samples, goal
            )
        )
        print("# of batches done {}".format(num_batches))

        end_time = time.perf_counter()
        print(
            "Time taken to compute empirical fisher is {} seconds".format(
                str(end_time - st_time)
            )
        )

        # 恢复原始的权重掩码
        for idx, module in enumerate(self._modules):
            module.weight_mask = backup_masks[idx]

        # 释放显存，因为我们已经在 CPU 上合并了 grads
        torch.cuda.empty_cache()

        return grads, GTWs, w, None

    def __top_k_indice(self, vector, k):
        """Function to get the indices of the k largest values of a vector"""
        _, indices = torch.topk(vector.abs(), k)
        return set(indices.tolist())

    def __same_support(self, a, b, k):
        """Check if two vectors have the same support"""
        return self.__top_k_indice(a, k) == self.__top_k_indice(b, k)

    def __find_tau_c(self, a, b, k, eps=1e-12):
        """Find the largest tau such that the support of a-tau*b is the same as a"""
        low, high = 0, min(torch.max(a / (b + eps)), 100)

        count = 0
        while (high - low > eps) and (count < 50):
            count += 1
            # print(f'low={low}, high={high}')
            mid = (low + high) / 2
            if self.__same_support(a, a - mid * b, k):
                low = mid
            else:
                high = mid

        return torch.tensor(low)

    class CustomLoss(Function):
        @staticmethod
        def forward(ctx, a, b, X, w_bar, lam, ome, CopulaPruner, tau):
            lam_torch = torch.tensor(lam, device=X.device, dtype=torch.float32)
            ome_torch = torch.tensor(ome, device=X.device, dtype=torch.float32)

            w = (a - tau * b).to(X.device, dtype=torch.float32)

            w = torch.clamp(w, min=0, max=1)
            w_bar = torch.clamp(w_bar, min=0, max=1)

            joint_density_w = CopulaPruner.frankc_n_np(w, X)
            joint_density_w_bar_r, partial_derivatives = CopulaPruner.frankc_n(w_bar, X)

            ctx.save_for_backward(w, X, b, w_bar, lam_torch, ome_torch,
                                  torch.from_numpy(joint_density_w_bar_r).to(lam_torch.device, dtype=torch.float32),
                                  torch.from_numpy(joint_density_w).to(lam_torch.device, dtype=torch.float32),
                                  torch.from_numpy(partial_derivatives).to(lam_torch.device, dtype=torch.float32))
            ctx.CopulaPruner = CopulaPruner

            n = X.shape[0]

            entropy_w1_l1 = CopulaPruner.compute_copula_entropy(joint_density_w_bar_r)
            entropy_w2_l1 = CopulaPruner.compute_copula_entropy(joint_density_w)

            entropy_w1_l1 = torch.from_numpy(entropy_w1_l1).to(lam_torch.device, dtype=torch.float32)
            entropy_w2_l1 = torch.from_numpy(entropy_w2_l1).to(lam_torch.device, dtype=torch.float32)

            diff = torch.sum(torch.linalg.norm(entropy_w1_l1 - entropy_w2_l1, dim=1), dim=0) / n

            loss = (1 / 2) * torch.sum(torch.from_numpy(
                np.array([CopulaPruner.euclidean_distance_2d(joint_density_w, joint_density_w_bar_r)])).to(
                lam_torch.device, dtype=torch.float32)) + \
                   (1 / 2) * lam_torch * torch.linalg.norm(w - w_bar) ** 2 + \
                   (1 / 2) * (ome_torch * diff)

            assert not torch.isnan(loss).any(), "loss contains NaN"

            print(f"\t loss: {loss}")
            return loss

        @staticmethod
        def backward(ctx, grad_output):
            w, X, b, w_bar, lam_torch, ome_torch, joint_density_w_bar_r, joint_density_w, partial_derivatives = ctx.saved_tensors
            CopulaPruner = ctx.CopulaPruner

            n = X.shape[0]
            n_torch = torch.tensor(n, device=w.device, dtype=torch.float32)
            joint_density_w_bar_r = joint_density_w_bar_r.cpu().numpy()
            joint_density_w = joint_density_w.cpu().numpy()
            partial_derivatives = partial_derivatives.cpu().numpy()

            partial_deriv_w1 = CopulaPruner.partial_derivative_entropy_w1(joint_density_w, partial_derivatives)
            gong1_g = CopulaPruner.euclidean_distance_and_gradient_torch(joint_density_w, joint_density_w_bar_r,
                                                                         partial_derivatives)

            entropy_w1_l1 = CopulaPruner.compute_copula_entropy(joint_density_w)
            entropy_w2_l1 = CopulaPruner.compute_copula_entropy(joint_density_w_bar_r)

            gong3 = np.abs(entropy_w1_l1 - entropy_w2_l1)
            gong3_g = np.sum(gong3 * -partial_deriv_w1, axis=0)

            gong1_g = torch.from_numpy(gong1_g).to(w.device, dtype=torch.float32)
            gong3_g = torch.from_numpy(gong3_g).to(w.device, dtype=torch.float32)

            grad_w = gong1_g + lam_torch * torch.abs(w - w_bar) + ome_torch * gong3_g

            # 将b转换为与grad_w相同的数据类型
            b = b.to(w.device, dtype=torch.float32)

            k = -(b @ grad_w)
            print(f"\t k: {k}")

            min_value = -5e-07

            if torch.isnan(k):
                k = torch.tensor([min_value], dtype=torch.float32).to(w.device)

            grad_tau = (k) * grad_output
            grad_tau = grad_tau.view(-1)

            return None, None, None, None, None, None, None, grad_tau

    def __find_minimizing_tau(
            self, a, b, tau_c, k, X, y, w_bar, lam, ome, reg, transport, entropy, PI, lr=0.2, eps=8e-10
    ):
        """Find the tau that minimizes Q"""
        tau = torch.tensor([0.5], requires_grad=True, device=X.device)
        tau.data = tau.data.clamp(min=0, max=tau_c - eps)

        optimizer = optim.Adam([tau], lr=lr)

        # Perform hard thresholding on 'a' and move to CPU
        a = self._hard_threshold(a, k).to("cpu")
        b = b.to("cpu")

        # Apply mask to 'b'
        mask = (a != 0).float()
        b *= mask

        # Ensure data is on the correct device
        a_device = a.to(X.device)
        b_device = b.to(X.device)

        for i in range(20):
            optimizer.zero_grad()
            # Use torch.no_grad() to avoid storing gradients when not needed
            loss = self.CustomLoss.apply(a_device, b_device, X, w_bar, lam, ome, self, tau)

            # Perform backward pass and optimize
            loss.backward()
            optimizer.step()



            # Check gradients and stop if condition is met
            if tau.grad is None:
                raise Exception("Sorry, grads seem to be None")

            # Moving tau.grad back to CPU for comparison
            if tau.grad.abs().item() < eps or tau.item() >= tau_c or tau.item() <= 0:
                break

            # Free up unused variables to reduce memory usage
            del loss
            torch.cuda.empty_cache()

        # Clamp tau to be within valid range
        tau.data = tau.data.clamp(min=0, max=tau_c)

        print(f"\t optimized tau = {tau.data}")
        return tau.data

    def _hard_threshold(self, x, k):
        # Set all but the largest k elements in x to zero
        if k < 1:
            return torch.zeros(len(x))
        weights = x.clone()
        threshold = (
            weights.abs().flatten().kthvalue(int(weights.numel() - k + 1), dim=-1)[0]
        )
        weights[weights.abs() < threshold] = 0
        return weights

    def _set_weights(self, weight_updates, module_param_indices_list, set_mask):
        for idx, module in enumerate(self._modules):
            weight = weight_updates[
                module_param_indices_list[idx] : module_param_indices_list[idx + 1]
            ]
            weight = weight.view(module.weight.shape)

            mask = (weight != 0).float()
            if set_mask:
                module.weight_mask = mask

            with torch.no_grad():
                module.weight.data = module.weight.data * mask

    def _pruning_objective(self, X, y, w, w_bar, lam, ome, reg, entropy, transport, PI=None):
        n = X.shape[0]


        device = w.device
        lam_torch = torch.tensor(lam, device=device, dtype=torch.float32)
        ome_torch = torch.tensor(ome, device=device, dtype=torch.float32)


        w = torch.clamp(w, min=0, max=1).to(device=device, dtype=torch.float32)
        w_bar = torch.clamp(w_bar, min=0, max=1).to(device=device, dtype=torch.float32)


        joint_density_w = self.frankc_n_np(w, X).astype(np.float32)
        joint_density_w_bar_r = self.frankc_n_np(w_bar, X).astype(np.float32)


        entropy_w1_l1 = torch.from_numpy(self.compute_copula_entropy(joint_density_w_bar_r)).to(device,
                                                                                                dtype=torch.float32)
        entropy_w2_l1 = torch.from_numpy(self.compute_copula_entropy(joint_density_w)).to(device, dtype=torch.float32)


        diff = torch.linalg.norm(entropy_w1_l1 - entropy_w2_l1, dim=1).sum(dtype=torch.float32) / n


        torch.cuda.empty_cache()




        if not entropy:
            euclidean_dist = torch.from_numpy(
                np.array([self.euclidean_distance_2d(joint_density_w, joint_density_w_bar_r).astype(np.float32)])
            ).to(device).sum()
            Q = euclidean_dist + (n / 2) * lam_torch * torch.linalg.norm(w - w_bar) ** 2
        else:
            euclidean_dist = torch.from_numpy(
                np.array([self.euclidean_distance_2d(joint_density_w, joint_density_w_bar_r).astype(np.float32)])
            ).to(device).sum()
            Q = 0.5 * euclidean_dist + 0.5 * lam_torch * torch.linalg.norm(w - w_bar) ** 2 + 0.5 * ome_torch * diff


        del w, w_bar, lam_torch, ome_torch, euclidean_dist, diff, PI
        del joint_density_w, joint_density_w_bar_r, entropy_w1_l1, entropy_w2_l1
        torch.cuda.empty_cache()

        return Q, None, None

    def _get_transportation_cost(self, joint_density_w,joint_density_w_bar_r):
        original_distr = joint_density_w
        embedded_distr = joint_density_w_bar_r

        n = joint_density_w.shape[0]

        original_distr = original_distr.detach().cpu().numpy()
        embedded_distr = embedded_distr.detach().cpu().numpy()

        original_distr = original_distr
        embedded_distr = embedded_distr

        # Compute the cost matrix (squared Euclidean distance) between original_distr and embedded_distr
        M = ot.dist(
            original_distr.reshape(n, 1),
            embedded_distr.reshape(n, 1),
            metric="sqeuclidean",
        )

        return M

    def _get_transportation_plan(self, joint_density_w,joint_density_w_bar_r, reg, transport):

        n = len(joint_density_w)
        if not transport:
            return torch.eye(n) * 1 / n, None

        M = self._get_transportation_cost(joint_density_w,joint_density_w_bar_r)

        original_distr_mass = [1 / n for i in range(n)]
        embedded_distr_mass = [1 / n for i in range(n)]#往后的权重逐步扩大

        if reg == "inf":
            PI = np.ones((n, n)) / n**2
        elif reg <= 1e-10:
            PI = ot.emd(original_distr_mass, embedded_distr_mass, M)
        else:
            PI = ot.bregman.sinkhorn(
                original_distr_mass, embedded_distr_mass, M, reg=reg, numItermax=5000
            )
            # PI = ot.bregman.sinkhorn_epsilon_scaling(
            #     original_distr_mass, embedded_distr_mass, M, reg=reg, numItermax=5000
            # )

        return torch.from_numpy(PI).float().to(joint_density_w.device), torch.from_numpy(
            M
        ).float().to(joint_density_w.device)


    def _compute_local_proportions_1d_numpy(self,tensor, r):

        tensor=tensor.cpu().numpy()
        length = len(tensor)
        result = np.zeros_like(tensor, dtype=np.float32)

        for i in range(length):
            if tensor[i] == 0:
                continue

            start = max(0, i - r)
            end = min(length, i + r + 1)
            window = tensor[start:end]

            window_abs = np.abs(window)
            window_sum = np.sum(window_abs)

            if window_sum == 0:
                result[i] = 0
            else:
                result[i] = np.abs(tensor[i]) / window_sum

        return result

    def _compute_local_proportions_1d_tensor(self,tensor, r):
        # 确保输入 tensor 在 CPU 上，因为我们需要对其进行逐元素操作
        tensor = tensor
        length = tensor.size(0)
        result = torch.zeros_like(tensor, dtype=torch.float32)


        for i in range(length):
            if tensor[i] == 0:
                continue

            start = max(0, i - r)
            end = min(length, i + r + 1)
            window = tensor[start:end]


            window_abs = window.abs()
            window_sum = window_abs.sum()

            if window_sum == 0:
                result[i] = 0
            else:
                result[i] = tensor[i].abs() / window_sum

        return result

    def _compute_local_proportions_w(self,tensor):


        abs_t = torch.abs(tensor)


        sum_abs_t = torch.sum(abs_t)


        t_normalized = abs_t / sum_abs_t

        return t_normalized

    def _compute_local_proportions_2d(self,tensor):

        row_sums = tensor.abs().sum(dim=1, keepdim=True)

        # 每个元素除以其所在行的总和，得到占比
        grads_normalized = tensor.abs() / row_sums

        p=row_sums/row_sums.sum(dim=0, keepdim=True)

        return grads_normalized * p

    def _compute_local_proportions_2d_r(self, tensor, r):

        rows, cols = tensor.shape

        result = torch.zeros_like(tensor, dtype=torch.float32)


        row_sums = tensor.abs().sum(dim=1, keepdim=True)

        p = row_sums / row_sums.sum(dim=0, keepdim=True)


        for row in range(rows):
            for col in range(cols):
                if tensor[row, col] == 0:
                    continue


                start = max(0, col - r)
                end = min(cols, col + r + 1)


                window = tensor[row, start:end]


                window_abs = torch.abs(window)
                window_sum = torch.sum(window_abs)


                if window_sum == 0:
                    result[row, col] = 0
                else:
                    result[row, col] = torch.abs(tensor[row, col]) / window_sum


        result = result * p

        return result



    # Modify frankc_n_np to use float32
    def frankc_n_np(self, w, l, batch_size=50):
        # Convert input data to numpy arrays and use float32
        if isinstance(w, torch.Tensor):
            w = self.replace_nans_with_zero(w).detach().cpu().numpy().astype(np.float32)
        else:
            w = self.replace_nans_with_zero(w).astype(np.float32)

        if isinstance(l, torch.Tensor):
            l = l.detach().cpu().numpy().astype(np.float32)

        # Get shape information
        rows, cols = l.shape

        # Initialize result array as float32
        joint_density = np.zeros((rows, cols), dtype=np.float32)

        # Flatten w to avoid repeated computation in the loop
        w_flat = w.flatten()

        # Define a function to compute Kendall tau and joint probability density
        def compute_joint_density(i, w_flat=w_flat):
            # Compute Kendall tau correlation coefficient
            tau, _ = kendalltau(w_flat, l[i])
            tau = 0 if np.isnan(tau) or np.isinf(tau) else tau

            # Initialize Frank Copula and set tau and theta
            frank_copula = Frank()
            frank_copula.tau = tau
            frank_copula.theta = frank_copula.compute_theta()

            # Compute joint probability density
            ranked_data = np.column_stack((w_flat, l[i]))
            return frank_copula.probability_density(ranked_data).astype(np.float32)

        # Batch parallel computation
        for batch_start in range(0, rows, batch_size):
            batch_end = min(batch_start + batch_size, rows)
            with ThreadPoolExecutor() as executor:
                results = list(executor.map(compute_joint_density, range(batch_start, batch_end)))

            # Fill result array
            joint_density[batch_start:batch_end] = np.array(results, dtype=np.float32)

            del results  # Free memory

        return joint_density

    def frankc_n(self, w, l, batch_size=50):
        # Ensure w and l are numpy arrays and convert to float32
        if isinstance(w, torch.Tensor):
            w = w.detach().cpu().numpy().astype(np.float32)
        if isinstance(l, torch.Tensor):
            l = l.detach().cpu().numpy().astype(np.float32)

        # Get shape information
        rows, cols = l.shape

        # Preallocate result arrays as float32
        joint_density = np.zeros((rows, cols), dtype=np.float32)
        partial_derivatives = np.zeros((rows, cols), dtype=np.float32)

        # Flatten w to avoid repeated computation in the loop
        w_flat = w.flatten()

        # Define a function for parallel computation
        def compute_density_and_derivative(i, w_flat=w_flat):
            # Compute Kendall tau
            tau, _ = kendalltau(w_flat, l[i])
            if np.isnan(tau) or np.isinf(tau):
                tau = 0

            # Initialize Frank Copula and compute theta
            frank_copula = Frank()
            frank_copula.tau = tau
            frank_copula.theta = frank_copula.compute_theta()

            # Compute joint density and partial derivatives
            ranked_data = np.column_stack((w_flat, l[i]))
            joint_density_row = frank_copula.probability_density(ranked_data).astype(np.float32)
            partial_derivative_row = frank_copula.partial_derivative(ranked_data).astype(np.float32)

            return joint_density_row, partial_derivative_row

        # Batch parallel computation
        for batch_start in range(0, rows, batch_size):
            batch_end = min(batch_start + batch_size, rows)
            with ThreadPoolExecutor() as executor:
                results = list(executor.map(compute_density_and_derivative, range(batch_start, batch_end)))

            # Fill result arrays
            for i, (joint_density_row, partial_derivative_row) in enumerate(results):
                joint_density[batch_start + i] = joint_density_row
                partial_derivatives[batch_start + i] = partial_derivative_row

            del results  # Free memory

        return joint_density, partial_derivatives

    def euclidean_distance_1d(self, x, y):
        return np.linalg.norm(x - y)

    def euclidean_distance_2d(self, x, y):
        return  np.linalg.norm(x - y, axis=1, keepdims=True)**2

    def compute_copula_entropy(self, joint_density):
        res = -joint_density * np.log(joint_density)
        res = np.where(np.isnan(res), 0.0, res)
        return res

    def gradient_w1(self, joint_density, joint_density2, partial_derivatives):
        data = np.vstack((joint_density, joint_density2))
        cov_matrix = np.cov(data, rowvar=False)
        inv_cov_matrix = np.linalg.inv(cov_matrix)

        diff = joint_density - joint_density2
        mahal_dist = mahalanobis(joint_density, joint_density2, inv_cov_matrix)

        # 使用矢量化操作替代矩阵乘法
        grad = np.einsum('ij,jk->ik', partial_derivatives, diff.T @ inv_cov_matrix) / mahal_dist
        return grad

    def partial_derivative_entropy_w1(self, density, partial_derivatives):
        entropy = partial_derivatives * (np.log(np.abs(density)) + 1)
        entropy = np.where(np.isnan(entropy), partial_derivatives, entropy)
        return entropy

    def euclidean_distance_and_gradient_torch(self,x, y, partial_derivatives):
        """
        使用PyTorch计算两个向量之间的欧氏距离及其对x的导数

        参数:
        x (numpy.ndarray): 第一个向量
        y (numpy.ndarray): 第二个向量

        返回:
        distance (float): 欧氏距离
        grad_x (numpy.ndarray): 欧氏距离对x的导数
        """
        grad = np.subtract(x, y)

        grad = np.sum(partial_derivatives * grad, axis=0)


        return grad

    def replace_nans_with_zero(self, tensor):

        if torch.is_floating_point(tensor):

            device = tensor.device
            tensor = torch.where(
                torch.isnan(tensor) | torch.isinf(tensor),
                torch.zeros(1, dtype=tensor.dtype, device=device),
                tensor
            )
        return tensor

    def _get_weight_update(
            self,
            grads,
            target_weights,
            lam,
            ome,
            r,
            transport,
            entropy,
            reg,
            dset,
            subset_inds,
            device,
            num_workers,
            module_param_indices_list,
            sparsity,
            pruning_stage,
    ):
        # 打印显存使用情况的详细信息
        print(f"\t 开始={torch.cuda.memory_summary(device=None, abbreviated=False)}")
        # Convert all constants and parameters to the target device and float32 precision
        lam_torch = torch.tensor(lam, device=device, dtype=torch.float32)
        ome_torch = torch.tensor(ome, device=device, dtype=torch.float32)

        # Get weights and convert to device and float32 precision
        w, _ = self._get_weights()
        w_bar = copy(target_weights).to(device)
        w = w.to(device)

        print(f"\t w: {w}")
        print(f"\t w_bar: {w_bar}")
        print(f"\t lenw_bar={len(w_bar)}")

        # Compute local proportions and convert to float32
        w_r = self.replace_nans_with_zero(self._compute_local_proportions_w(w)).to(device)
        w_bar_r = self.replace_nans_with_zero(self._compute_local_proportions_w(w_bar)).to(device)

        print(f"\t reg={reg}")

        # Compute gradients
        grads, _, _, _ = self._compute_wgH(dset, subset_inds, device, num_workers, debug=False)
        print(f"\t lengrads={len(grads)}")

        # Process grads in batches
        num_batches = 20  # Split grads into 10 batches
        batch_size = grads.shape[0] // num_batches
        delta_Qw_total = torch.zeros_like(w_r)
        tau_total = torch.tensor(0.0, device=device, dtype=torch.float32)

        non_zero_params_num = int(grads.shape[1] * (1 - sparsity))


        # 打印显存使用情况的详细信息
        print(f"\t batch开始前={torch.cuda.memory_summary(device=None, abbreviated=False)}")


        for batch_idx in range(num_batches):
            print(f"\t batch_idx: {batch_idx}")
            # Get the current batch of grads
            batch_start = batch_idx * batch_size
            batch_end = (batch_idx + 1) * batch_size
            grads_batch = grads[batch_start:batch_end].to(device)  # Move each batch to device and convert to float32

            # Compute local proportions
            grads_batch = self.replace_nans_with_zero(self._compute_local_proportions_2d(grads_batch)).to(device)

            X = grads_batch
            X_t = grads_batch

            print(f"\t X: {X}")
            print(f"\t w_r: {w_r}")
            print(f"\t w_bar_r: {w_bar_r}")
            print(f"\t X.shape: {X.shape}, w_r.shape: {w_r.shape}, w_bar_r.shape: {w_bar_r.shape}")

            # Construct joint distribution and convert to float32
            joint_density_w_bar_r = self.frankc_n_np(w_bar_r, X).astype(np.float32)
            joint_density_w, partial_derivatives = self.frankc_n(w_r, X)
            joint_density_w = joint_density_w.astype(np.float32)
            partial_derivatives = partial_derivatives.astype(np.float32)

            # Compute the objective function value and related variables
            Q, PI, M = self._pruning_objective(
                X=X_t,
                y=joint_density_w_bar_r,
                w=w_r,
                w_bar=w_bar_r,
                lam=lam,
                ome=ome,
                reg=reg,
                transport=transport,
                entropy=entropy,
            )

            print(f"\t Objective function value: {Q}")

            # Compute derivatives and update direction
            partial_deriv_w1 = self.partial_derivative_entropy_w1(joint_density_w, partial_derivatives).astype(
                np.float32)
            gong1_g = self.euclidean_distance_and_gradient_torch(joint_density_w, joint_density_w_bar_r,
                                                                 partial_derivatives)

            # Compute copula entropy differences
            entropy_w1_l1 = self.compute_copula_entropy(joint_density_w).astype(np.float32)
            entropy_w2_l1 = self.compute_copula_entropy(joint_density_w_bar_r).astype(np.float32)
            gong3 = np.abs(entropy_w1_l1 - entropy_w2_l1).astype(np.float32)

            print(f"\t gong3.T.shape: {gong3.T.shape}, partial_deriv_w1.shape: {partial_deriv_w1.shape}")
            gong3_g = torch.from_numpy(np.sum(gong3 * -partial_deriv_w1, axis=0)).to(device).float()

            gong1_g = torch.from_numpy(gong1_g).to(device).float()

            # Choose the update strategy
            if not transport:
                delta_Qw = (
                    gong1_g + lam_torch * (w_r - w_bar_r)
                    if not entropy else
                    gong1_g + lam_torch * torch.abs(w_r - w_bar_r) + ome_torch * gong3_g
                )
            else:
                delta_Qw = PI @ gong1_g + lam_torch * (w_r - w_bar_r) + (ome_torch * gong3_g if entropy else 0)

            print(f"\t delta_Qw={delta_Qw}")
            tau_c = self.__find_tau_c(a=w_r, b=delta_Qw, k=non_zero_params_num)
            print(f"\t tau_c = {tau_c}")

            tau_m = self.__find_minimizing_tau(
                a=w_r,
                b=delta_Qw,
                tau_c=tau_c,
                k=non_zero_params_num,
                X=X_t,
                y=joint_density_w_bar_r,
                w_bar=w_bar_r,
                lam=lam,
                ome=ome,
                reg=reg,
                transport=transport,
                entropy=entropy,
                PI=PI,
            )
            print(f"\t tau_m = {tau_m}")

            tau = tau_m if tau_m < tau_c else self.optimize_tau_with_gamma(
                tau_c, delta_Qw, w_r, w_bar_r, X_t, joint_density_w_bar_r, lam, ome, reg, transport, entropy, PI
            )

            print(f"\t tau = {tau}")
            print(f"\t tau_total = {tau_total}")
            # Accumulate delta_Qw and tau
            delta_Qw_total += delta_Qw

            print(f"\t delta_Qw_total = {delta_Qw_total}")
            tau_total= tau.view(-1).to(device)+tau_total.view(-1).to(device)

            print(f"\t batch结束={torch.cuda.memory_summary(device=None, abbreviated=False)}")
            # Release variables not needed for the current batch
            del grads_batch, X, X_t, joint_density_w_bar_r, joint_density_w, partial_derivatives, partial_deriv_w1, gong1_g, entropy_w1_l1, entropy_w2_l1, gong3, gong3_g, delta_Qw, PI, M
            torch.cuda.empty_cache()

            # 打印显存使用情况的详细信息
            print(f"\t 释放后={torch.cuda.memory_summary(device=None, abbreviated=False)}")

        # Take the average step size
        tau_avg = tau_total / num_batches
        delta_Qw_total= delta_Qw_total / num_batches
        # Update weights using the accumulated delta_Qw and average tau
        w_new_r = w_r - tau_avg * delta_Qw_total
        print(f"\t w_new_r = {w_new_r}")

        print(f"\t Non-zero value num of w_new_r{pruning_stage + 1}: {(w_new_r != 0).sum()}")

        w_new_r_p = self._hard_threshold(w_new_r, non_zero_params_num)
        print(f"\t Non-zero value num of w_{pruning_stage + 1}: {(w != 0).sum()}")
        print(f"\t Non-zero value num of w_new_r_p{pruning_stage + 1}: {(w_new_r_p != 0).sum()}")

        # Update weights
        w = torch.where(w_new_r_p == 0, torch.tensor(0.0, device=device, dtype=torch.float32), w)
        self._set_weights(weight_updates=w, module_param_indices_list=module_param_indices_list, set_mask=True)

        # Save OT files if needed
        if self.args.ot and self.args.dump_ot_files:
            self.save_ot_files(X, w_bar, w, PI)

        # 调用垃圾回收
        gc.collect()

        # 清空缓存的 GPU 显存
        torch.cuda.empty_cache()

        # 打印显存使用情况的详细信息
        print(f"\t 最终={torch.cuda.memory_summary(device=None, abbreviated=False)}")


    def optimize_tau_with_gamma(
            self, tau_c, delta_Qw, w_r, w_bar_r, X_t, y, lam, ome, reg, transport, entropy, PI
    ):
        gamma = 1.05
        tau = tau_c
        Q_best, _, _ = self._pruning_objective(X=X_t, y=y, w=w_r - tau * delta_Qw, w_bar=w_bar_r, lam=lam, ome=ome,
                                               reg=reg, transport=transport, entropy=entropy, PI=PI)
        Q_gamma_tau, _, _ = self._pruning_objective(X=X_t, y=y, w=w_r - gamma * tau * delta_Qw, w_bar=w_bar_r, lam=lam,
                                                    ome=ome, reg=reg, transport=transport, entropy=entropy, PI=PI)

        while Q_best > Q_gamma_tau:
            Q_best = Q_gamma_tau
            tau = gamma * tau
            Q_gamma_tau, _, _ = self._pruning_objective(X=X_t, y=y, w=w_r - gamma * tau * delta_Qw, w_bar=w_bar_r,
                                                        lam=lam, ome=ome, reg=reg, transport=transport, entropy=entropy,
                                                        PI=PI)

        return tau

    def save_ot_files(self, X, w_bar, w, PI):
        od = X @ w_bar
        ed = X @ w
        sorted_od, indices_Xwbar = torch.sort(od)
        sorted_ed, indices_Xw = torch.sort(ed)
        sorted_PI = PI[indices_Xwbar][:, indices_Xw]

        np.savetxt(f"./logs/ot_files/std_1_prop_01/od_reg={self.args.reg}_seed={self.args.seed}.csv",
                   sorted_od.detach().cpu().numpy(), delimiter=",")
        np.savetxt(f"./logs/ot_files/std_1_prop_01/ed_reg={self.args.reg}_seed={self.args.seed}.csv",
                   sorted_ed.detach().cpu().numpy(), delimiter=",")
        np.savetxt(f"./logs/ot_files/std_1_prop_01/PI_reg={self.args.reg}_seed={self.args.seed}.csv",
                   sorted_PI.detach().cpu().numpy(), delimiter=",")

    def on_epoch_begin(
        self, dset, subset_inds, device, num_workers, epoch_num, **kwargs
    ):
        meta = {}
        if epoch_num == 0 or self._pruner_not_active(epoch_num) or self._end == 1:
            print("Pruner is not ACTIVEEEE yaa!")
            self._target_weights, self._original_mask = self._get_weights()
            return False, {}

        CopulaPruner.have_not_started_pruning = False
        # ensure that the model is not in training mode, this is importance, because
        # otherwise the pruning procedure will interfere and affect the batch-norm statistics
        assert not self._model.training

        # reinit params if they were deleted during gradual pruning
        if not hasattr(self, "_all_grads"):
            self._all_grads = []
        if not hasattr(self, "_param_stats"):
            self._param_stats = []

        self._param_idx = 0

        flat_pruned_weights_list = []
        flat_module_weights_list = []
        module_shapes_list = []
        module_param_indices_list = []

        for idx, module in enumerate(self._modules):
            module_param_indices_list.append(self._param_idx)
            assert self._weight_only
            module_shapes_list.append(module.weight.shape)

            self._param_idx += module.weight.numel()

        for idx, module in enumerate(self._modules):
            flat_pruned_weights_list.append(module.weight.flatten())
            flat_module_weights_list.append(module.weight.flatten())

        module_param_indices_list.append(self._param_idx)

        flat_pruned_weights_list = flatten_tensor_list(flat_pruned_weights_list)
        flat_module_weights_list = flatten_tensor_list(flat_module_weights_list)

        # Compute the weight updates using the custom function
        grads, _, _, _ = self._compute_wgH(
            dset, subset_inds, device, num_workers, debug=False
        )

        pruning_stage = (epoch_num - self._start) // self._freq
        total_stages = (self._end - self._start) // self._freq
        print(f"PRUNING STAGE {pruning_stage}")

        # # linear increasing sparsity
        # sparsity = (
        #     pruning_stage / total_stages * self._target_sparsity
        # )

        # exponential increasing sparsity
        if total_stages > 1:
            sparsity = (
                self._target_sparsity
                + (self._initial_sparsity - self._target_sparsity)
                * (1 - (pruning_stage - 1) / (total_stages - 1)) ** 3
            )  # cubic increasing sparsity
        else:
            sparsity = self._target_sparsity

        print(f"Sparsity={sparsity}")
        self._get_weight_update(
            grads=grads,
            target_weights=self._target_weights,
            lam=1e-300,
            ome=1e-300,
            r=10000,
            transport=0,
            entropy=1,
            reg=self.args.reg,
            dset=dset,
            subset_inds=subset_inds,
            device=device,
            num_workers=num_workers,
            module_param_indices_list=module_param_indices_list,
            sparsity=sparsity,
            pruning_stage=pruning_stage,
        )

        return True, meta
