import argparse
import json
import os
import random
import warnings
import numpy as np
import torch
from torchmetrics import Accuracy
from pathlib import Path
import os

import clip
import torch
import numpy as np
from tqdm import tqdm
from Wavelet_method import get_avg, subtract_wavelet_coefficients
from zs_datasets.CLIPImageDataset import CLIPImageDataset, CLIPCapDataset
from model import AudioCLIP
from utils.transforms import ToTensor1D
from zs_datasets.load_the_test_datasets import load_test_datasets
from tools import convert_models_to_fp32, norm_np, read_audio_features, set_random_seed, compute_sim_np, \
    return_top1_result, return_top5_result, return_top10_result, return_top1_result_query, return_top5_result_query, \
    return_top10_result_query
from agrs_parser import get_args_parser, get_features_paths
from agrs_parser import get_args_parser
from tools import extract_all_images, extract_all_captions, show_settings
LAMBDA = 0.5

def load_imageretrival_datas(args,data_path):
    refs = []
    images = []
    dev_images = []
    dev_refs = []

    # load flickr30k dataset
    if args.dataset == 'flickr30k':

        FLICKR30K_DIR = data_path
        dataset_path = Path(FLICKR30K_DIR, "test")
        sentences_path = Path(FLICKR30K_DIR, "30k-entities", "Sentences")
        image_paths = os.listdir(dataset_path)
        all_images = [i.replace('.jpg', '')
                      for i in image_paths if i.endswith('.jpg')]
        with open(Path(FLICKR30K_DIR, 'test.txt'), 'r') as fb:
            for line in fb:
                image = line.strip()
                image_path = Path(dataset_path, f"{image}.jpg")
                images.append(str(image_path))
                ref_path = Path(sentences_path, f"{image}.txt")
                ref = []
                with open(ref_path, 'r') as f2:
                    for raw in f2:
                        splitted = raw.split(' ')
                        processed = []
                        for s in splitted:
                            if '[' in s:
                                continue
                            else:
                                processed.append(s.replace(']', '').replace('\n', ''))
                        ref.append(' '.join(processed))
                refs.append(ref)
        for image in all_images:
            dev_image_path = str(Path(dataset_path, f"{image}.jpg"))
            if dev_image_path in images:
                continue
            dev_images.append(dev_image_path)
            ref_path = Path(sentences_path, f"{image}.txt")
            ref = []
            with open(ref_path, 'r') as f2:
                for raw in f2:
                    splitted = raw.split(' ')
                    processed = []
                    for s in splitted:
                        if '[' in s:
                            continue
                        else:
                            processed.append(
                                s.replace(']', '').replace('\n', ''))
                    ref.append(' '.join(processed))
            dev_refs.append(ref)
        assert len(dev_images) + len(images) == len(all_images)

    # load mscoco dataset
    elif args.dataset == 'mscoco':
        MSCOCO_DIR = data_path
        with open(Path(MSCOCO_DIR, "annotations", "captions_val2014.json"), "r") as fb:
            caption_dicts = json.load(fb)['annotations']
        with open(Path(MSCOCO_DIR, "annotations", "coco_test_ids.npy"), "rb") as fb:
            test_ids = set(np.load(fb))
        with open(Path(MSCOCO_DIR, "annotations", "coco_dev_ids.npy"), "rb") as fb:
            dev_ids = set(np.load(fb))
        image2caption = {}
        dev_image2caption = {}
        for d in caption_dicts:
            image = d['image_id']
            if not d['id'] in test_ids:
                continue
            if not image in image2caption:
                image2caption[image] = []
            cap = d['caption'].strip().split(' ')
            cap = ' '.join(cap)
            image2caption[image].append(cap)

        for d in caption_dicts:
            image = d['image_id']
            if not d['id'] in dev_ids:
                continue
            if not image in dev_image2caption:
                dev_image2caption[image] = []
            cap = d['caption'].strip().split(' ')
            cap = ' '.join(cap)
            dev_image2caption[image].append(cap)

        for image, captions in image2caption.items():
            img_path = Path(MSCOCO_DIR, "val2014", f"COCO_val2014_{str(image).rjust(12, '0')}.jpg")
            images.append(str(img_path))
            refs.append(captions)
        for image, captions in dev_image2caption.items():
            dev_img_path = Path(MSCOCO_DIR, "val2014", f"COCO_val2014_{str(image).rjust(12, '0')}.jpg")
            dev_images.append(str(dev_img_path))
            dev_refs.append(captions)

    unique_images = []
    unique_refs = []
    saved = set()
    for image, ref in zip(images, refs):
        if not image in saved:
            unique_images.append(image)
            unique_refs.append(ref)
            saved.add(image)
    images = unique_images
    refs = unique_refs
    all_refs = []
    labels = []
    for i, rs in enumerate(refs):
        for r in rs:
            all_refs.append(r)
            labels.append(i)

    unique_images = []
    unique_refs = []
    saved = set()
    for image, ref in zip(images, refs):
        if not image in saved:
            unique_images.append(image)
            unique_refs.append(ref)
            saved.add(image)
    images = unique_images
    refs = unique_refs

    all_refs = []
    labels = []
    for i, rs in enumerate(refs):
        for r in rs:
            all_refs.append(r)
            labels.append(i)
    return images, all_refs, labels

