#!/usr/bin/env python3
# -*- coding: utf-8 -*-


"""
Evaluation tools adapted from https://github.com/fartashf/vsepp/blob/master/evaluation.py
"""

import numpy as np
import torch
import random
from sentence_transformers import util
from loguru import logger
import torch.nn.functional as F
from gensim.models.word2vec import Word2Vec
import ot
from tqdm import tqdm
import pylab as pl
import seaborn as sns


import pygmtools as pygm


def setup_seed(seed):

    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


class AverageMeter(object):
    """
    Keeps track of most recent, average, sum, and count of a metric.
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count




def l2norm(X):
    """L2-normalize columns of X
    """
    norm = torch.pow(X, 2).sum(dim=1, keepdim=True).sqrt()
    X = torch.div(X, norm)
    return X


# evaluation tools
def a2t(audio_embs, cap_embs, return_ranks=False):
    num_audios = int(audio_embs.shape[0] / 5)
    index_list = []

    ranks = np.zeros(num_audios)
    top1 = np.zeros(num_audios)
    mAP10 = np.zeros(num_audios)
    for index in tqdm(range(num_audios)):
        # get query audio
        audio = audio_embs[5 * index].reshape(1, audio_embs.shape[1]) # size of [1, audio_emb]
        d = util.cos_sim(torch.Tensor(audio), torch.Tensor(cap_embs)).squeeze(0).numpy() # size of [1, #captions]
        inds = np.argsort(d)[::-1] # sorting metric scores
        index_list.append(inds[0])

        inds_map = []

        rank = 1e20
        #########################################################################
        # find the best rank among five captions
        for i in range(5 * index, 5 * index + 5, 1):
            tmp = np.where(inds == i)[0][0]
            if tmp < rank:
                rank = tmp
            if tmp < 10:
                inds_map.append(tmp + 1)
        ##########################################################################
        inds_map = np.sort(np.array(inds_map))
        if len(inds_map) != 0:
            mAP10[index] = np.sum((np.arange(1, len(inds_map) + 1) / inds_map)) / 5
        else:
            mAP10[index] = 0.
        ranks[index] = rank
        top1[index] = inds[0]
    # compute metrics
    r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
    r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
    r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
    r50 = 100.0 * len(np.where(ranks < 50)[0]) / len(ranks)
    mAP10 = 100.0 * np.sum(mAP10) / len(ranks)
    medr = np.floor(np.median(ranks)) + 1
    meanr = ranks.mean() + 1
    if return_ranks:
        return r1, r5, r10, r50, medr, meanr, ranks, top1
    else:
        return r1, r5, r10, r50, medr, meanr
    

def t2a(audio_embs, cap_embs, return_ranks=False):
    num_audios = int(audio_embs.shape[0] / 5)
    audios = np.array([audio_embs[i]for i in range(0, audio_embs.shape[0], 5)])

    ranks = np.zeros(5 * num_audios)
    top1 = np.zeros(5 * num_audios)

    for index in tqdm(range(num_audios)):

        # get query captions
        queries = cap_embs[5 * index: 5 * index + 5]
        d = util.cos_sim(torch.Tensor(queries), torch.Tensor(audios)).numpy() # size of [5 queries, #audios]

        inds = np.zeros(d.shape)
        for i in range(len(inds)):
            inds[i] = np.argsort(d[i])[::-1]
            ranks[5 * index + i] = np.where(inds[i] == index)[0][0]
            top1[5 * index + i] = inds[i][0]

    
    # compute metrics
    r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
    r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
    r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
    r50 = 100.0 * len(np.where(ranks < 50)[0]) / len(ranks)
    mAP10 = 100.0 * np.sum(1 / (ranks[np.where(ranks < 10)[0]] + 1)) / len(ranks)
    medr = np.floor(np.median(ranks)) + 1
    meanr = ranks.mean() + 1
    if return_ranks:
        return r1, r5, r10, r50, medr, meanr, ranks, top1
    else:
        return r1, r5, r10, r50, medr, meanr


def a2t_ot(audio_embs, cap_embs, M, train_data=False, use_float=False):
    if not train_data:
        audio = [audio_embs[i] for i in range(0, len(audio_embs), 5)]
    else:
        audio = audio_embs

    rank_list = []

    a = torch.ones(len(audio))/len(audio)
    b = torch.ones(len(cap_embs))/len(cap_embs)

    cap_embs = torch.tensor(cap_embs).float()
    audio = torch.tensor(np.array(audio)).float()

    
    M = M.float().cpu()
    audio_norm = (audio ** 2 * M).sum(dim=1, keepdim=True)
    text_norm = (cap_embs **2 * M).sum(dim=1,)
    cross = 2 * torch.matmul(audio * M, cap_embs.T)
    M_dist = text_norm.unsqueeze(0) + audio_norm - cross

    M_dist = torch.sqrt(M_dist)
    M_dist = M_dist/M_dist.max()
    ###########################################################

    d = ot.sinkhorn(a,b,M_dist, reg=0.03, numItermax=10).cpu().numpy()

    for index in range(len(audio)):
        inds = np.argsort(d[index])[::-1] # sort an array by index
        inds_map = []
        rank = 1e20
        if not train_data:
            for i in range(5 * index, 5 * index + 5, 1):
                tmp = np.where(inds == i)[0][0]
                if tmp < rank:
                    rank = tmp
                if tmp < 10:
                    inds_map.append(tmp + 1)
            rank_list.append(rank)
        else:
            tmp = np.where(inds == index)[0][0]
            if tmp < rank:
                rank = tmp
            if tmp < 10:
                inds_map.append(tmp + 1)
            rank_list.append(rank)
    preds = np.array(rank_list)

    medr = np.floor(np.median(preds)) + 1
    meanr = preds.mean() + 1
    r1 = np.mean(preds < 1)*100
    r5 = np.mean(preds < 5)*100
    r10 = np.mean(preds < 10)*100
    r50 = np.mean(preds < 50)*100
    return r1, r5, r10, r50, medr, meanr


def t2a_ot(audio_embs, cap_embs, M, train_data=False, use_float=False):
    if not train_data:
        audio = [audio_embs[i] for i in range(0, len(audio_embs), 5)]
    else:
        audio = audio_embs
    rank_list = []
    a = torch.ones(len(cap_embs))/len(cap_embs)
    b = torch.ones(len(audio))/len(audio)

    cap_embs = torch.from_numpy(np.array(cap_embs)).float()
    audio = torch.from_numpy(np.array(audio)).float()

    # Mahanalobis distance ####################################
    M = M.float().cpu()
    text_norm = (cap_embs **2 * M).sum(dim=1, keepdim=True)
    audio_norm = (audio ** 2 * M).sum(dim=1)
    cross = 2 * torch.matmul(cap_embs * M, audio.T)
    M_dist = text_norm + audio_norm.unsqueeze(0) - cross

    M_dist = torch.sqrt(M_dist)
    M_dist = M_dist/ (M_dist.max()+ 0.1)
    ###########################################################

    d = ot.sinkhorn(a,b,M_dist, reg=0.03, numItermax=10).cpu().numpy() #[#cap_embs, #audio]
    rank_d = []
    for index in range(len(audio)):
        if not train_data:
            for i in range(5*index, 5*index+5, 1):
                inds = np.argsort(d[i])[::-1]
                rank = np.where(inds==index)[0][0]
                rank_d.append(inds)
                rank_list.append(rank)
        else:
            inds = np.argsort(d[index])[::-1]
            rank = np.where(inds==index)[0][0]
            rank_d.append(inds)
            rank_list.append(rank)

    preds = np.array(rank_list)
    rank_d = np.array(rank_d)

    medr = np.floor(np.median(preds)) + 1
    meanr = preds.mean() + 1
    r1 = np.mean(preds < 1)*100
    r5 = np.mean(preds < 5)*100
    r10 = np.mean(preds < 10)*100
    r50 = np.mean(preds < 50)*100

    return r1, r5, r10, r50, medr, meanr

def a2t_t2a_ot(audio_embs, cap_embs, M, train_data=False, use_float=False):
    rank_list = []
    a = torch.ones(len(audio_embs))/len(audio_embs)
    b = torch.ones(len(cap_embs))/len(cap_embs)

    audio = torch.tensor(np.array(audio_embs)).float()
    cap_embs = torch.tensor(np.array(cap_embs)).float()

    # Mahanalobis distance ###########################
    M = M.float().cpu()
    audio_norm = (audio ** 2 * M).sum(dim=1, keepdim=True)
    text_norm = (cap_embs **2 * M).sum(dim=1,)
    cross = 2 * torch.matmul(audio * M, cap_embs.T)
    M_dist = text_norm.unsqueeze(0) + audio_norm - cross
    M_dist = torch.sqrt(M_dist)
    M_dist = M_dist/M_dist.max()
    ###########################################################
    
    d = ot.sinkhorn(a, b, M_dist, reg=0.03, numItermax=10).cpu().numpy()
    rank_d = []
    if not train_data:
        for index in range(audio.size(0) // 5):
            inds = np.argsort(d[5*index: 5*index+5])[:, ::-1] # 返回降序排序后的数组，inds[:,i] 表示第i名的原始索引
            rank = np.where(np.isin(inds, np.arange(5*index, 5*index+5)))[1].min()
            rank_d.append(inds)
            rank_list.append(rank)
    else:
        for i in range(audio_embs.size(0)):
            inds = np.argsort(d[i])[::-1]
            rank_list.append(np.where(inds == i)[0][0])

    preds = np.array(rank_list)
    rank_d = np.array(rank_d)

    medr = np.floor(np.median(preds)) + 1
    meanr = preds.mean() + 1
    r1 = np.mean(preds < 1)*100
    r5 = np.mean(preds < 5)*100
    r10 = np.mean(preds < 10)*100
    r50 = np.mean(preds < 50)*100
    return r1, r5, r10, r50, medr, meanr    
            


def t2a_ot_bilinear(audio_embs, cap_embs, M=None, train_data=False):
    if not train_data:
        audio = [audio_embs[i] for i in range(0, len(audio_embs), 5)]
    else:
        audio = audio_embs

    rank_list = []
    cap_embs = torch.tensor(np.array(cap_embs))
    audio = torch.tensor(np.array(audio))
    if M is not None:
        M = M.to("cpu")
        M = torch.matmul(cap_embs, M)
        M_dist = torch.matmul(M, audio.transpose(0, 1))
    else:
        M_dist = torch.cdist(cap_embs, audio, p=2)
    M_dist = M_dist/M_dist.max()
    
    b= torch.ones(len(audio))/len(audio)
    a = torch.ones(len(cap_embs))/len(cap_embs)
    a = a.to(audio.device)
    b = b.to(audio.device)
    s = ot.sinkhorn(a,b,M_dist, reg=0.03, numItermax=10).cpu().numpy()
    # s = pygm.sinkhorn(-M_dist,dummy_row=True, max_iter=10, tau=0.03, batched_operation=False, backend='pytorch').cpu().numpy()
    # s = F.softmax(-M_dist, dim=-1).cpu().numpy()
    rank_d = []
    for index in range(len(audio)):
        if not train_data:
            for i in range(5*index, 5*index+5, 1):
                inds = np.argsort(s[i])[::-1]
                rank = np.where(inds==index)[0][0]
                rank_d.append(inds)
                rank_list.append(rank)
        else:
            inds = np.argsort(d[index])[::-1]
            rank = np.where(inds==index)[0][0]
            rank_d.append(inds)
            rank_list.append(rank)
    preds = np.array(rank_list)
    rank_d = np.array(rank_d)

    medr = np.floor(np.median(preds)) + 1
    meanr = preds.mean() + 1
    r1 = np.mean(preds < 1)*100
    r5 = np.mean(preds < 5)*100
    r10 = np.mean(preds < 10)*100
    r50 = np.mean(preds < 50)*100
    return r1, r5, r10, r50, medr, meanr

def a2t_ot_bilinear(audio_embs, cap_embs, M=None, train_data=False):
    if not train_data:
        audio = [audio_embs[i] for i in range(0, len(audio_embs), 5)]
    else:
        audio = audio_embs

    rank_list = []

    cap_embs = torch.tensor(np.array(cap_embs))
    audio = torch.tensor(np.array(audio))
    if M is not None:
        M = M.to("cpu")
        M = torch.matmul(audio, M)
        M_dist = torch.matmul(M, cap_embs.transpose(0, 1))
    else:
        M_dist = torch.cdist(audio, cap_embs, p=2)
    M_dist = M_dist/M_dist.max()
    
    b = torch.ones(len(cap_embs))/len(cap_embs)
    a = torch.ones(len(audio))/len(audio)
    
    a = a.to(audio.device)
    b = b.to(audio.device)
    s = ot.sinkhorn(a,b, M_dist, reg=0.03, numItermax=10).cpu().numpy()
    # s = pygm.sinkhorn(-M_dist,dummy_row=True, max_iter=10, tau=0.03, batched_operation=False, backend='pytorch').cpu().numpy()
    # s = F.softmax(-M_dist, dim=-1).cpu().numpy()
    for index in range(len(audio)):
        inds = np.argsort(s[index])[::-1] # 返回降序排序后的数组，inds[i] 表示第i名的原始索引
        inds_map = []
        rank = 1e20
        if not train_data:
            for i in range(5 * index, 5 * index + 5, 1):
                tmp = np.where(inds == i)[0][0]
                if tmp < rank:
                    rank = tmp
                if tmp < 10:
                    inds_map.append(tmp + 1)
            rank_list.append(rank)
        else:
            tmp = np.where(inds == index)[0][0]
            if tmp < rank:
                rank = tmp
            if tmp < 10:
                inds_map.append(tmp + 1)
            rank_list.append(rank)
    preds = np.array(rank_list)
    
    medr = np.floor(np.median(preds)) + 1
    meanr = preds.mean() + 1
    r1 = np.mean(preds < 1)*100
    r5 = np.mean(preds < 5)*100
    r10 = np.mean(preds < 10)*100
    r50 = np.mean(preds < 50)*100
    return r1, r5, r10, r50, medr, meanr

def a2t_and_t2a(audio_embs, cap_embs, M=None, train_data=False):
    rank_list = []
    audio_embs, cap_embs = torch.tensor(np.array(audio_embs)), torch.tensor(np.array(cap_embs))
    if M is not None:
        M = M.to("cpu")
        M = torch.matmul(audio_embs, M)
        M_dist = torch.matmul(M, cap_embs.transpose(0, 1))
    else:
        M_dist = torch.cdist(audio_embs, cap_embs, p=2)
    M_dist = M_dist/M_dist.max()
    batch_size = audio_embs.size(0)
    a = torch.ones(batch_size)/batch_size
    b = torch.ones(batch_size)/batch_size
    a = a.to(audio_embs.device)
    b = b.to(audio_embs.device)
    s = ot.sinkhorn(a,b,M_dist, reg=0.03, numItermax=10).cpu().numpy()
    # s = pygm.sinkhorn(-M_dist, dummy_row=False, max_iter=10, tau=0.03, batched_operation=False, backend='pytorch').cpu().numpy()
    # s = F.softmax(-M_dist, dim=-1).cpu().numpy()
    rank_d = []
    if not train_data:
        for index in range(audio_embs.size(0) // 5):
            inds = np.argsort(s[5*index: 5*index+5])[:, ::-1]
            rank = np.where(np.isin(inds, np.arange(5*index, 5*index+5)))[1].min()
            rank_d.append(inds)
            rank_list.append(rank)
    else:
        for i in range(audio_embs.size(0)):
            inds = np.argsort(s[i])[::-1]
            rank_list.append(np.where(inds == i)[0][0])
            
    preds = np.array(rank_list)
    rank_d = np.array(rank_d)

    medr = np.floor(np.median(preds)) + 1
    meanr = preds.mean() + 1
    r1 = np.mean(preds < 1)*100
    r5 = np.mean(preds < 5)*100
    r10 = np.mean(preds < 10)*100
    r50 = np.mean(preds < 50)*100
    return r1, r5, r10, r50, medr, meanr

def visual_plan(d):
    fig, ax = pl.subplots()
    # matrix = d[:10,:50]
    matrix = d
    # matrix = torch.from_numpy(matrix)
    # matrix = torch.nn.functional.softmax(matrix, dim=-1).detach().numpy()
    
    # index = [i for i in range(0,50,5)]
    # index = [i for i in range(d.shape[0])]
    # matrix = matrix[:, index]
    pl.matshow(matrix[:30,:30])
    # ax = sns.heatmap(matrix)
    pl.xlabel("Caption", fontsize=20)
    pl.ylabel("Audio", fontsize=20)
    pl.savefig("cosine-ot.png")

def visual_true_plan(d):
    true_pi = torch.zeros(d.shape[0], d.shape[1])
    for i in range(d.shape[0]):
        true_pi[i, i*5:i*5+5]=1
        # true_pi[i,i]=1/d.shape[0]
    true_pi = true_pi/(d.shape[0]*d.shape[1])
    true_pi = true_pi.cpu().numpy()

    # matrix = true_pi[:10,:50]
    matrix = true_pi[:30,:30]
    
    # index = [i for i in range(0,50,5)]
    index = [i for i in range(d.shape[0])]
    # matrix = matrix[:, index]
    matrix = true_pi[:30,:30]
    pl.matshow(matrix)
    # pl.colorbar()
    pl.xlabel("Caption",fontsize=20)
    pl.ylabel("Audio", fontsize=20)
    pl.savefig("cosine-true-plan.png")


if __name__ == '__main__':
    d = torch.rand(32,32)
    visual_plan(d)
    visual_true_plan(d)