"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import json
import os

import numpy as np
import torch
from lavis.common.registry import registry
from lavis.tasks.retrieval import RetrievalTask


@registry.register_task("retrieval_mllmu_mixed")
class RetrievalMllmuMixedTask(RetrievalTask):
    @staticmethod
    @torch.no_grad()
    def _report_metrics(scores_i2t, scores_t2i, txt2img, img2txt):

        # Images->Text
        ranks = np.zeros(scores_i2t.shape[0])
        for index, score in enumerate(scores_i2t):
            inds = np.argsort(score)[::-1]
            # Score
            rank = 1e20
            for i in img2txt[index]:
                tmp = np.where(inds == i)[0][0]
                if tmp < rank:
                    rank = tmp
            ranks[index] = rank

        # Compute metrics
        tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
        tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
        tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)

        dr_tr1  = 100.0 * len(np.where(ranks[:1000] < 1)[0])  / 1000
        dr_tr5  = 100.0 * len(np.where(ranks[:1000] < 5)[0])  / 1000
        dr_tr10 = 100.0 * len(np.where(ranks[:1000] < 10)[0]) / 1000

        df_tr1  = 100.0 * len(np.where(ranks[1000:] < 1)[0])  / (len(ranks)-1000)
        df_tr5  = 100.0 * len(np.where(ranks[1000:] < 5)[0])  / (len(ranks)-1000)
        df_tr10 = 100.0 * len(np.where(ranks[1000:] < 10)[0]) / (len(ranks)-1000)

        # Text->Images
        ranks = np.zeros(scores_t2i.shape[0])

        for index, score in enumerate(scores_t2i):
            inds = np.argsort(score)[::-1]
            ranks[index] = np.where(inds == txt2img[index])[0][0]

        # Compute metrics
        ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
        ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
        ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)

        dr_ir1 = 100.0 * len(np.where(ranks[:5000] < 1)[0]) / 5000
        dr_ir5 = 100.0 * len(np.where(ranks[:5000] < 5)[0]) / 5000
        dr_ir10 = 100.0 * len(np.where(ranks[:5000] < 10)[0]) / 5000

        df_ir1 = 100.0 * len(np.where(ranks[5000:] < 1)[0]) / (len(ranks) - 5000)
        df_ir5 = 100.0 * len(np.where(ranks[5000:] < 5)[0]) / (len(ranks) - 5000)
        df_ir10 = 100.0 * len(np.where(ranks[5000:] < 10)[0]) / (len(ranks) - 5000)

        tr_mean = (tr1 + tr5 + tr10) / 3
        ir_mean = (ir1 + ir5 + ir10) / 3
        r_mean = (tr_mean + ir_mean) / 2

        dr_tr_mean = (dr_tr1 + dr_tr5 + dr_tr10) / 3
        dr_ir_mean = (dr_ir1 + dr_ir5 + dr_ir10) / 3
        dr_r_mean  = (dr_tr_mean +      dr_ir_mean) / 2

        df_tr_mean = (df_tr1 + df_tr5 + df_tr10) / 3
        df_ir_mean = (df_ir1 + df_ir5 + df_ir10) / 3
        df_r_mean  = (df_tr_mean +      df_ir_mean) / 2

        agg_metrics = (tr1 + tr5 + tr10) / 3
        eval_result = {
            "txt_r1": tr1,
            "txt_r5": tr5,
            "txt_r10": tr10,
            "txt_r_mean": tr_mean,
            "img_r1": ir1,
            "img_r5": ir5,
            "img_r10": ir10,
            "img_r_mean": ir_mean,
            "r_mean": r_mean,
            "agg_metrics": agg_metrics,
            "dr_metrics": {
                "txt_r1":     dr_tr1,
                "txt_r5":     dr_tr5,
                "txt_r10":    dr_tr10,
                "txt_r_mean": dr_tr_mean,
                "img_r1":     dr_ir1,
                "img_r5":     dr_ir5,
                "img_r10":    dr_ir10,
                "img_r_mean": dr_ir_mean,
                "r_mean":     dr_r_mean,
            },
            "df_metrics": {
                "txt_r1":     df_tr1,
                "txt_r5":     df_tr5,
                "txt_r10":    df_tr10,
                "txt_r_mean": df_tr_mean,
                "img_r1":     df_ir1,
                "img_r5":     df_ir5,
                "img_r10":    df_ir10,
                "img_r_mean": df_ir_mean,
                "r_mean":     df_r_mean,
            },
        }
        with open(
            os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
        ) as f:
            f.write(json.dumps(eval_result) + "\n")
        return eval_result
