import os
import clip
import torch
import numpy as np
from tqdm import tqdm
from zs_datasets.CLIPImageDataset import CLIPImageDataset, CLIPCapDataset
from WaveDiffer.backup.load_querying_dataset import load_querying_dataset
from model import AudioCLIP
from utils.transforms import ToTensor1D
from zs_datasets.load_the_test_datasets import load_test_datasets
from tools import convert_models_to_fp32
from agrs_parser import get_args_parser, get_features_paths

def record_logits(val_loader, texts, model, output_path):
    with torch.no_grad():

        if args.dataset == "esc50" or args.dataset == "urbansound8k":
            #audio clip
            ((_, _, text_feature), _), _ = model(text=texts)
            text_feature = text_feature.cpu().detach().numpy()
            save_filename = os.path.join(output_path, "text_features.npy")
            np.save(save_filename, text_feature)

            for a_val_loader in val_loader:
                for i, (audios, target, fold) in enumerate(tqdm(a_val_loader)):
                    audios = audios.to(device)
                    target_name = str(target.item())
                    fold = str(fold.item())
                    ((audio_features, _, _), _), _ = model(audio=audios)
                    save_filename = "fold_{}_no_{}_label_{}.npy".format(fold, i, target_name)
                    save_path = os.path.join(output_path, save_filename)
                    np.save(save_path, audio_features)
        else:
            # image feature extraction
            text_tokens = clip.tokenize(texts).to(device)
            text_feature = model.encode_text(text_tokens)
            text_feature = text_feature.cpu().detach().numpy()
            save_filename = os.path.join(output_path, "text_features.npy")
            np.save(save_filename, text_feature)

            for i, (images, target) in enumerate(tqdm(val_loader)):
                images = images.to(device)
                target_name = str(target.item())
                image_feature = model.encode_image(images).cpu().detach().numpy()
                save_filename = "no_{}_label_{}.npy".format(i, target_name)
                save_path = os.path.join(output_path,save_filename)
                np.save(save_path, image_feature)
        print("done!")

if __name__ == '__main__':
    args = get_args_parser()
    output_path = get_features_paths(args)

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # load the model:
    if "audio" not in args.VLM_Base:
        model, preprocess = clip.load(args.VLM_Base, device, jit=False)
        convert_models_to_fp32(model)
    else:
        if args.VLM_Base =="audio_full":
            pt_path = f"/cache/AudioClip/AudioCLIP-Full-Training.pt"
        else:
            #audio_partial
            pt_path = f"/cache/AudioClip/AudioCLIP-Partial-Training.pt"
        model = AudioCLIP(pretrained=pt_path)
        model.eval()
        preprocess = ToTensor1D() # audio_transforms
    print("successfully load the model")

    # load the dataset
    test_dataloader, texts = load_test_datasets(args, preprocess)
    # conduct feature extraction
    record_logits(test_dataloader, texts, model, output_path)