def compute_imagetext_retrieval(model, images, refs,labels, image_features,text_features,device, args, preprocess):
    if args.sims_mode == "Org":
        _image_features = image_features
        _text_features = text_features

    if args.sims_mode == "DN":
        sampled_image_features = extract_all_images(random.sample(images, args.num_samples), model, device, args, preprocess).cpu()
        sampled_text_features = extract_all_captions(random.sample(refs, args.num_samples), model, device, args).cpu()
        _image_features = image_features - LAMBDA * torch.mean(sampled_image_features, dim=0)
        _text_features = text_features - LAMBDA * torch.mean(sampled_text_features, dim=0)

    if args.sims_mode == "Wave":
        sampled_image_features = extract_all_images(random.sample(images, args.num_samples), model, device, args, preprocess).cpu()
        sampled_text_features = extract_all_captions(random.sample(refs, args.num_samples), model, device, args).cpu()
        image_features = image_features.numpy()
        text_features = text_features.numpy()
        sampled_image_features= sampled_image_features.numpy()
        sampled_text_features = sampled_text_features.numpy()

        # normalize
        text_fs_norm = []
        img_fs_norm = []
        text_sample_norm = []
        img_sample_norm = []
        for x in text_features:
            text_fs_norm.append(norm_np(x))
        for x in image_features:
            img_fs_norm.append(norm_np(x))

        for x in sampled_text_features:
            text_sample_norm.append(norm_np(x))
        for x in sampled_image_features:
            img_sample_norm.append(norm_np(x))

        text_mean_coefficients = get_avg(text_sample_norm, args)
        img_mean_coefficients = get_avg(img_sample_norm, args)

        text_fs_DE = [subtract_wavelet_coefficients(x, text_mean_coefficients, args) for x in text_fs_norm]
        img_fs_DE = [subtract_wavelet_coefficients(x, img_mean_coefficients, args) for x in img_fs_norm]

        _image_features = torch.from_numpy(np.array(img_fs_DE))
        _text_features = torch.from_numpy(np.array(text_fs_DE))

    sim = (_text_features @ _image_features.T).cpu()
    if args.retrival_type == "image2text":
        sim = sim.T
        indexes = torch.argsort(sim, dim=1, descending=True)[:, :10]
        w, h = indexes.size()
        index_labels = torch.zeros(w, h).long()
        for i in range(w):
            for j in range(h):
                index_labels[i, j] = labels[indexes[i, j]]
        top1 = torch.mean(torch.where(
            torch.sum(index_labels[:, :1] == torch.arange(w).reshape(-1, 1), dim=1) > 0, 1.0, 0.0))
        top5 = torch.mean(torch.where(
            torch.sum(index_labels[:, :5] == torch.arange(w).reshape(-1, 1), dim=1) > 0, 1.0, 0.0))
        top10 = torch.mean(torch.where(
            torch.sum(index_labels == torch.arange(w).reshape(-1, 1), dim=1) > 0, 1.0, 0.0))
    else:
        num_classes_img = sim.size(1)
        top1 = Accuracy(top_k=1, task="multiclass",
                  num_classes=num_classes_img)(sim, labels)
        top5 = Accuracy(top_k=5, task="multiclass",
                  num_classes=num_classes_img)(sim, labels)
        top10 = Accuracy(top_k=10, task="multiclass",
                   num_classes=num_classes_img)(sim, labels)
    return top1,top5,top10

