import os
import numpy as np
import random
import numpy as np
import torch
import datetime
import tqdm
from zs_datasets.CLIPImageDataset import CLIPImageDataset, CLIPCapDataset
import torch.nn.functional as F

def extract_all_images(images, model, device, args, preprocess,normalize=True):
    if args.sims_mode == "Wave":
        normalize = False

    data = torch.utils.data.DataLoader(CLIPImageDataset(images,preprocess), batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False)
    all_image_features = []
    with torch.no_grad():
        for b in tqdm.tqdm(data):
            b = b['image'].to(device)
            features = model.encode_image(b)
            if normalize:
                features = F.normalize(features, p=2, dim=1)
            all_image_features.append(features)
    all_image_features = torch.cat(all_image_features, dim=0)
    return all_image_features


def extract_all_captions(captions, model, device,args, normalize=True):
    if args.sims_mode == "Wave":
        normalize = False

    data = torch.utils.data.DataLoader(
        CLIPCapDataset(captions),
        batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False)
    all_text_features = []
    with torch.no_grad():
        for b in data:
            b = b['caption'].to(device)
            features = model.encode_text(b)
            if normalize:
                features = F.normalize(features, p=2, dim=1)
            all_text_features.append(features)
    all_text_features = torch.cat(all_text_features, dim=0)
    return all_text_features


def set_random_seed():
    seed = random.randint(1, 10000)
    print('Random seed: {}'.format(seed))
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def show_settings(args):
    divider = "$" * 30
    print(divider)
    print(datetime.datetime.utcnow())
    print("VLM_Base:{}".format(args.VLM_Base))
    print("dataset:{}".format(args.dataset))
    print("sims_mode:{}".format(args.sims_mode))
    if args.sims_mode=="Wave":
        print("hyper_level:{}".format(args.hyper_level))
        print("wavelet_base:{}".format(args.wavelet_base))
        print("hyper_strength:{}".format(args.hyper_strength))
    print("num_experiments:{}".format(args.num_experiments))
    print(divider)


def return_top1_result(sims, groundtruth):
    if sims.index(max(sims)) == groundtruth:
        return True
    else:
        return False

def return_top5_result(sims, groundtruth):
    arr1 = np.array(sims)
    maxt1 = arr1.argsort()[-5:][::-1]
    if groundtruth in maxt1:
        return True
    else:
        return False

def return_top10_result(sims, groundtruth):
    arr1 = np.array(sims)
    maxt1 = arr1.argsort()[-10:][::-1]
    if groundtruth in maxt1:
        return True
    else:
        return False


def return_top1_result_query(sims, groundtruth):
    result_no = sims.index(max(sims))
    if result_no in groundtruth:
        return True
    else:
        return False

def return_top5_result_query(sims, groundtruth):
    arr1 = np.array(sims)
    maxt1 = arr1.argsort()[-5:][::-1]
    for a_ground_truth in groundtruth:
        if a_ground_truth in maxt1:
            return True
    return False

def return_top10_result_query(sims, groundtruth):
    arr1 = np.array(sims)
    maxt1 = arr1.argsort()[-10:][::-1]
    for a_ground_truth in groundtruth:
        if a_ground_truth in maxt1:
            return True
    return False


def norm_np(input_np):
    tmp_norm = np.linalg.norm(input_np)
    return input_np / tmp_norm


def compute_sim_np(img_fs_input, text_fs_input,scale = 100):
    ret_list = []
    for text_f in text_fs_input:
        ret_list.append(np.sum(img_fs_input * text_f) * scale)
    return ret_list

def compute_sim_np_multiview(img_fs_input, text_fs_input):
    ret_list = []
    for text_f in text_fs_input:
        a_sims = 0
        for i in range(len(text_f)):
            a_sims += np.sum(text_f[i]*img_fs_input[i])*100
        ret_list.append(a_sims)
    return ret_list



def convert_models_to_fp32(model):
    for p in model.parameters():
        p.data = p.data.float()
        if p.grad:
            p.grad.data = p.grad.data.float()


def read_features(input_path):
    feature_files = os.listdir(input_path)
    feature_files.remove("text_features.npy")
    text_f_filename = os.path.join(input_path, "text_features.npy")
    text_features = np.load(text_f_filename)
    text_features = [text_features[i] for i in range(text_features.shape[0])]
    image_features = []
    labels = []
    for a_sample_path in feature_files:
        feature_path = os.path.join(input_path, a_sample_path)
        a_image_feature = np.load(feature_path)
        np_array_squeezed = np.squeeze(a_image_feature, axis=0)
        image_features.append(np_array_squeezed)
        a_label_name = a_sample_path.split("_")[-1].split(".")[0]
        labels.append(int(a_label_name))
    return text_features, image_features, labels

def read_audio_features(input_path):
    feature_files = os.listdir(input_path)
    feature_files.remove("text_features.npy")
    text_f_filename = os.path.join(input_path, "text_features.npy")
    text_features = np.load(text_f_filename)
    text_features = [text_features[i] for i in range(text_features.shape[0])]

    audio_features_fold = [[] for i in range(10)]
    labels_fold = [[] for i in range(10)]

    for a_sample_name in feature_files:
        fold_no = int(a_sample_name.split("_")[1])
        feature_path = os.path.join(input_path, a_sample_name)
        a_audio_feature = np.load(feature_path, allow_pickle=True)

        np_array_squeezed = np.squeeze(a_audio_feature, axis=0)
        audio_features_fold[fold_no-1].append(np_array_squeezed)

        a_label_name = a_sample_name.split("_")[-1].split(".")[0]
        labels_fold[fold_no-1].append(int(a_label_name))

    non_empty_elements = [element for element in audio_features_fold if element != []]
    fold_num = len(non_empty_elements)
    audio_features_fold = audio_features_fold[0:fold_num]
    labels_fold = labels_fold[0:fold_num]

    return text_features, audio_features_fold, labels_fold, fold_num