import math
import time
import torch
import torch.nn.functional as F
import numpy as np
import random
from .dataselectionstrategy import DataSelectionStrategy
from ..helpers import OrthogonalMP_REG_Parallel, OrthogonalMP_REG, OrthogonalMP_REG_Parallel_V1, select_closest_vectors
from torch.utils.data import Subset, DataLoader
import pandas as pd
import numpy as np
import logging
from sklearn.cluster import KMeans
import torch.distributed as dist
import os
class GradMatchStrategy(DataSelectionStrategy):
    """
    Implementation of GradMatch Strategy from the paper :footcite:`pmlr-v139-killamsetty21a` for supervised learning frameworks.

    GradMatch strategy tries to solve the optimization problem given below:

    .. math::
        \\min_{\\mathbf{w}, S: |S| \\leq k} \\Vert \\sum_{i \\in S} w_i \\nabla_{\\theta}L_T^i(\\theta) -  \\nabla_{\\theta}L(\\theta)\\Vert

    In the above equation, :math:`\\mathbf{w}` denotes the weight vector that contains the weights for each data instance, :math:`\mathcal{U}` training set where :math:`(x^i, y^i)` denotes the :math:`i^{th}` training data point and label respectively,
    :math:`L_T` denotes the training loss, :math:`L` denotes either training loss or validation loss depending on the parameter valid,
    :math:`S` denotes the data subset selected at each round, and :math:`k` is the budget for the subset.

    The above optimization problem is solved using the Orthogonal Matching Pursuit(OMP) algorithm.

    Parameters
	----------
    trainloader: class
        Loading the training data using pytorch DataLoader
    valloader: class
        Loading the validation data using pytorch DataLoader
    model: class
        Model architecture used for training
    loss: class
        PyTorch loss function for training
    eta: float
        Learning rate. Step size for the one step gradient update
    device: str
        The device being utilized - cpu | cuda
    num_classes: int
        The number of target classes in the dataset
    linear_layer: bool
        Apply linear transformation to the data
    selection_type: str
        Type of selection -
        - 'PerClass': PerClass method is where OMP algorithm is applied on each class data points seperately.
        - 'PerBatch': PerBatch method is where OMP algorithm is applied on each minibatch data points.
        - 'PerClassPerGradient': PerClassPerGradient method is same as PerClass but we use the gradient corresponding to classification layer of that class only.
    logger : class
        - logger object for logging the information
    valid : bool
        If valid==True, we use validation dataset gradient sum in OMP otherwise we use training dataset (default: False)
    v1 : bool
        If v1==True, we use newer version of OMP solver that is more accurate
    lam : float
        Regularization constant of OMP solver
    eps : float
        Epsilon parameter to which the above optimization problem is solved using OMP algorithm
    """

    def __init__(self, trainloader, valloader, model, loss,
                 eta, device, num_classes, linear_layer,
                 selection_type, logger, valid=False, v1=True, lam=0, eps=1e-4):
        """
        Constructor method
        """
        super().__init__(trainloader, valloader, model, num_classes, linear_layer, loss, device, logger)
        self.eta = eta  # step size for the one step gradient update
        self.device = device
        self.init_out = list()
        self.init_l1 = list()
        self.selection_type = selection_type
        self.valid = valid
        self.lam = lam
        self.eps = eps
        self.v1 = v1
        self.save = 0
        self.logger = logger

    def ompwrapper(self, X, Y, bud, method):
        if self.device == "cpu":
            reg = OrthogonalMP_REG(X.numpy(), Y.numpy(), nnz=bud, positive=True, lam=0)
            ind = np.nonzero(reg)[0]
        else:
            if self.v1:
    
                if method == "gradmatch_cos":
                    reg = select_closest_vectors(X, bud, lam=self.lam)
                else:
                    reg = OrthogonalMP_REG_Parallel_V1(X, Y, nnz=bud,
                                                    positive=True, lam=self.lam,
                                                    tol=self.eps, device=self.device)
                
                #reg = select_closest_vectors(X, bud, lam=self.lam)
            else:
                reg = OrthogonalMP_REG_Parallel(X, Y, nnz=bud,
                                                positive=True, lam=self.lam,
                                                tol=self.eps, device=self.device)
            abc = reg                                    
            ind = torch.nonzero(reg).view(-1)
        return ind.tolist(), reg[ind].tolist()
    
    def select(self, budget, model_params, last_para1, last_para2, method):
        """
        Apply OMP Algorithm for data selection

        Parameters
        ----------
        budget: int
            The number of data points to be selected
        model_params: OrderedDict
            Python dictionary object containing models parameters

        Returns
        ----------
        idxs: list
            List containing indices of the best datapoints,
        gammas: weights tensors
            Tensor containing weights of each instance
        """
        omp_start_time = time.time()
        if self.is_distributed:
            for param in self.model.parameters():
                dist.broadcast(param.data, src=0)

        self.update_model(model_params)
        self.compute_gradients(self.valid, perBatch=True, perClass=False)
        trn_gradients = self.grads_per_elem
        
        if self.selection_type == 'PerClass':
            self.get_labels(valid=self.valid)
            idxs = []
            gammas = []
            for i in range(self.num_classes):
                trn_subset_idx = torch.where(self.trn_lbls == i)[0].tolist()
                trn_data_sub = Subset(self.trainloader.dataset, trn_subset_idx)
                self.pctrainloader = DataLoader(trn_data_sub, batch_size=self.trainloader.batch_size,
                                                shuffle=False, pin_memory=True, collate_fn=self.trainloader.collate_fn)
                if self.valid:
                    val_subset_idx = torch.where(self.val_lbls == i)[0].tolist()
                    val_data_sub = Subset(self.valloader.dataset, val_subset_idx)
                    self.pcvalloader = DataLoader(val_data_sub, batch_size=self.trainloader.batch_size,
                                                  shuffle=False, pin_memory=True, collate_fn=self.trainloader.collate_fn)

                self.compute_gradients(self.valid, perBatch=False, perClass=True)
                trn_gradients = self.grads_per_elem
                if self.valid:
                    sum_val_grad = torch.sum(self.val_grads_per_elem, dim=0)
                else:
                    sum_val_grad = torch.sum(trn_gradients, dim=0)
                idxs_temp, gammas_temp = self.ompwrapper(torch.transpose(trn_gradients, 0, 1),
                                                         sum_val_grad,
                                                         math.ceil(budget * len(trn_subset_idx) / self.N_trn))
                idxs.extend(list(np.array(trn_subset_idx)[idxs_temp]))
                gammas.extend(gammas_temp)
        elif self.selection_type == 'PerBatch':#origion method
            #self.compute_gradients(self.valid, perBatch=True, perClass=False)
            idxs = []
            gammas = []
            #trn_gradients = self.grads_per_elem
            #trn_gradients = 0.5*trn_gradients + 0.3*trn_gradients1 + 0.2*trn_gradients2
            d, n = trn_gradients.shape
            #print(d, n)
            if method == "gradmatch_theta" or method == "gradmatch_threshold_theta":
                trn_gradients = torch.cat((trn_gradients, trn_gradients1, trn_gradients2), dim=1)

            #print("trn_gradients.shape", trn_gradients.shape)
            if self.valid:
                sum_val_grad = torch.sum(self.val_grads_per_elem, dim=0)
            else:
                sum_val_grad = torch.sum(trn_gradients, dim=0)
            d, n = trn_gradients.shape

            idxs_temp1, gammas_temp = self.ompwrapper(torch.transpose(trn_gradients, 0, 1),
                                                        sum_val_grad, math.ceil(budget / self.trainloader.batch_size), method)
            self.logger.info(
                "Subset indices: %s (Length: %d), Gammas: %s (Min: %.4f, Max: %.4f)",
                idxs_temp1, len(idxs_temp1), gammas_temp, min(gammas_temp), max(gammas_temp)
            )

                                      
            # idxs_temp2, gammas_temp2 = self.ompwrapper(torch.transpose(trn_gradients1, 0, 1),
            #                                          sum_val_grad1, math.ceil(budget / self.trainloader.batch_size), method)

            # idxs_temp3, gammas_temp3 = self.ompwrapper(torch.transpose(trn_gradients2, 0, 1),
            #                                          sum_val_grad2, math.ceil(budget / self.trainloader.batch_size), method)          
            

            # jaccard_12 = self.jaccard_similarity(idxs_temp1, idxs_temp2)
            # jaccard_13 = self.jaccard_similarity(idxs_temp1, idxs_temp3)
            # jaccard_23 = self.jaccard_similarity(idxs_temp2, idxs_temp3)

            # # 输出结果
            # print("jaccard_12, jaccard_13, jaccard_23",jaccard_12, jaccard_13, jaccard_23)
            
            #print("idxs_temp1",idxs_temp1,"\n","gammas_temp", gammas_temp)
            #idxs_temp1, gammas_temp = self.optimize_with_budget(gradients, budget, epsilon=1e-4, max_iter=100, batch_size=100)
            
            idxs_temp = idxs_temp1
            '''
            for i in range(len(idxs_temp)):
                if idxs_temp[i] > 4499:
                    idxs_temp[i] = idxs_temp[i] - 4500
                if idxs_temp[i] > 2249:
                    idxs_temp[i] = idxs_temp[i] - 2250
            '''


            #self.compute_gradients(self.valid, perBatch=True, perClass=False)
            #trn_gradients = self.grads_per_elem
            batch_wise_indices = list(self.trainloader.batch_sampler)
            for i in range(len(idxs_temp)):
                tmp = batch_wise_indices[idxs_temp[i]]
                idxs.extend(tmp)
                gammas.extend(list(gammas_temp[i] * np.ones(len(tmp))))

            # Subset gradient selection
            selected_gradients = trn_gradients[idxs_temp1, :]  # Select gradients of chosen samples
            gammas_temp = torch.tensor(gammas_temp).to(trn_gradients.device)
            gammas_temp = gammas_temp.view(-1, 1)  # Reshape gammas_temp to (num_selected, 1)
            reconstructed_gradients = torch.sum(selected_gradients * gammas_temp, dim=0)
            
            # Compute norm of original gradient
            normb = torch.norm(sum_val_grad)
            
            # Compute residual for OMP selected gradients


            cosine_similarity = F.cosine_similarity(reconstructed_gradients, sum_val_grad, dim=0)
            residual = 1 - cosine_similarity

            #residual = torch.norm(reconstructed_gradients - sum_val_grad) / normb
            # print("!!!!The relative difference is: ", residual.item())
            # ----- New Code for Random 10% Subset Selection -----
            # Randomly select 10% of the samples
            num_random_samples = int(0.1 * d)
            random_indices = torch.randint(0, d, (num_random_samples,), dtype=torch.long)

            # Sum the gradients of the randomly selected samples
            random_sum_index = torch.sum(trn_gradients[random_indices, :], dim=0)

            # Calculate the norm of random_sum_index
            norm_random_sum = torch.norm(random_sum_index)

            # Calculate residual_random
            #residual_random = torch.norm((random_sum_index * normb / norm_random_sum) - sum_val_grad) / normb
            cosine_similarity_random = F.cosine_similarity(reconstructed_gradients, sum_val_grad, dim=0)
            residual_random = 1 - cosine_similarity_random
            # Print or return residual_random for further analysis
            # print(f'Residual from random 10% selection: {residual_random.item()}')
            '''
            # 提取对应的梯度子集
            selected_gradients = trn_gradients[idxs_temp, : ]  # 选择特定样本的梯度
            gammas_temp = torch.tensor(gammas_temp).to(trn_gradients.device)  # 将权重转为tensor并放到相同的设备上
            #gammas_temp = gammas_temp.clone().detach().to(trn_gradients.device)
            # 通过加权子集梯度，重构整体梯度
            gammas_temp = gammas_temp.view(-1, 1)  # 将 gammas_temp 重塑为 (211, 1)
            reconstructed_gradients = torch.sum(selected_gradients * gammas_temp, dim=0)

            # 计算原始梯度的范数
            normb = torch.norm(sum_val_grad)

            # 计算重构梯度和原始梯度的相对差距（归一化的残差）
            residual = torch.norm(reconstructed_gradients - sum_val_grad) / normb

            print("The relative difference is: ", residual.item())
            '''
        elif self.selection_type == 'PerClassPerGradient':
            self.get_labels(valid=self.valid)
            idxs = []
            gammas = []
            embDim = self.model.get_embedding_dim()
            for i in range(self.num_classes):
                trn_subset_idx = torch.where(self.trn_lbls == i)[0].tolist()
                trn_data_sub = Subset(self.trainloader.dataset, trn_subset_idx)
                self.pctrainloader = DataLoader(trn_data_sub, batch_size=self.trainloader.batch_size,
                                                shuffle=False, pin_memory=True, collate_fn=self.trainloader.collate_fn)
                if self.valid:
                    val_subset_idx = torch.where(self.val_lbls == i)[0].tolist()
                    val_data_sub = Subset(self.valloader.dataset, val_subset_idx)
                    self.pcvalloader = DataLoader(val_data_sub, batch_size=self.trainloader.batch_size,
                                                  shuffle=False, pin_memory=True, collate_fn=self.trainloader.collate_fn)
                self.compute_gradients(self.valid, perBatch=False, perClass=True)
                trn_gradients = self.grads_per_elem
                tmp_gradients = trn_gradients[:, i].view(-1, 1)
                tmp1_gradients = trn_gradients[:,
                                 self.num_classes + (embDim * i): self.num_classes + (embDim * (i + 1))]
                trn_gradients = torch.cat((tmp_gradients, tmp1_gradients), dim=1)

                if self.valid:
                    val_gradients = self.val_grads_per_elem
                    tmp_gradients = val_gradients[:, i].view(-1, 1)
                    tmp1_gradients = val_gradients[:,
                                     self.num_classes + (embDim * i): self.num_classes + (embDim * (i + 1))]
                    val_gradients = torch.cat((tmp_gradients, tmp1_gradients), dim=1)
                    sum_val_grad = torch.sum(val_gradients, dim=0)
                else:
                    sum_val_grad = torch.sum(trn_gradients, dim=0)

                idxs_temp, gammas_temp = self.ompwrapper(torch.transpose(trn_gradients, 0, 1),
                                                         sum_val_grad,
                                                         math.ceil(budget * len(trn_subset_idx) / self.N_trn))
                idxs.extend(list(np.array(trn_subset_idx)[idxs_temp]))
                gammas.extend(gammas_temp)


        diff = budget - len(idxs)
        self.logger.debug("Random points added: %d ", diff)
        diff = int(diff)
        if diff > 0:
            remainList = set(np.arange(self.N_trn)).difference(set(idxs))
            new_idxs = np.random.choice(list(remainList), size=diff, replace=False)
            idxs.extend(new_idxs)
            gammas.extend([1 for _ in range(diff)])
            idxs = np.array(idxs)
            gammas = np.array(gammas)

        if self.selection_type in ["PerClass", "PerClassPerGradient"]:
            rand_indices = np.random.permutation(len(idxs))
            idxs = list(np.array(idxs)[rand_indices])
            gammas = list(np.array(gammas)[rand_indices])
        
        idxs = [int(x) for x in idxs]
        omp_end_time = time.time()
        self.logger.debug("GradMatch algorithm Subset Selection time is: %.4f", omp_end_time - omp_start_time)
        if self.is_distributed and self.rank != 0:
            return [], torch.tensor([]), [], [], 0, 0
        else:
            return idxs, gammas, idxs_temp, gammas_temp, residual.item(), residual_random.item()
        # return idxs, torch.FloatTensor(gammas), idxs_temp, gammas_temp, residual.item(), residual_random.item()

    def compute_gradient_variability(self, trn_gradients: torch.Tensor, method: str = 'var') -> torch.Tensor:
        """
        Compute gradient variability using different methods.

        Parameters
        ----------
        trn_gradients : torch.Tensor
            Shape (N, D) where N is number of elements and D is gradient dimension.
        method : str
            Method to compute variability. Options:
                - 'var': Mean variance across dimensions
                - 'rel_var': Variance divided by mean abs gradient per dim
                - 'cv': Coefficient of Variation (std / mean abs) per dim
                - 'normalized_var': Normalize gradients then compute variance

        Returns
        -------
        torch.Tensor
            A single scalar representing the variability score.
        """
        eps = 1e-8  # prevent divide-by-zero

        if method == 'var':
            grad_var = torch.var(trn_gradients, dim=0, unbiased=False)
            return torch.mean(grad_var)

        elif method == 'rel_var':
            grad_var = torch.var(trn_gradients, dim=0, unbiased=False)
            grad_mean_abs = torch.mean(torch.abs(trn_gradients), dim=0)
            rel_var = grad_var / (grad_mean_abs + eps)
            return torch.mean(rel_var)

        elif method == 'snr':
            grad_var = torch.var(trn_gradients, dim=0, unbiased=False)
            # self.logger.info(f"Gradient shape: {trn_gradients.shape}")
            # self.logger.info(f"Gradient_var shape: {grad_var.shape}")
            grad_std = torch.std(trn_gradients, dim=0, unbiased=False)
            grad_mean_abs = torch.mean(torch.abs(trn_gradients), dim=0)
            cv = grad_std / (grad_mean_abs + eps)
            var_over_mean = grad_var / (grad_mean_abs + eps)
            return torch.mean(cv)

        elif method == 'normalized_var':
            # Normalize each dimension, then compute variance
            grad_mean = torch.mean(trn_gradients, dim=0)
            grad_std = torch.std(trn_gradients, dim=0, unbiased=False)
            normed_grad = (trn_gradients - grad_mean) / (grad_std + eps)
            grad_var = torch.var(normed_grad, dim=0, unbiased=False)
            return torch.mean(grad_var)

        else:
            raise ValueError(f"Unknown method '{method}'. Choose from 'var', 'rel_var', 'cv', 'normalized_var'.")
        
    def evaluate_CV(self, budget, model_params):
        self.update_model(model_params)
        self.compute_gradients_small(self.valid, perBatch=True, perClass=False)
        trn_gradients = self.grads_per_elem
        cv_value = self.compute_gradient_variability(trn_gradients, method='snr')
        return cv_value
       
    def evaluate_grad(self, budget, model_params, subset_indices, subset_weights, random_indices):
        self.update_model(model_params)
        self.compute_gradients(self.valid, perBatch=True, perClass=False)
        trn_gradients = self.grads_per_elem

        max_vals, _ = torch.max(trn_gradients, dim=0)  
        min_vals, _ = torch.min(trn_gradients, dim=0)  
        # print(max_vals.shape)
        diff_square_sum = torch.sum((max_vals - min_vals) ** 2)  

       
        # self.logger.info(f"📏 Sum of (max - min)^2 across dimensions: {diff_square_sum.item():.6e}")

        if self.valid:
            sum_val_grad = torch.sum(self.val_grads_per_elem, dim=0)
        else:
            sum_val_grad = torch.sum(trn_gradients, dim=0)
        
        # Get the dimensions of gradients (d: number of samples, n: gradient dimensions)
        d, n = trn_gradients.shape
        
        # Subset indices and weights from input arguments
        idxs_temp, gammas_temp = subset_indices, subset_weights

        # Extract the gradients of the selected subset
        selected_gradients = trn_gradients[idxs_temp, :]  # Select specific samples' gradients
        #print("selected_gradients.device:", selected_gradients.device)
        #print("gammas_temp.device:", gammas_temp.device)

        # Reshape weights for broadcasting and reconstruction
        gammas_temp = gammas_temp.view(-1, 1)  # Reshape gammas_temp to (num_selected, 1)
        reconstructed_gradients = torch.sum(selected_gradients * gammas_temp, dim=0)
        norm_reconstructed_gradients = torch.norm(reconstructed_gradients)
        # Compute the norm of the original gradient (sum of all gradients)
        normb = torch.norm(sum_val_grad)

        # Compute the relative difference (normalized residual) between reconstructed and original gradients
        cosine_similarity = F.cosine_similarity(reconstructed_gradients, sum_val_grad, dim=0)
        residual = 1 - cosine_similarity

        # ----- New Code for Random 10% Subset Selection -----
        # Randomly select 10% of the samples
        num_random_samples = int(0.1 * d)  # 10% of the total number of samples
        #random_indices = torch.randint(0, d, (num_random_samples,), dtype=torch.long)

        # Sum the gradients of the randomly selected samples
        random_sum_index = torch.sum(trn_gradients[random_indices, :], dim=0)

        # Calculate the norm of the random_sum_index
        norm_random_sum = torch.norm(random_sum_index)

        # Calculate residual_random
        cosine_similarity_random = F.cosine_similarity(random_sum_index, sum_val_grad, dim=0)
        residual_random = 1 - cosine_similarity_random
        grad_var = torch.var(trn_gradients, dim=0, unbiased=False) 
        mean_var = torch.mean(grad_var) 
        variability_metrics = {
            'mean_variance': self.compute_gradient_variability(trn_gradients, method='var'),
            'relative_variance': self.compute_gradient_variability(trn_gradients, method='rel_var'),
            'coefficient_of_variation': self.compute_gradient_variability(trn_gradients, method='cv'),
            'normalized_variance': self.compute_gradient_variability(trn_gradients, method='normalized_var')
        }

        self.logger.info("📊 Gradient Variability Metrics:")
        for name, val in variability_metrics.items():
            self.logger.info(f"  - {name}: {val.item():.6e}")
        cosine_similarity_random_vs_reconstructed = 1-F.cosine_similarity(random_sum_index, reconstructed_gradients, dim=0)
        
        return residual.item(), residual_random.item(), mean_var.item(), trn_gradients

    def record_gradiant(self, model_params):
        self.update_model(model_params)
        self.compute_gradients(self.valid, perBatch=True, perClass=False)
        trn_gradients = self.grads_per_elem
        df = pd.DataFrame(trn_gradients)
        # 将 DataFrame 保存到指定路径
        output_path = '/root/cords_project_retry/cords-main/benchmarks/SL/results/grad_of_all_samples_in_full.csv'
        df.to_csv(output_path, index=False)
        return 0
