import torch
import transformers
import json
from k_means_constrained import KMeansConstrained
import numpy as np
from tqdm import tqdm




# def matrix_mask(param, sparsity_ratio):
#     with torch.no_grad():
#         param_values = param.data.view(-1)
#         param_abs = (torch.abs(param_values))

#         total_num = param_abs.numel()
        
    
#         k = int((1 - sparsity_ratio) * total_num)
#         threshold = torch.topk(param_abs, k).values[-1]
        
        
#         mask = (torch.abs(param.data) >= threshold)
#         param *= mask
      


def kmeans(matrix0, num_clusters=4, top_k=16, name = None, log_data = None):

    # to numpy
    matrix = matrix0.cpu().to(torch.float32)
    matrix_np = matrix.numpy()
    
    # K-MeansConstrained
    kmeans = KMeansConstrained(n_clusters=num_clusters, size_min=top_k ,max_iter = 50, random_state=42)
    kmeans.fit(matrix_np)
    
   
    cluster_centers = kmeans.cluster_centers_
    labels = kmeans.labels_
 
    # print('labels', labels, type(labels), labels.shape)
    print("聚类标签分布:", np.bincount(labels))
    

    distances = np.linalg.norm(matrix_np - cluster_centers[labels], axis=1)
    
    selected_vectors = []

    indices = []
    
    for cluster_idx in range(num_clusters):
       
        cluster_indices = np.where(labels == cluster_idx)[0]
 
        cluster_distances = distances[cluster_indices]
        
        top_k_indices = cluster_indices[np.argsort(cluster_distances)[:top_k]]
        
    
        top_k_vectors = matrix0[top_k_indices]  # shape: (top_k, dim)
        indices.append(top_k_indices.tolist())
        selected_vectors.append(top_k_vectors)
    
    #(4, k, dim)
    result = torch.stack(selected_vectors)
    log_data[name] = indices
    print("结果形状:", result.shape)  
    
    return result


import os

def copyparam(model, file):
    with open(file, 'r') as f:
        paramidx = json.load(f)
    with torch.no_grad():
        for name, param in model.named_parameters():
            if 'base_layer' in name and not 'bias' in name:
                index = paramidx[name]
                cweight = []
                for exp in range(4):
                    cweight.append(param[index[exp], :])
      
            elif 'lora_A.default' in name:
                selected_columns = int(name.split("lora_A.default.expert_")[1][0])
                param.copy_(cweight[selected_columns])
                param.requires_grad = False
               
            

def copyfromdense(model, randomseed = 42, indexfile = './copy_from_dense_log.json', rank = 128, use_cache = False):

    if use_cache:
        if not os.path.exists(indexfile):
            raise ValueError("The cache of cluter is not exists.")
        else:
            print("copy lora params from {}, rank = {}".format(indexfile, str(rank)))
            copyparam(model, indexfile)
    
    else:
        log_data = {}
        
        with torch.no_grad():
            torch.manual_seed(randomseed)
            
            # base_layer_weight = None
            for name, param in tqdm(model.named_parameters()):
                
                if 'base_layer' in name and not 'bias' in name:
                    print(name, param.size())
                    base_layer_weight = param
                    cweights = kmeans(param, top_k=rank, name = name, log_data = log_data)
                    
    
                elif 'lora_A.default' in name:
                    print(name)
                    
                    lora_rank, _ = param.size()
                    
                    
                    selected_columns = int(name.split("lora_A.default.expert_")[1][0])
                    print(selected_columns)
                    param.copy_(torch.squeeze(cweights[selected_columns,:,:]))
                    param.requires_grad = False
                    
                    # log_data[name] = selected_indices.tolist()
        
        # Save to JSON
        with open(indexfile, 'w') as f:
            json.dump(log_data, f)    
        
    return model



if __name__ == '__main__':
    model = transformers.AutoModelForCausalLM.from_pretrained(
        '/hy-tmp/mistral',
        torch_dtype = torch.bfloat16,
        trust_remote_code=True,
    )
    from peft import LoraConfig, get_peft_model
    from peft.tuners.lora import LoraLayer
   
    config = LoraConfig(
        r=128,
        lora_alpha=256,
        target_modules="q_proj,k_proj,v_proj,o_proj,gate_proj,down_proj,up_proj".split(','),
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
    )
    # model = model.to(torch.bfloat16)
    model = get_peft_model(model, config)
    

    copyfromdense(model, indexfile = '/hy-tmp/copy_from_dense_log.json', use_cache = True)
    
