import os.path

from model import AudioCLIP
from Wavelet_method import get_avg, subtract_wavelet_coefficients
import random
import numpy as np
from tqdm import tqdm
from tools import convert_models_to_fp32, compute_sim_np, norm_np, return_top1_result, return_top5_result, \
    read_features, compute_sim_np_multiview, read_audio_features, set_random_seed, show_settings
from agrs_parser import get_args_parser, get_features_paths
import torch
def conduct_sims_computation(args, text_fs, images_fs, labels):

    per_class_length = [0 for _ in range(len(text_fs))]
    per_class_acc1 = [0 for _ in range(len(text_fs))]
    per_class_acc5 = [0 for _ in range(len(text_fs))]
    acc1 = 0
    acc5 = 0

    acc1_per_class = 0
    acc5_per_class = 0

    image_save_list = []

    if args.sims_mode == "Org":
        # norm texts
        text_fs_normed = []
        for x in text_fs:
            text_fs_normed.append(norm_np(x))
        for a_img_fs, ground_truth in tqdm(zip(images_fs, labels), desc="Processing"):
            a_img_fs_norm = norm_np(a_img_fs)
            result = compute_sim_np(a_img_fs_norm, text_fs_normed)

            per_class_length[ground_truth] += 1
            if return_top1_result(result, ground_truth):
                acc1 += 1
                per_class_acc1[ground_truth] += 1
            if return_top5_result(result, ground_truth):
                acc5 += 1
                per_class_acc5[ground_truth] += 1
        acc1 = acc1 / len(images_fs)
        acc5 = acc5 / len(images_fs)

        for i in range(len(text_fs)):
            per_class_acc1[i] /= per_class_length[i]
            per_class_acc5[i] /= per_class_length[i]
            acc1_per_class += per_class_acc1[i]
            acc5_per_class += per_class_acc5[i]
        acc1_per_class /= len(text_fs)
        acc5_per_class /= len(text_fs)


    elif args.sims_mode == "Wave":
        if len(text_fs)<args.num_samples:
            tmp_textfs = [item for item in text_fs for _ in range(args.num_samples)]
            sampled_textfs = random.sample(tmp_textfs, args.num_samples)
        else:
            sampled_textfs = random.sample(text_fs, args.num_samples)
        sampled_imgfs = random.sample(images_fs,args.num_samples)
        text_mean_coefficients = get_avg(sampled_textfs, args)
        img_mean_coefficients = get_avg(sampled_imgfs, args)

        text_fs_DE = [subtract_wavelet_coefficients(x, text_mean_coefficients, args) for x in text_fs]

        # normed
        text_fs_DE_norm = []
        for x in text_fs_DE:
            text_fs_DE_norm.append(norm_np(x))

        for a_img_fs, ground_truth in tqdm(zip(images_fs, labels), desc="Processing"):
            # differentiation enhancement
            a_img_fs_DE = subtract_wavelet_coefficients(a_img_fs,img_mean_coefficients, args)

            # norm img
            a_img_fs_DE_norm = norm_np(a_img_fs_DE)

            result = compute_sim_np(a_img_fs_DE_norm, text_fs_DE_norm)

            per_class_length[ground_truth] += 1
            if return_top1_result(result, ground_truth):
                acc1 += 1
                per_class_acc1[ground_truth] +=1
            if return_top5_result(result, ground_truth):
                acc5 += 1
                per_class_acc5[ground_truth] +=1
        acc1 = acc1/len(images_fs)
        acc5 = acc5 /len(images_fs)



        for i in range(len(text_fs)):
            per_class_acc1[i] /= per_class_length[i]
            per_class_acc5[i] /= per_class_length[i]
            acc1_per_class += per_class_acc1[i]
            acc5_per_class += per_class_acc5[i]
        acc1_per_class /= len(text_fs)
        acc5_per_class /= len(text_fs)

    elif args.sims_mode == "DN":
        lamda = 0.25
        # get the avg of text and img fs
        if len(text_fs)<args.num_samples:
            tmp_textfs = [item for item in text_fs for _ in range(args.num_samples)]
            sampled_textfs = random.sample(tmp_textfs, args.num_samples)
        else:
            sampled_textfs = random.sample(text_fs, args.num_samples)
        sampled_imgfs = random.sample(images_fs, args.num_samples)

        img_fs_mu = np.mean(sampled_imgfs, axis=0)
        text_fs_mu =np.mean(sampled_textfs, axis=0)

        text_fs_dn = text_fs - lamda * text_fs_mu

        text_fs_norm_dn = []
        for x in text_fs_dn:
            text_fs_norm_dn.append(norm_np(x))

        for a_img_fs, ground_truth in tqdm(zip(images_fs, labels), desc="Processing"):

            a_img_fs_dn = a_img_fs - lamda * img_fs_mu
            a_img_fs_norm = norm_np(a_img_fs_dn)
            result = compute_sim_np(a_img_fs_norm, text_fs_norm_dn)

            per_class_length[ground_truth] += 1
            if return_top1_result(result, ground_truth):
                acc1 += 1
                per_class_acc1[ground_truth] +=1
            if return_top5_result(result, ground_truth):
                acc5 += 1
                per_class_acc5[ground_truth] +=1
        acc1 = acc1/len(images_fs)
        acc5 = acc5 /len(images_fs)

        for i in range(len(text_fs)):
            per_class_acc1[i] /= per_class_length[i]
            per_class_acc5[i] /= per_class_length[i]
            acc1_per_class += per_class_acc1[i]
            acc5_per_class += per_class_acc5[i]
        acc1_per_class /= len(text_fs)
        acc5_per_class /= len(text_fs)


    else:
        pass
    # save
    image_save_list = np.array(per_class_acc1)
    filename = "classlevelacc1_{}_{}.npy".format(args.dataset, args.sims_mode)
    savepath = os.path.join("/cache/WaveDNClassLevelAcc", filename)
    np.save(savepath, image_save_list)
    print("done")


    if args.acc_type == "acc":
        return acc1, acc5
    else:
        return acc1_per_class, acc5_per_class


