import argparse
import sys
import os

parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)

import collections
import json
from typing import Union, Tuple, List

import numpy as np
import torch
import torch.nn as nn
import tqdm
from torch.utils.data import DataLoader

from evaluate.dataset import TextDataset, ImgTSVDataset
import open_clip


class ModelWrapper(nn.Module):
    """ Wrap model for DataParallel multi-gpu testing. """

    def __init__(self, model: nn.Module, forward_func: str = 'forward'):
        super(ModelWrapper, self).__init__()
        self.model = model
        self.forward_func = forward_func

    def forward(self, **kwargs):
        # print('kwargs:{}'.format(kwargs))
        return getattr(self.model, self.forward_func)(**kwargs)


class DefaultDataLoader(DataLoader):
    def __init__(self,
                 dataset,
                 batch_size=32,
                 shuffle=False,
                 sampler=None,
                 num_workers=4,
                 pin_memory=False,
                 drop_last=False,
                 prefetch_factor=2,
                 persistent_workers=False):
        if sampler is not None:
            shuffle = False

        super(DefaultDataLoader, self).__init__(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            sampler=sampler,
            num_workers=num_workers,
            pin_memory=pin_memory,
            drop_last=drop_last,
            persistent_workers=persistent_workers,
            prefetch_factor=prefetch_factor
        )


class ImageTextRetrievalEvaluator(object):
    def __init__(self,
                 img_ann_path: str,
                 text_ann_path: str,
                 rel_ann_path: str,
                 topk: Tuple = (1, 5, 10)):
        super(ImageTextRetrievalEvaluator, self).__init__()

        self.img_info_list = self.load_jsonl_annotation(img_ann_path)
        self.text_info_list = self.load_jsonl_annotation(text_ann_path)
        self.rel_info_list = self.load_jsonl_annotation(rel_ann_path)
        self.topk = topk

    def load_jsonl_annotation(self, ann_path: str) -> List[dict]:
        ann_list = []
        with open(ann_path, 'r', encoding='utf8') as f:
            for line in f:
                ann = json.loads(line)  # type: dict
                ann_list.append(ann)
        return ann_list

    @torch.no_grad()
    def compute_image_text_similarity(self,
                                      image_feat: Union[torch.Tensor, np.ndarray],
                                      text_feat: Union[torch.Tensor, np.ndarray]):
        if isinstance(image_feat, torch.Tensor):
            image_feat = image_feat.cpu().numpy()
        if isinstance(text_feat, torch.Tensor):
            text_feat = text_feat.cpu().numpy()
        image_feat = image_feat.astype(np.float32)
        text_feat = text_feat.astype(np.float32)

        assert image_feat.shape[0] == len(self.img_info_list), \
            f"{image_feat.shape[0]} vs {len(self.img_info_list)}"
        assert text_feat.shape[0] == len(self.text_info_list), \
            f"{text_feat.shape[0]} vs {len(self.text_feat)}"

        image_feat = image_feat / np.clip(np.sqrt(np.sum(image_feat ** 2, axis=-1, keepdims=True)), a_min=1e-7,
                                          a_max=np.inf)
        text_feat = text_feat / np.clip(np.sqrt(np.sum(text_feat ** 2, axis=-1, keepdims=True)), a_min=1e-7,
                                        a_max=np.inf)

        if torch.cuda.is_available():
            text_feat_tensor = torch.FloatTensor(text_feat).cuda().permute(1, 0).contiguous()
            image_feat_tensor = torch.FloatTensor(image_feat).cuda()
            sim = image_feat_tensor @ text_feat_tensor
            sim = sim.cpu().numpy()
        else:
            sim = image_feat @ text_feat.T
        return sim

    def get_ground_truth_relation(self, query_key: str, doc_key: str):
        gt_relation = collections.defaultdict(list)
        for info in self.rel_info_list:
            gt_relation[info[query_key]].append(info[doc_key])
        return gt_relation

    def get_predict_relation(self, sim_score: np.ndarray):
        pred_relation_i2t = dict()
        pred_relation_t2i = dict()
        for index, img_info in enumerate(self.img_info_list):
            img_id = img_info['img_id']
            arginds = np.argsort(sim_score[index, :])[::-1][:10]
            text_id_list = [self.text_info_list[i]['text_id'] for i in arginds]
            pred_relation_i2t[img_id] = text_id_list
        for index, text_info in enumerate(self.text_info_list):
            text_id = text_info['text_id']
            arginds = np.argsort(sim_score[:, index])[::-1][:10]
            img_id_list = [self.img_info_list[i]['img_id'] for i in arginds]
            pred_relation_t2i[text_id] = img_id_list

        return pred_relation_i2t, pred_relation_t2i

    def calculate_recall(self, pred_relation, gt_relation):
        key_list = list(gt_relation.keys())
        recall = {}
        for k in self.topk:
            hit = 0
            for key in key_list:
                if any(gt_id in pred_relation[key][:k] for gt_id in gt_relation[key]):
                    hit += 1
            recall[f'recall@k={k}'] = hit / len(key_list)
        return recall

    def evaluate(self,
                 text_feat: Union[torch.Tensor, np.ndarray],
                 image_feat: Union[torch.Tensor, np.ndarray]):
        image2text_sim = self.compute_image_text_similarity(image_feat, text_feat)
        pred_relation_i2t, pred_relation_t2i = self.get_predict_relation(image2text_sim)
        gt_relation_i2t = self.get_ground_truth_relation('img_id', 'text_id')
        gt_relation_t2i = self.get_ground_truth_relation('text_id', 'img_id')

        i2t_recall = self.calculate_recall(pred_relation_i2t, gt_relation_i2t)
        t2i_recall = self.calculate_recall(pred_relation_t2i, gt_relation_t2i)

        result = dict()
        for k, v in i2t_recall.items():
            result[f'i2t_{k}'] = v
        for k, v in t2i_recall.items():
            result[f't2i_{k}'] = v
        return result


