import torch
from encodec import EncodecModel
import librosa
import os
import fire
import json
import numpy as np

from valle.data import (
    AudioTokenConfig,
    AudioTokenExtractor,
    AudioTokenConfig_16k,
    AudioTokenExtractor_16k,
    AudioTokenExtractor_16k_tfcodec,
    TextTokenizer,
    tokenize_text,
    ApplyKmeans,
    HubertFeatureReader
)

ckpt_path = "/dev_huaying/zhijun/models/hubert/hubert_base_ls960.pt"
layer = 9
km_path = "/dev_huaying/zhijun/models/hubert/hubert_base_ls960_L9_km500.bin"
reader = HubertFeatureReader(ckpt_path, layer)
apply_kmeans = ApplyKmeans(km_path)

acoustics_dic = {}

def computer_semantic(wav_path, depup=False):
    feat = reader.get_feats(wav_path)
    lab = apply_kmeans(feat).tolist()

    if depup is True:
        unique_tokens = []  
        for token in lab:  
            if token not in unique_tokens:  
                unique_tokens.append(token)
        lab = unique_tokens
    print(lab)

    return lab


def get_dic(folder_path, depup=False):
    for root, dirs, files in os.walk(folder_path):
        for file_name in files:
            if file_name.endswith(".wav"):
                file_path = os.path.join(root, file_name)
                spk = root.split('/')[-2]
                print(file_path)

                if file_name not in acoustics_dic.keys():
                    print(f"{file_name} not in acoustics_dic")
                    
                    acoustics_dic[file_name] = {}
                    acoustics_dic[file_name][spk] = computer_semantic(file_path, depup)
                    # append(computer_semantic(file_path))
                else:
                    acoustics_dic[file_name][spk] = computer_semantic(file_path, depup)
    
    if depup is False:
            
        with open("/mnt/zhijun/Accents/combine_L1_L2/acoustic_tokens_dic/native_l1_l2_arctic_semantic_dic_v2.json", "w") as json_file:
            json.dump(acoustics_dic, json_file)
    else:
        with open("/mnt/zhijun/Accents/combine_L1_L2/acoustic_tokens_dic/native_l1_l2_arctic_semantic_dic_v2_depup.json", "w") as json_file_1:
            json.dump(acoustics_dic, json_file_1)       

if __name__=="__main__":
    fire.Fire(get_dic)
    # with open("./l1_l2_arctic_acoustics_dic.json", "r") as json_file:  
    #     loaded_dict = json.load(json_file) 
    #     for key, values in loaded_dict.items():
    #         print(key)
    #         for value in values:
    #             values_array = np.array(list(value))  

    #             print(values_array.shape)