def conduct_sims_computation_audio(args, text_fs, images_fs, labels):
    if args.VLM_Base == "audio_full":
        pt_path = f"/cache/AudioClip/AudioCLIP-Full-Training.pt"
    else:
        # audio_partial
        pt_path = f"/cache/AudioClip/AudioCLIP-Partial-Training.pt"
    aclp = AudioCLIP(pretrained=pt_path)
    aclp.eval()
    scale_audio_text = torch.clamp(aclp.logit_scale_at.exp(), min=1.0, max=100.0)
    scale_audio_text = scale_audio_text.detach().numpy()

    # conduct original sims computation
    acc1 = 0
    acc5 = 0

    if args.sims_mode == "Org":
        # norm texts
        text_fs_normed = []
        for x in text_fs:
            text_fs_normed.append(norm_np(x))
        for a_img_fs, ground_truth in tqdm(zip(images_fs, labels), desc="Processing"):
            # norm img
            a_img_fs_norm = norm_np(a_img_fs)
            result = compute_sim_np(a_img_fs_norm, text_fs_normed, scale_audio_text)

            if return_top1_result(result, ground_truth):
                acc1 += 1
            if return_top5_result(result, ground_truth):
                acc5 += 1
        acc1 = acc1 / len(images_fs)
        acc5 = acc5 / len(images_fs)

    elif args.sims_mode == "Wave":
        if len(text_fs)<args.num_samples:
            tmp_textfs = [item for item in text_fs for _ in range(args.num_samples)]
            sampled_textfs = random.sample(tmp_textfs, args.num_samples)
        else:
            sampled_textfs = random.sample(text_fs, args.num_samples)
        sampled_imgfs = random.sample(images_fs,args.num_samples)


        text_mean_coefficients = get_avg(sampled_textfs, args)
        img_mean_coefficients = get_avg(sampled_imgfs, args)

        text_fs_DE = [subtract_wavelet_coefficients(x, text_mean_coefficients, args) for x in text_fs]
        text_fs_DE_norm = []
        for x in text_fs_DE:
            text_fs_DE_norm.append(norm_np(x))

        for a_img_fs, ground_truth in tqdm(zip(images_fs, labels), desc="Processing"):
            # differentiation enhancement
            a_img_fs_DE = subtract_wavelet_coefficients(a_img_fs,img_mean_coefficients, args)
            a_img_fs_DE_norm = norm_np(a_img_fs_DE)
            result = compute_sim_np(a_img_fs_DE_norm, text_fs_DE_norm,scale_audio_text)

            if return_top1_result(result, ground_truth):
                acc1 += 1
            if return_top5_result(result, ground_truth):
                acc5 += 1
        acc1 = acc1/len(images_fs)
        acc5 = acc5 /len(images_fs)

    elif args.sims_mode == "DN":
        lamda = 0.25
        if len(text_fs)<args.num_samples:
            tmp_textfs = [item for item in text_fs for _ in range(args.num_samples)]
            sampled_textfs = random.sample(tmp_textfs, args.num_samples)
        else:
            sampled_textfs = random.sample(text_fs, args.num_samples)
        sampled_imgfs = random.sample(images_fs,args.num_samples)

        img_fs_mu = np.mean(sampled_imgfs, axis=0)
        text_fs_mu = np.mean(sampled_textfs, axis=0)

        text_fs_dn = text_fs - lamda * text_fs_mu
        text_fs_dn_norm = []
        for x in text_fs_dn:
            text_fs_dn_norm.append(norm_np(x))


        for a_img_fs, ground_truth in tqdm(zip(images_fs, labels), desc="Processing"):
            # differentiation enhancement
            a_img_fs_dn = a_img_fs - lamda * img_fs_mu
            a_img_fs_norm = norm_np(a_img_fs_dn)
            result = compute_sim_np(a_img_fs_norm, text_fs_dn_norm, scale_audio_text)
            if return_top1_result(result, ground_truth):
                acc1 += 1
            if return_top5_result(result, ground_truth):
                acc5 += 1
        acc1 = acc1/len(images_fs)
        acc5 = acc5 /len(images_fs)
    else:
        pass

    if args.acc_type == "acc":
        return acc1, acc5, len(images_fs)


