import numpy as np
import torch
from torchvision import transforms
from stylegan2.makegan import make_gan
# from pytorch_pretrained_gans import make_gan
import PIL.Image
import sys
import os
import json
from pytorch_lightning.utilities.seed import seed_everything
from contra_single import CNNEncoder,test
import torch.nn as nn
sys.path.append('/home/byzeng/project/pytorch-pretrained-gans-main/pytorch_pretrained_gans')
def load_normalize_testdata(path):
    anchor_llama=torch.from_numpy(np.load(path))
    for i in range(anchor_llama.shape[0]):
        mean = torch.mean(anchor_llama[i])
        std = torch.std(anchor_llama[i])
        anchor_llama[i]=(anchor_llama[i]-mean)/std
    # mean = torch.mean(anchor_llama)
    # std = torch.std(anchor_llama)
    # anchor_llama=(anchor_llama-mean)/std
    return anchor_llama.unsqueeze(0)

def convert_to_images(obj):
    """ Convert an output tensor from BigGAN in a list of images.
        Params:
            obj: tensor or numpy array of shape (batch_size, channels, height, width)
        Output:
            list of Pillow Images of size (height, width)
    """
    try:
        import PIL
    except ImportError:
        raise ImportError("Please install Pillow to use images: pip install Pillow")

    if not isinstance(obj, np.ndarray):
        obj = obj.detach().cpu().numpy()

    obj = obj.transpose((0, 2, 3, 1))
    obj = np.clip(((obj + 1) / 2.0) * 256, 0, 255)

    img = []
    for i, out in enumerate(obj):
        out_array = np.asarray(np.uint8(out), dtype=np.uint8)
        img.append(PIL.Image.fromarray(out_array))
    return img
def get_cosine_similarity(features):
    # for key in ['koala','baize','medalpaca','alpaca','MiniGPT','Wizard-Vicuna-7B','WizardLM-7B'\
    #             ,'chinese-alpaca','chinese-llama','alpaca-lora','vicuna']:
    for key in ['koala-7b','baize-v2-7b','medalpaca-7b','alpaca-native','MiniGPT-4-LLaMA-7B','wizardLM-7B-HF'\
                    ,'chinese-alpaca-7b-merged','chinese-llama-7b-merged','alpaca-lora-7b','vicuna-7b-v1','beaver-7b-v1']:
        cos_llama=torch.cosine_similarity(features['llama-7b-hf'],features[key],dim=1)
        cos.setdefault('llama', []).append(cos_llama.item())
    cos['mean_llama']=sum(cos['llama'])/len(cos['llama'])
    for key1,key2 in zip(['falcon-40b-instruct','Qwen-7B-Chat','mpt-30b-chat','Llama-2-7B-fp16','Baichuan-13B-Chat','internlm-7b'],\
                            ['falcon-40b','Qwen-7B','mpt-30b','Llama-2-7b-chat-fp16','Baichuan-13B-Base','internlm-chat-7b']):
        cos_sft=torch.cosine_similarity(features[key1],features[key2],dim=1)
        cos.setdefault('base_sft', []).append(cos_sft.item())
    cos['mean_base_sft']=sum(cos['base_sft'])/len(cos['base_sft'])
    dif_model=['THUDM_chatglm2-6b','Cerebras-GPT-1', 'gpt-j-6b', 'Baichuan-13B-Base', 'internlm-7b', 'gptneox_seed3', 'galactica-30b', 'huggyllama_llama-65b',  'gpt-neo-2', 'open_llama_7b', 'THUDM_chatglm-6b', 'gptneox_seed2',  'gptneox_seed4', 'llama-7b-hf', 'Qwen-7B', 'gpt2-large', 'falcon-40b', 'huggyllama_llama-30b', 'opt-30b', 'gpt-neox-20b', 'pythia-12b', 'huggyllama_llama-13b', 'gptneox_seed1', 'Llama-2-7B-fp16','bloom-7b1', 'baichuan-7B', 'mpt-30b']
    for i in range(len(dif_model)):
        for j in range(i + 1, len(dif_model)):
            cos_dif=torch.cosine_similarity(features[dif_model[i]],features[dif_model[j]],dim=1)
            cos.setdefault('diff', []).append(abs(cos_dif.item()))
            cos[dif_model[i]+dif_model[j]]=abs(cos_dif.item())
    cos['mean_diff']=sum(cos['diff'])/len(cos['diff'])
    return cos