def do_eval_for_one_dataset(model, preprocess, tokenizer,
                            text_info_path, img_info_path, relation_info_path,
                            img_tsv_path, batch_size=32):
    model.eval()
    model.cuda()

    text_model = ModelWrapper(model, forward_func='encode_text')
    text_model = nn.DataParallel(text_model.cuda(), device_ids=list(range(torch.cuda.device_count())))

    img_model = ModelWrapper(model, forward_func='encode_image')
    img_model = nn.DataParallel(img_model.cuda(), device_ids=list(range(torch.cuda.device_count())))

    # Step 2, build dataloader
    print("Build text dataloader")
    text_dataset = TextDataset(text_info_path, tokenizer=tokenizer)
    text_loader = DefaultDataLoader(text_dataset, batch_size=batch_size)

    print("Build image dataloader")
    img_dataset = ImgTSVDataset(img_info_path, img_tsv_path, preprocess)
    img_loader = DefaultDataLoader(img_dataset, batch_size=batch_size)

    # Step 3, inference
    bar = tqdm.tqdm(total=len(text_loader))
    text_feat_list = []
    with torch.no_grad():
        for text_data in text_loader:
            feat = text_model(**text_data)
            feat /= feat.norm(dim=-1, keepdim=True)
            text_feat_list.append(feat.cpu().numpy())
            bar.update()
    bar.close()
    text_feat = np.concatenate(text_feat_list, axis=0)

    bar = tqdm.tqdm(total=len(img_loader))
    img_feat_list = []
    with torch.no_grad():
        for doc_data in img_loader:
            feat = img_model(**doc_data)
            feat /= feat.norm(dim=-1, keepdim=True)
            img_feat_list.append(feat.cpu().numpy())
            bar.update()
    bar.close()
    img_feat = np.concatenate(img_feat_list, axis=0)

    # Step 4, evaluate
    print("Evaluate results...")
    evaluator = ImageTextRetrievalEvaluator(
        img_info_path,
        text_info_path,
        relation_info_path
    )
    results = evaluator.evaluate(text_feat, img_feat)
    return results


def parse_args():
    parser = argparse.ArgumentParser(description='Zero-shot Image-Text Retrieval. ')
    parser.add_argument('--model_name', default='YouCLIP-Base',
                        choices=['YouCLIP-Base', 'YouCLIP-Base-CN-ENG', 'YouCLIP-Base-512', 'YouCLIP-Base-512-CN-ENG',
                                 'YouCLIP-Large', 'YouCLIP-Large-CN-ENG', 'YouCLIP-Huge', 'YouCLIP-Huge-CN-ENG'],
                        help='Model size. ')
    parser.add_argument('--model_checkpoint', default=None, type=str, help='checkpoint path. ')
    parser.add_argument('--text_info_path', default=None, type=str, help='text info path. ')
    parser.add_argument('--img_info_path', default=None, type=str, help='img info path. ')
    parser.add_argument('--img_tsv_path', default=None, type=str, help='img tsv path. ')
    parser.add_argument('--relation_info_path', default=None, type=str, help='relation info path. ')
    parser.add_argument('--batch_size', default=32, type=int, help='batch_size for eval')
    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_args()
    print('build model...')
    model, preprocess, tokenizer = open_clip.load_YouCLIP(model_name=args.model_name,
                                                          model_file_path=args.model_checkpoint)

    result = do_eval_for_one_dataset(model=model, tokenizer=tokenizer, preprocess=preprocess,
                                     text_info_path=args.text_info_path,
                                     img_info_path=args.img_info_path,
                                     relation_info_path=args.relation_info_path,
                                     img_tsv_path=args.img_tsv_path,
                                     batch_size=args.batch_size, )
    print(result)