def conduct_text_audio_retrival(args, text_fs, audios_fs, labels):
    if args.VLM_Base == "audio_full":
        pt_path = f"/cache/AudioClip/AudioCLIP-Full-Training.pt"
    else:
        # audio_partial
        pt_path = f"/cache/AudioClip/AudioCLIP-Partial-Training.pt"
    aclp = AudioCLIP(pretrained=pt_path)
    aclp.eval()
    scale_audio_text = torch.clamp(aclp.logit_scale_at.exp(), min=1.0, max=100.0)
    scale_audio_text = scale_audio_text.detach().numpy()
    top1 = 0
    top5 = 0
    top10 = 0

    text_label = {}
    for index, (audio, label) in enumerate(zip(audios_fs, labels)):
        if label not in text_label.keys():
            text_label[label]= []
        text_label[label].append(index)

    if args.sims_mode == "Org":
        audio_fs_normed = []
        for x in audios_fs:
            audio_fs_normed.append(norm_np(x))
        for index, a_text_fs in tqdm(enumerate(text_fs), desc="Processing"):
            ground_truth = text_label[index]
            a_text_fs_norm = norm_np(a_text_fs)
            result = compute_sim_np(a_text_fs_norm, audio_fs_normed, scale_audio_text)
            if return_top1_result_query(result, ground_truth):
                top1 += 1
            if return_top5_result_query(result, ground_truth):
                top5 += 1
            if return_top10_result_query(result, ground_truth):
                top10 += 1
        top1 = top1 / len(text_fs)
        top5 = top5 / len(text_fs)
        top10 = top10 / len(text_fs)


    if args.sims_mode == "DN":
        lamda = 0.5
        if len(text_fs)<args.num_samples:
            tmp_textfs = [item for item in text_fs for _ in range(args.num_samples)]
            sampled_textfs = random.sample(tmp_textfs, args.num_samples)
        else:
            sampled_textfs = random.sample(text_fs, args.num_samples)
        sampled_audiofs = random.sample(audios_fs, args.num_samples)
        audio_fs_mu = norm_np(np.mean(sampled_audiofs, axis=0))
        text_fs_mu = norm_np(np.mean(sampled_textfs, axis=0))
        audio_fs_normed = []
        for x in audios_fs:
            audio_fs_normed.append(norm_np(x))
        audio_fs_normed = audio_fs_normed - lamda * audio_fs_mu
        for index, a_text_fs in tqdm(enumerate(text_fs), desc="Processing"):
            ground_truth = text_label[index]
            a_text_fs_norm = norm_np(a_text_fs)
            a_text_fs_norm = a_text_fs_norm - lamda * text_fs_mu
            result = compute_sim_np(a_text_fs_norm, audio_fs_normed, scale_audio_text)
            if return_top1_result_query(result, ground_truth):
                top1 += 1
            if return_top5_result_query(result, ground_truth):
                top5 += 1
            if return_top10_result_query(result, ground_truth):
                top10 += 1
        top1 = top1 / len(text_fs)
        top5 = top5 / len(text_fs)
        top10 = top10 / len(text_fs)

    if args.sims_mode == "Wave":

        if len(text_fs)<args.num_samples:
            tmp_textfs = [item for item in text_fs for _ in range(args.num_samples)]
            sampled_textfs = random.sample(tmp_textfs, args.num_samples)
        else:
            sampled_textfs = random.sample(text_fs, args.num_samples)
        sampled_audiofs = random.sample(audios_fs,args.num_samples)


        sampled_audiofs_normed = []
        for x in sampled_audiofs:
            sampled_audiofs_normed.append(norm_np(x))
        sampled_textfs_normed = []
        for x in sampled_textfs:
            sampled_textfs_normed.append(norm_np(x))

        text_mean_coefficients = get_avg(sampled_textfs_normed, args)
        audio_mean_coefficients = get_avg(sampled_audiofs_normed, args)

        audio_fs_normed = []
        for x in audios_fs:
            audio_fs_normed.append(norm_np(x))
        audio_fs_DE = [subtract_wavelet_coefficients(x, audio_mean_coefficients, args) for x in audio_fs_normed]

        for index, a_text_fs in tqdm(enumerate(text_fs), desc="Processing"):
            ground_truth = text_label[index]
            a_text_fs_norm = norm_np(a_text_fs)
            a_text_fs_DE = subtract_wavelet_coefficients(a_text_fs_norm, text_mean_coefficients, args)
            result = compute_sim_np(a_text_fs_DE, audio_fs_DE, scale_audio_text)
            if return_top1_result_query(result, ground_truth):
                top1 += 1
            if return_top5_result_query(result, ground_truth):
                top5 += 1
            if return_top10_result_query(result, ground_truth):
                top10 += 1
        top1 = top1 / len(text_fs)
        top5 = top5 / len(text_fs)
        top10 = top10 / len(text_fs)

    return top1, top5, top10


