from torch.utils.data import Dataset
import os
import torch
import numpy as np
from torchvision import transforms, datasets
from PIL import Image
import random
import cv2
import pickle
import math
import time
import scipy.io as scio
import seaborn as sns
from collections import defaultdict
from sklearn.mixture import GaussianMixture
import matplotlib.pyplot as plt
from tqdm import tqdm
from Sinkhorn_distance import SinkhornDistance

class Argument(object):                       #
    def __init__(self):
        self.if_allow_break = True
        self.visible_gpu = '0'
        self.num_threads = 8
        self.device = None

        '''Arguments for two-stage'''
        self.total_epochs = 300
        self.pretrained_epochs = 200
        self.finetune_epochs = 50
        self.learning_rate = 1e-1
        self.batch_size = 128
        self.dataset_name = 'CIFAR100'
        self.imbalance_ratio = 0.01
        self.action_dim = 100
        self.feature_dim = 256
        self.random_seed = 9

    def init_before_training(self):
        np.random.seed(self.random_seed)
        torch.manual_seed(self.random_seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

        torch.set_num_threads(self.num_threads)                    # 并行化cpu
        torch.set_default_dtype(torch.float32)
        os.environ['CUDA_VISIBLE_DEVICES'] = str(self.visible_gpu)

def cifar_10(path):
    assert os.path.exists(path), 'path error!'
    train_file_list = [path + f'/data_batch_{i}' for i in range(1, 6)]
    test_file_list = path + '/test_batch'
    train_data = np.zeros((50000, 3072))
    train_gt = np.zeros(50000)

    idx = 0
    for file in train_file_list:
        with open(file, 'rb') as fo:
            dict = pickle.load(fo, encoding='bytes')
            temp_data = dict[b'data']
            temp_gt = dict[b'labels']
            train_data[idx:idx+temp_data.shape[0], :] = temp_data
            train_gt[idx:idx+temp_data.shape[0]] = temp_gt
            idx += temp_data.shape[0]

    with open(test_file_list, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
        test_data = np.array(dict[b'data'])
        test_gt = np.array(dict[b'labels'])

    return train_data.astype(np.uint8), train_gt.reshape((-1, 1)).astype(int), \
           test_data.astype(np.uint8), test_gt.reshape((-1, 1)).astype(int)


def cifar_100(path):
    assert os.path.exists(path), 'path error!'
    train_file_list = path + '/train'
    test_file_list = path + '/test'
    train_data = np.zeros((50000, 3072))
    train_gt = np.zeros(50000)
    with open(train_file_list, 'rb') as fo:
        dict = pickle.load(fo, encoding='latin1')
        temp_data = dict['data']
        temp_gt = dict['fine_labels']
        train_data[:] = temp_data
        train_gt[:] = temp_gt

    test_data = np.zeros((10000, 3072))
    test_gt = np.zeros(10000)
    with open(test_file_list, 'rb') as fo:
        dict = pickle.load(fo, encoding='latin1')
        temp_data = dict['data']
        temp_gt = dict['fine_labels']
        test_data[:] = temp_data
        test_gt[:] = temp_gt

    return train_data.astype(np.uint8), train_gt.reshape((-1, 1)).astype(int), \
        test_data.astype(np.uint8), test_gt.reshape((-1, 1)).astype(int)


def prepare_data(path, dataset_name, imb_ratio, imb=True):
    assert os.path.exists(path), 'path error!'

    if dataset_name == 'CIFAR10':
        class_num = 10
        sample_per_class = 5000
        train_data, train_gt, test_data, test_gt = cifar_10(path)
    if dataset_name == 'CIFAR100':
        class_num = 100
        sample_per_class = 500
        train_data, train_gt, test_data, test_gt = cifar_100(path)

    train_data = np.reshape(train_data, [train_data.shape[0], -1, 32, 32])
    test_data = np.reshape(test_data, [test_data.shape[0], -1, 32, 32])


    data_num_N = np.zeros(class_num, dtype=np.int64)
    for i in train_gt:
        data_num_N[i] += 1

    data_num_test = np.zeros(class_num, dtype=np.int64)
    for i in test_gt:
        data_num_test[i] += 1

    if imb == True:
        imb_train_data_list = []
        imb_train_gt_list = []
        data_num_N = []
        for cls_idx in range(class_num):
            num = sample_per_class * (imb_ratio ** (cls_idx / (class_num - 1)))
            data_num_N.append(int(num))
        data_num_N = np.array(data_num_N)
        for i in range(class_num):
            temp = train_data[np.where(train_gt == i)[0]]
            index = np.random.choice(sample_per_class, data_num_N[i], replace=False)
            data_temp = temp[index]
            gt_temp = np.ones(data_num_N[i], dtype=np.uint8) * i
            imb_train_data_list.append(data_temp)
            imb_train_gt_list.append(gt_temp.reshape(-1, 1))

        imb_train_data = imb_train_data_list[0]
        imb_train_gt = imb_train_gt_list[0]
        for i in range(1, class_num):
            temp_data = imb_train_data_list[i]
            temp_gt = imb_train_gt_list[i]
            imb_train_data = np.vstack((imb_train_data, temp_data))
            imb_train_gt = np.vstack((imb_train_gt, temp_gt))

        return imb_train_data, imb_train_gt, test_data, test_gt, data_num_N

    return train_data, train_gt, test_data, test_gt, data_num_N


def cal_statistic_single(dataloader, model, data_num_N, head_index=None):
    model.eval()
    mean_list, var_list = [], []
    value_vector_mean_var_cov = defaultdict(list)

    feature_list, label_list, logit_list, prob_list = [], [], [], []
    softmax = torch.nn.Softmax(dim=-1)
    with torch.no_grad():
        for batch in dataloader:
            data, gt = batch[0], batch[1]
            data = data.cuda().float()
            output, feature = model.forward_with_feature(data)
            prob = softmax(output)

            feature_list.append(feature.detach().cpu())
            logit_list.append(prob.detach().cpu())
            prob_list.append(prob.detach().cpu())
            label_list.append(gt.squeeze())

        feature = torch.cat(feature_list)
        label = torch.cat(label_list)
        prob = torch.cat(prob_list)
        logit = torch.cat(logit_list)

        if head_index is not None:
            if isinstance(head_index, (list, tuple)):
                head_index = torch.tensor(head_index, dtype=label.dtype)
            elif isinstance(head_index, np.ndarray):
                head_index = torch.from_numpy(head_index).to(dtype=label.dtype)
            head_mask = torch.isin(label, head_index)
            head_feature = feature[head_mask]
            head_label = label[head_mask]
        else:
            head_feature = None
            head_label = None


    ''' save feature '''
    recoder = {}
    recoder['feature'] = feature.detach().cpu()
    recoder['label'] = label.detach().cpu()
    with open('feature-embedding/cifar100lt', 'wb') as file:
        pickle.dump(recoder, file)

    feature_per_class = defaultdict()
    logit_per_class = defaultdict()
    for cls in range(len(data_num_N)):
        index = torch.where(label==cls)[0]
        cls_features = feature[index]
        cls_logit = logit[index]

        feature_per_class[cls] = cls_features
        logit_per_class[cls] = cls_logit.mean(dim=0)
        # feature_per_class[cls] = cls_features.numpy()
        # cls_prob = np.array(prob[index])
        mean_cls = torch.mean(cls_features, dim=0, keepdim=True)
        var_cls = torch.var(cls_features, dim=0, keepdim=True)
        ''' cov matrix '''
        norm_feature = cls_features - mean_cls
        cov_matrix = torch.matmul(norm_feature.T, norm_feature) / (norm_feature.shape[0] - 1)
        u, s, v = np.linalg.svd(np.array(cov_matrix), hermitian=True)
        ''' save the top ten values and their correspond vector  '''
        value_vector_mean_var_cov[cls].append(s[:10])
        value_vector_mean_var_cov[cls].append(u[:, :10])
        value_vector_mean_var_cov[cls].append(mean_cls)
        value_vector_mean_var_cov[cls].append(var_cls)
        value_vector_mean_var_cov[cls].append(cov_matrix)

        mean_list.append(mean_cls)
        var_list.append(var_cls)

    return value_vector_mean_var_cov, feature_per_class, logit_per_class, feature, label, head_feature, head_label



''' select topK most similar head classes '''
def cal_sim_topk(cls_statistic, head_class, tail_class, k=10):
    assert len(head_class) >= k, 'the number of head classes must more than k!!'
    topk = defaultdict(int)
    for t_cls in tail_class:
        t_eigenvalue = cls_statistic[t_cls][0]
        t_eigenvector = cls_statistic[t_cls][1]
        t_mean = cls_statistic[t_cls][2].squeeze()
        t_var = cls_statistic[t_cls][3].squeeze()
        t_cov = cls_statistic[t_cls][4].squeeze()

        sim_score = defaultdict(int)
        for h_cls in head_class:
            h_eigenvalue = cls_statistic[h_cls][0]
            h_eigenvector = cls_statistic[h_cls][1]
            h_mean = cls_statistic[h_cls][2].squeeze()
            h_var = cls_statistic[h_cls][3].squeeze()
            h_cov = cls_statistic[h_cls][4].squeeze()
            '''  only use cosine sim between mean'''
            sim = torch.dot(t_mean/torch.norm(t_mean), h_mean/torch.norm(h_mean))
            sim_score[h_cls] = sim

        index = torch.topk(torch.tensor(list(sim_score.values())), k).indices
        topk[t_cls] = index
    return topk

''' select topK most similar opponent '''
def cal_sim_topk_logit(logit_per_class, tail_class, k=1):
    logits = logit_per_class
    topk = defaultdict(list)

    for t_cls in tail_class:
        row = logits[t_cls].clone()  # [C]
        row[t_cls] = float("-inf")
        inds = torch.topk(row, k=k, largest=True).indices.tolist()
        topk[t_cls] = inds
    return topk


def extract_weight_diff(model, tail_class, topk_oppo):

    if hasattr(model, 'fc_cb'):
        weight = model.fc_cb.weight.data  # (num_classes, feature_dim)
    elif hasattr(model, 'linear'):
        weight = model.linear.weight.data  # (num_classes, feature_dim)
    elif hasattr(model, 'fc'):
        weight = model.fc.weight.data  # (num_classes, feature_dim)
    else:
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Linear) and 'classifier' in name.lower():
                weight = module.weight.data
                break
        else:
            linear_layers = [m for m in model.modules() if isinstance(m, torch.nn.Linear)]
            if linear_layers:
                weight = linear_layers[-1].weight.data
            else:
                raise ValueError("check linear classifier")
    
    device = weight.device
    num_classes, feature_dim = weight.shape

    if isinstance(tail_class, np.ndarray):
        tail_class = tail_class.tolist()
    
    weight_diff = {}
    
    for tail_cls in tail_class:
        if tail_cls not in topk_oppo:
            continue

        wc = weight[tail_cls]  # (feature_dim,)
        opponent_indices = topk_oppo[tail_cls]  # list or tensor of K indices

        if isinstance(opponent_indices, list):
            opponent_indices = torch.tensor(opponent_indices, device=device, dtype=torch.long)
        elif isinstance(opponent_indices, np.ndarray):
            opponent_indices = torch.from_numpy(opponent_indices).to(device).long()

        opponent_indices = opponent_indices[opponent_indices < num_classes]
        if len(opponent_indices) == 0:
            continue

        wk_all = weight[opponent_indices]  # (K, feature_dim)
        diff = wc.unsqueeze(0) - wk_all  # (K, feature_dim)
        
        weight_diff[tail_cls] = diff
    
    return weight_diff




def project_to_support_batch_per_class(delta0, v_k, iters=50, tol=1e-9):
    delta = delta0.clone()  # (N, d)
    N, K, d = v_k.shape
    for _ in range(iters):
        G = (v_k * delta.unsqueeze(1)).sum(dim=-1)
        min_vals, k_idx = torch.min(G, dim=1)  # (N,), (N,)
        if torch.all(min_vals >= -tol):
            break
        idx_n = torch.arange(N, device=delta.device)
        vk = v_k[idx_n, k_idx, :]
        denom = (vk * vk).sum(dim=1) + 1e-12  # (N,)
        step = min_vals / denom               # (N,)
        delta = delta - step.unsqueeze(1) * vk
    return delta



def soft_project_to_support_batch_per_class(
    delta0,
    v_k,
    lam=5.,
    lr=0.1,
    iters=1,
    tol=None
):
    delta = delta0.clone()  # (N,d)
    N, K, d = v_k.shape
    if K == 0:
        return delta
    for _ in range(iters):

        G = (v_k * delta.unsqueeze(1)).sum(dim=-1)  # (N,K)
        neg = torch.clamp(-G, min=0.0)             # (N,K)
        grad_data = 2.0 * (delta - delta0)         # (N,d)
        coeff = (2.0 * lam / K) * neg              # (N,K)
        grad_penalty = -(coeff.unsqueeze(-1) * v_k).sum(dim=1)  # (N,d)
        grad = grad_data + grad_penalty            # (N,d)
        delta = delta - lr * grad
        if tol is not None:
            max_violation = neg.max().item()
            if max_violation < tol:
                break
    return delta


def build_orthonormal_basis(v_k, r_max=None, eps=1e-8):
    if v_k is None:
        return None
    if isinstance(v_k, torch.Tensor) is False:
        v_k = torch.as_tensor(v_k)

    K, d = v_k.shape
    device = v_k.device
    dtype = v_k.dtype

    basis = []
    for k in range(K):
        v = v_k[k].clone()
        # Gram-Schmidt
        for b in basis:
            proj = torch.dot(v, b) * b
            v = v - proj
        n = v.norm()
        if n > eps:
            basis.append(v / n)
        if (r_max is not None) and (len(basis) >= r_max):
            break

    if len(basis) == 0:
        return None
    # (d, r)
    U = torch.stack(basis, dim=1).to(device=device, dtype=dtype)
    return U


def compute_base_var_for_tails(
    feature_per_class,
    class_mean_tensor,
    weight_diff,
    tail_index,
    device,
    r_max=4,
    eps=1e-8,
):

    base_var_dict = {}
    for cls in tail_index:
        feats = feature_per_class[cls].to(device)              # (N_c, d)
        mu_base = class_mean_tensor[cls, 0].to(device)         # (d,)

        if cls not in weight_diff:
            continue
        v_k_single = weight_diff[cls].to(device)               # (K, d)

        U_c = build_orthonormal_basis(v_k_single, r_max=r_max, eps=eps)
        if U_c is None:
            continue

        residuals = feats - mu_base.unsqueeze(0)               # (N_c, d)
        Y_base = residuals @ U_c                               # (N_c, r)
        base_var = Y_base.var(dim=0, unbiased=False)           # (r,)

        base_var_dict[cls] = base_var

    return base_var_dict



def dangerous_shrink_lowrank(
    x_sample,
    mu_c,
    v_k,
    base_var=None,
    tau=None,
    gamma_max=0.2,
    eta=1.0,
    J_indices=None,
    r_max=None,
    eps=1e-8,
):

    if v_k is None:
        return x_sample

    if isinstance(x_sample, torch.Tensor) is False:
        x_sample = torch.as_tensor(x_sample)
    if isinstance(mu_c, torch.Tensor) is False:
        mu_c = torch.as_tensor(mu_c)

    device = x_sample.device
    dtype = x_sample.dtype
    mu_c = mu_c.to(device=device, dtype=dtype)

    U = build_orthonormal_basis(v_k.to(device=device, dtype=dtype), r_max=r_max, eps=eps)
    if U is None:
        return x_sample

    B, d = x_sample.shape
    r = U.shape[1]   # 实际危险子空间维数
    residuals = x_sample - mu_c.unsqueeze(0)   # (B, d)
    Y = residuals @ U                          # (B, r)
    var_hat = Y.var(dim=0, unbiased=False)     # (r,)

    if (base_var is not None) and (tau is not None):
        base_var = torch.as_tensor(base_var, device=device, dtype=dtype)
        V = var_hat.sum()
        V_base = base_var.sum()
        if (V - V_base) <= tau * (V_base + eps):
            return x_sample

    if base_var is not None:
        base_var = torch.as_tensor(base_var, device=device, dtype=dtype)
        ratio = torch.clamp((var_hat - base_var) / (base_var + eps), min=0.0)
        gamma = eta * ratio
        gamma = torch.clamp(gamma, max=gamma_max)
    else:
        gamma = torch.full((r,), gamma_max, device=device, dtype=dtype)

    if J_indices is not None:
        mask = torch.zeros_like(gamma, dtype=torch.bool)
        mask[J_indices] = True
        gamma = gamma * mask.to(dtype)

    Y_gamma = Y * gamma.unsqueeze(0)          # (B, r)
    shrink_term = Y_gamma @ U.t()             # (B, d)
    residuals_new = residuals - shrink_term   # (B, d)

    x_new = mu_c.unsqueeze(0) + residuals_new
    return x_new


def risk_bounded_calibrate(cls_stastic, topk, feature_per_class, head_index, tail_index,
                           data_num_N, weight_diff, project_iters=50, project_tol=1e-9):
    def _to_list(idx):
        if isinstance(idx, torch.Tensor):
            return idx.detach().cpu().tolist()
        if isinstance(idx, np.ndarray):
            return idx.tolist()
        return list(idx)


    head_index = _to_list(head_index)
    tail_index = _to_list(tail_index)

    head_mean_list, tail_mean_list = [], []
    head_vector, tail_vector = [], []
    for i in head_index:
        head_mean_list.append(cls_stastic[i][2])
        head_vector.append(torch.from_numpy(cls_stastic[i][1].T).unsqueeze(dim=0))
    for j in tail_index:
        tail_mean_list.append(cls_stastic[j][2])
        tail_vector.append(torch.from_numpy(cls_stastic[j][1].T).unsqueeze(dim=0))
    head_mean = torch.cat(head_mean_list, dim=0)
    tail_mean = torch.cat(tail_mean_list, dim=0)
    head_vector = torch.cat(head_vector, dim=0)
    tail_vector = torch.cat(tail_vector, dim=0)

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    head_mean = head_mean.to(device=device, dtype=torch.float32)
    tail_mean = tail_mean.to(device=device, dtype=torch.float32)
    head_vector = head_vector.to(device=device, dtype=torch.float32)
    tail_vector = tail_vector.to(device=device, dtype=torch.float32)

    # 构造 class_mean_tensor，并预计算每个 tail 类的 base_var ---
    num_classes = len(cls_stastic)
    feature_dim = head_mean.shape[1]
    class_mean_tensor = torch.zeros(num_classes, 1, feature_dim,
                                    device=device, dtype=head_mean.dtype)
    for cls_id in range(num_classes):
        class_mean_tensor[cls_id, 0] = cls_stastic[cls_id][2].to(device=device, dtype=head_mean.dtype)
    r_max = 8  # 危险子空间维数上限，可调（4 或 8 都可以试）
    base_var_dict = compute_base_var_for_tails(
        feature_per_class=feature_per_class,
        class_mean_tensor=class_mean_tensor,
        weight_diff=weight_diff,
        tail_index=tail_index,
        device=device,
        r_max=r_max,
    )

    OT = SinkhornDistance(eps=2, max_iter=200, dis='cos', reduction=None).to('cuda')
    _, prob_T, _ = OT(head_mean, head_vector, tail_mean, tail_vector, data_num_N)


    prob_T = prob_T.to(device=device, dtype=torch.float32)
    transport_sum = prob_T.sum(dim=0, keepdim=True).t()
    transport_sum = torch.clamp(transport_sum, min=1e-12)
    ot_target = prob_T.t() @ head_mean
    ot_target = ot_target / transport_sum
    delta0 = ot_target - tail_mean

    if weight_diff is None:
        raise ValueError("weight_diff is required to build constraint normals v_k.")

    feature_dim = head_mean.shape[1]
    num_classes = len(cls_stastic)
    class_mean_tensor = torch.zeros(num_classes, 1, feature_dim, device=device, dtype=head_mean.dtype)
    for head_cls in head_index:
        class_mean_tensor[head_cls, 0] = cls_stastic[head_cls][2].to(device=device, dtype=head_mean.dtype)


    tail_num = len(tail_index)
    available_k = [diff.shape[0] for diff in weight_diff.values() if diff is not None]
    max_k = max(available_k) if available_k else 1
    v_list = []
    for tail_cls in tail_index:
        diff = weight_diff.get(tail_cls)
        if diff is None:
            diff_tensor = torch.zeros(max_k, feature_dim, device=device, dtype=head_mean.dtype)
        else:
            diff_tensor = diff.to(device=device, dtype=head_mean.dtype)
            if diff_tensor.dim() == 1:
                diff_tensor = diff_tensor.unsqueeze(0)
            if diff_tensor.shape[0] < max_k:
                pad = torch.zeros(max_k - diff_tensor.shape[0], feature_dim, device=device,
                                  dtype=head_mean.dtype)
                diff_tensor = torch.cat([diff_tensor, pad], dim=0)
            elif diff_tensor.shape[0] > max_k:
                diff_tensor = diff_tensor[:max_k]
        v_list.append(diff_tensor)
    v_k = torch.stack(v_list, dim=0) if v_list else torch.zeros(tail_num, max_k, feature_dim,
                                                               device=device, dtype=head_mean.dtype)
    ''' hard '''
    # delta_projected = project_to_support_batch_per_class(delta0, v_k, iters=project_iters, tol=project_tol)
    ''' soft '''
    delta_projected = soft_project_to_support_batch_per_class(delta0, v_k)


    max_sample_size = feature_per_class[head_index[0]].shape[0]
    finetune_feature = feature_per_class[head_index[0]].to(device=device, dtype=torch.float32)
    finetune_label = torch.tensor([head_index[0]] * max_sample_size, device=device).unsqueeze(dim=1)
    origin_lable = torch.tensor([head_index[0]] * feature_per_class[0].shape[0], device=device).unsqueeze(dim=1)
    origin_feature = torch.cat(
        [(torch.from_numpy(v).to(device=device, dtype=torch.float32) if isinstance(v, np.ndarray) else v.to(device=device, dtype=torch.float32))
         for v in feature_per_class.values()], dim=0)

    if len(head_index) > 1:
        for i in range(1, len(head_index)):
            temp_feature = feature_per_class[head_index[i]].to(device=device, dtype=torch.float32)
            temp_label = torch.tensor([head_index[i]] * feature_per_class[head_index[i]].shape[0], device=device).unsqueeze(dim=1)
            finetune_feature = torch.cat([finetune_feature, temp_feature], dim=0)
            finetune_label = torch.cat([finetune_label, temp_label], dim=0)
            origin_lable = torch.cat([origin_lable, temp_label], dim=0)

    for i, tail_cls in enumerate(tail_index):
        tail_mean_cls = cls_stastic[tail_cls][2].to(device)
        tail_cov = cls_stastic[tail_cls][4].to(device)
        tail_value, tail_vector = torch.tensor(cls_stastic[tail_cls][0], device=device), torch.tensor(cls_stastic[tail_cls][1], device=device)
        cali_cov = torch.zeros_like(tail_cov)
        for head in range(prob_T.shape[0]):
            head_idx = head_index[head]
            head_mean, head_cov = cls_stastic[head_idx][2].to(device), cls_stastic[head_idx][4].to(device)
            cali_cov += head_cov * prob_T[head, i]

        new_mean = tail_mean_cls + 0.2 * delta_projected[i]
        class_mean_tensor[tail_cls, 0] = new_mean
        new_cov = 0.5 * tail_cov + 0.5 * cali_cov
        new_cov = new_cov + torch.eye(new_cov.size(-1), device=new_cov.device) * 1e-6

        cali_Gauss = torch.distributions.MultivariateNormal(new_mean, new_cov)
        new_tail_feature = cali_Gauss.sample(sample_shape=(max_sample_size,)).squeeze()

        v_k_single = weight_diff.get(tail_cls, None)
        base_var = base_var_dict.get(tail_cls, None)
        if v_k_single is not None:
            new_tail_feature = dangerous_shrink_lowrank(
                x_sample=new_tail_feature,
                mu_c=new_mean.squeeze(),
                v_k=v_k_single,
                base_var=None,
                tau=0.0,
                gamma_max=0.2,
                eta=1.,
                r_max=r_max)
        temp_label_full = torch.tensor([tail_cls] * max_sample_size, device=device).unsqueeze(dim=1)
        temp_label = torch.tensor([tail_cls] * feature_per_class[tail_cls].shape[0], device=device).unsqueeze(dim=1)
        finetune_feature = torch.cat([finetune_feature, new_tail_feature], dim=0)
        finetune_label = torch.cat([finetune_label, temp_label_full], dim=0)
        origin_lable = torch.cat([origin_lable, temp_label], dim=0)
    return finetune_feature, finetune_label, origin_feature, origin_lable