if __name__ == '__main__':
    # init setting

    args = get_args_parser()
    if args.dataset == "pets" or args.dataset == "Flowers102":
        args.acc_type = "acc_per_class"

    # print exp settings
    print("cross modal zero-shot classification:")
    show_settings(args)

    input_path = get_features_paths(args)

    acc1_all_exps = 0
    acc5_all_exps = 0
    for _ in range(args.num_experiments):
        set_random_seed()
        if "audio" in args.VLM_Base:
            text_fs, audios_fs_folds, labels_folds, fold_num = read_audio_features(input_path)
            print("Successfully Load the Features!")
            acc1 = 0
            acc5 = 0
            total_num = 0
            for i in range(fold_num):
                audios_fs = audios_fs_folds[i]
                labels = labels_folds[i]
                acc1_fold, acc5_fold, fold_sample_num = conduct_sims_computation_audio(args, text_fs, audios_fs, labels)
                acc1 += acc1_fold*fold_sample_num
                acc5 += acc5_fold*fold_sample_num
                total_num += fold_sample_num
            acc1 /= total_num
            acc5 /= total_num
        else:
            text_fs, images_fs, labels = read_features(input_path)
            print("Successfully Load the Features!")
            acc1, acc5 = conduct_sims_computation(args, text_fs, images_fs, labels)
        acc1_all_exps += acc1
        acc5_all_exps += acc5

    acc1_all_exps /= args.num_experiments
    acc5_all_exps /= args.num_experiments
    print("Avg Accuracy@1:{:.2f}% in {} exps".format(acc1_all_exps*100, args.num_experiments) )
    print("Avg Accuracy@5:{:.2f}% in {} exps".format(acc5_all_exps*100, args.num_experiments))



