import imp
import os
import torch
import numpy as np
import torch.nn.functional as F
import shutil
import pdb


def adjust_learning_rate(optimizer, epoch, args):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = args.lr * (0.1 ** (epoch // 30))
    print('lr = ', lr)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def read_directory(directory_name):
    file_path = []
    for filename in os.listdir(directory_name):
        filepath = os.path.join(directory_name, filename)
        file_path.append(filepath)
    return file_path


def KL(quant, float):
    # psnr11=psnr11+peak_signal_noise_ratio(x1, y1,data_range=1)
    float_tensor = torch.abs(float)
    quant_tensor = torch.abs(quant)
    
    # kl = kl + F.kl_div(y_tensor.log(), x_tensor, reduction='sum')
    kl = F.kl_div(quant_tensor.log(), float_tensor, reduction='mean')
    
    # print(f'count = {count}')
    # assert  count == 500 * data_percentage   # 当expression1为假的时候, 则抛出异常      
    # print('KL:', kl.item()) #后处理加数据统计，需要保存成txt文件

    return kl.item()


def KL_savefile(args, samples_id, data_percentage=1):
    KL_file = open(args.KL_path, "a")
    float_result_folder_path = args.float_result_folder_path + f'/{samples_id}'
    quant_result_folder_path = args.quant_result_folder_path + f'/{samples_id}'
    floatpath = read_directory(float_result_folder_path)
    quantpath = read_directory(quant_result_folder_path)
    kl = 0
    count = 0
    for filey in quantpath:
        for filex in floatpath:
            if filex.split('/')[-1] == filey.split('/')[-1]:
                count = count + 1
                x1 = np.load(filex, allow_pickle=True).flatten()
                y1 = np.load(filey, allow_pickle=True).flatten()
                # psnr11=psnr11+peak_signal_noise_ratio(x1, y1,data_range=1)
                x_tensor = torch.from_numpy(x1)
                y_tensor = torch.from_numpy(y1)

                x_tensor = torch.abs(x_tensor)
                y_tensor = torch.abs(y_tensor)
                
                # kl = kl + F.kl_div(y_tensor.log(), x_tensor, reduction='sum')
                kl = kl + F.kl_div(y_tensor.log(), x_tensor, reduction='mean')
    
    # print(f'count = {count}')
    # assert  count == 500 * data_percentage   # 当expression1为假的时候, 则抛出异常      
    kl = kl / count
    # print('KL:', kl.item()) #后处理加数据统计，需要保存成txt文件
    KL_file.write('{}: {}\n'.format(samples_id, kl.item()))
    KL_file.close()
    return kl.item()



def save_npy(folder_path, value, samples_id, img_id, idx):
    '''
    idx: the index of input image (decoder layer index)
    '''
    folder_path = folder_path + f'/{samples_id}'
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)

    if hasattr(value,'out'):
        np.save('{}/samples_id_{}_imgid_{}_decoder_layer{}.npy'.format(folder_path, samples_id, img_id, idx), value.out.cpu().detach().numpy())
    else:
        print('传入的module: value没有self.out')
        np.save('{}/samples_id_{}_imgid_{}_decoder_layer{}.npy'.format(folder_path, samples_id, img_id, idx), value)

    # pdb.set_trace()
    # img_name = [j.split()[1].split('/')[-1].split('.')[0] for j in target]
    # np.save('{}/{}.npy'.format(folder_path, img_name[idx]), value[idx])

    # if isinstance(target,list):
    #     if len(target) == 1:
    #         img_name = target[0]['image_id']
    # else:
    #     img_name = [j.split()[1].split('/')[-1].split('.')[0] for j in target]
    # if isinstance(value,dict):
    #     np.save('{}/{}.npy'.format(folder_path, img_name[idx]), value[idx])
    # else:
    #     np.save('{}/{}.npy'.format(folder_path, img_name[idx]), value[idx])