if __name__ == "__main__":
    # load model
    args = get_args_parser()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    input_path = get_features_paths(args)

    args.batch_size = 64
    # print exp settings
    print("cross modal retrival:")
    print(args.retrival_type)
    show_settings(args)


    if args.dataset == "flickr30k" or args.dataset == "mscoco":
        model, preprocess = clip.load(args.VLM_Base, device, jit=False)
        convert_models_to_fp32(model)
        print("successfully load the model")
        images, refs, labels= load_imageretrival_datas(args,input_path)
        print("successfully load the data")
        image_features = extract_all_images(images, model, device, args, preprocess).cpu()
        text_features = extract_all_captions(refs, model, device, args).cpu()
        labels = torch.Tensor(labels).long()
        print("successfully extract the features")
        top1s = []
        top5s = []
        top10s = []
        for _ in range(args.num_experiments):
            set_random_seed()
            top1,top5,top10 = compute_imagetext_retrieval(model, images, refs, labels, image_features,text_features,device, args, preprocess)
            top1s.append(top1*100)
            top5s.append(top5*100)
            top10s.append(top10*100)
        print(f'Top-1 Accuracy: {np.mean(top1s)}')
        print(f'Top-1 Std {np.std(top1s)}')
        print(f'Top-5 Accuracy: {np.mean(top5s)}')
        print(f'Top-5 Std {np.std(top5s)}')
        print(f'Top-10 Accuracy: {np.mean(top10s)}')
        print(f'Top-10 Std {np.std(top10s)}')

    # audio retrival tasks
    if args.dataset == "esc50" or args.dataset == "urbansound8k":
        input_path = get_features_paths(args)
        text_fs, audios_fs_folds, labels_folds, fold_num = read_audio_features(input_path)
        print("Successfully Load the Features!")
        top1s = []
        top5s = []
        top10s = []
        for _ in range(args.num_experiments):
            set_random_seed()
            top1 = 0
            top5 = 0
            top10 = 0
            audios_fs = []
            labels = []
            for i in range(fold_num):
                audios_fs = audios_fs_folds[i]
                labels = labels_folds[i]
                top1_fold, top5_fold, top10_fold = conduct_text_audio_retrival(args, text_fs, audios_fs, labels)
                top1 +=top1_fold
                top5 +=top5_fold
                top10 +=top10_fold
            top1s.append(top1/fold_num*100)
            top5s.append(top5/fold_num*100)
            top10s.append(top10/fold_num*100)
        print(f'Top-1 Accuracy: {np.mean(top1s)}')
        print(f'Top-1 Std {np.std(top1s)}')
        print(f'Top-5 Accuracy: {np.mean(top5s)}')
        print(f'Top-5 Std {np.std(top5s)}')
        print(f'Top-10 Accuracy: {np.mean(top10s)}')
        print(f'Top-10 Std {np.std(top10s)}')