def save_as_images(obj, file_name='output'):
    """ Convert and save an output tensor from BigGAN in a list of saved images.
        Params:
            obj: tensor or numpy array of shape (batch_size, channels, height, width)
            file_name: path and beggingin of filename to save.
                Images will be saved as `file_name_{image_number}.png`
    """
    img = convert_to_images(obj)
    for i, out in enumerate(img):
        current_file_name = file_name + '_%d.png' % i
        logger.info("Saving image to {}".format(current_file_name))
        out.save(current_file_name, 'png')
def generate_model_images(model_matrix, png_path, seed):
    seed_everything(seed)
    # G=make_gan(gan_type='stylegan2',model_name='afhqdog').to('cuda:0')
    mean = torch.mean(model_matrix)
    std = torch.std(model_matrix)
    model_matrix=(model_matrix-mean)/std
    model_matrix=G.G.mapping(model_matrix, None, truncation_psi=G.truncation_psi, truncation_cutoff=G.truncation_cutoff)
    x_n = G(z=model_matrix)  # -> torch.Size([1, 3, 128, 128])
    img_n=convert_to_images(x_n)
    img_n[0].save(png_path+'.png')
# hiden_size =512
# for i in range(15):
#     CNNencoder = CNNEncoder().cuda()
#     data_parallel_model = torch.load('/home/byzeng/project/weights-search/encoder_512_'+str(i)+'.pth')
#     if isinstance(data_parallel_model, nn.DataParallel):
#         data_parallel_model = data_parallel_model.module.state_dict()
#     CNNencoder.load_state_dict(data_parallel_model)
#     test(CNNencoder)
if __name__ == '__main__':
    os.environ['CUDA_VISIBLE_DEVICES'] = '7'
    os.environ['NCCL_P2P_DISABLE'] = '1'
    seed_everything(100)
    CNNencoder = CNNEncoder(48).cuda()
    # CNNencoder = CNNEncoder().cuda()
    # model_name = 'random'
    # CNNencoder = torch.load("/home/byzeng/project/weights-search/encoder/encoder_new3_512_16.pth")#6,4,13,13,16 is good
    # path="/home/byzeng/project/weights-search/encoder/encoder_new3_512_16.pth"
    # path="/home/byzeng/project/weights-search/encoder/encoder_gan8_512_21.pth"
    # path="/home/byzeng/project/weights-search/encoder/encoder_gan8_512_25.pth"
    # path="/home/byzeng/project/weights-search/encoder/encoder_gan8_512_28.pth"
    # path="/home/byzeng/project/weights-search/encoder/encoder_gan2_512_0.pth"
    # path="/home/byzeng/project/weights-search/encoder/encoder_gan_512_0.pth"
    # path="/home/byzeng/project/weights-search/encoder/encoder_5gan4096_512_0.pth"
    # path="/home/byzeng/project/weights-search/encoder/encoder_6gan4096_512_9.pth"
    # path="/home/byzeng/project/weights-search/encoder/encoder_6gan4096_512_11.pth"
    # path="/home/byzeng/project/weights-search/goodencoders/encoder_k48_l16_7.pth"
    # path="/home/byzeng/project/weights-search/goodencoders/encoder_n0.2_k48_p1.0_lr0.0001_0_26.pth"
    # path="/home/byzeng/project/weights-search/goodencoders/encoder_n0.4_k48_p1.3_lr0.0001_4_26.pth"
    # path="/home/byzeng/project/weights-search/goodencoders/encoder_n0.4_k48_p1.3_lr0.0001_3_26.pth" best
    # path="/home/byzeng/project/weights-search/goodencoders/encoder_n0.35_k48_p1.3_lr0.0001_0_26.pth" best
    # path="/home/byzeng/project/weights-search/goodencoders/encoder_n0.4_k48_p1.3_lr0.0001_9_26.pth"#bestbest
    # path="/home/byzeng/project/weights-search/newencoder/encoder_n0.4_k48_p1.3_lr5e-05_2_21_10.pth"
    # path="/home/byzeng/project/weights-search/goodencodersnew/encoder_n0.4_k48_p1.4_lr0.0001_7_21_5.pth"#best
    path="/home/byzeng/project/weights-search/newencoder/encoder_n0.4_k48_p1.4_lr0.0001_7_21_10.pth"#best
    # path="/home/byzeng/project/weights-search/goodencodersnew/encoder_n0.4_k48_p1.4_lr0.0001_10_21_5.pth"
    # path="/home/byzeng/project/weights-search/goodencodersnew/encoder_n0.4_k48_p1.3_lr0.0001_16_21_10.pth"
    model_name = os.path.basename(path)  # 获取文件名，包括扩展名
    model_name = os.path.splitext(model_name)[0]  # 去掉扩展名
    CNNencoder = torch.load(path)#6,4,13,13,16 is good
    # G=make_gan(gan_type='stylegan2',model_name='afhq').to('cuda')# else:
    G=make_gan(gan_type='stylegan2',model_name='afhqdog').to('cuda')# else:
    # G=make_gan(gan_type='stylegan2',model_name='afhqcat').to('cuda')# else:
        # =data_parallel_model.cuda(7)
    # CNNencoder = CNNEncoder.cuda(7)
    # test(CNNencoder)
    #读取testdata文件夹下的所有文件
    folder_path = "/home/byzeng/project/weights-search/inputweightsxxnew/"  # Replace with the actual folder path containing the JSON files
    model_files = [folder_path+filename for filename in os.listdir(folder_path) if filename.endswith("npy")]
    features={}
    CNNencoder.eval()
    cos={}
    with torch.no_grad():
        for filename,modelname in zip(model_files,os.listdir(folder_path)):
            model_matrix=load_normalize_testdata(filename)
            model_matrix=CNNencoder(model_matrix.unsqueeze(0).cuda())
            mean=torch.mean(model_matrix)
            std=torch.std(model_matrix)
            model_feature=((model_matrix-mean)/std).cpu()
            features[modelname.split('.')[0]]=model_feature        
            generate_model_images(model_matrix,'/home/byzeng/project/weights-search/fingerprints/'+ \
                                modelname.split('.')[0],seed=100)
            # clip_image=np.clip(model_matrix[0].flatten().numpy(),-3,3)
            # plt.hist(model_matrix[0].flatten().cpu().detach().numpy(),bins=np.linspace(-3,3, 1000),range=(-3,3))
            # plt.savefig("/home/byzeng/project/weights-search/m_h.jpg")
        #将features保存为npy文件
        feature_arrays = [tensor.cpu().detach().numpy() for tensor in features.values()]
        # np.save(f"/home/byzeng/project/weights-search/{model_name}_features.npy",feature_arrays)
        # json.dump(features,open("/home/byzeng/project/weights-search/features.json",'w'))   
        cos=get_cosine_similarity(features)
        print('cos_llama:%.3f   cos_sft:%.3f    cos_diff:%.3f'%(cos['mean_llama'],cos['mean_base_sft'],cos['mean_diff']))
        json.dump(cos,open(f"/home/byzeng/project/weights-search/{model_name}_cos.json",'w'))   
        # np.save("/home/byzeng/project/weights-search/features.npy",features.cpu().detach().numpy())