
import os
import json
from typing import List
from dataclasses import dataclass, field, asdict

import numpy as np
from tqdm import tqdm
from transformers import HfArgumentParser
import datasets
from PIL import Image
from PIL import ImageFile

from index_builder import Indexer
from retriever import DenseRetriever
from retrieval_evalator import Evaluator

from encoder_inference import InferenceModelMultiGPU
import logging 
import random
import time


Image.MAX_IMAGE_PIXELS = None
ImageFile.LOAD_TRUNCATED_IMAGES = True

logger = logging.getLogger('evaluate_logger')
logger.setLevel(logging.INFO)

fh = logging.FileHandler('evaluate.log')
fh.setLevel(logging.INFO)

ch = logging.StreamHandler()
ch.setLevel(logging.INFO)

formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fh.setFormatter(formatter)
ch.setFormatter(formatter)

logger.addHandler(fh)
logger.addHandler(ch)

@dataclass
class EvalArgs:
    model_name_or_path: str = field(
        default=None,
        metadata={'help': 'Name or path of Embedding encoder'}
    )
    model_config_path: str = field(
        default=None,
        metadata={'help': 'path for encoder_model_path'}
    )
    reranker: str = field(default=None, metadata={'help': 'path to reranker model'})
    rank_result_save_dir: str = field(default=None, metadata={'help': 'path to save rank results'})
    task_name: str = field(
        default=None,
        metadata={'help': 'task name'}
    )
    task_type: str = field(
        default=None,
        metadata={'help': 'task type'}
    )
    batch_size: int = field(
        default=8,
        metadata={'help': 'Search batch size.'}
    )
    max_length: int = field(default=512,
        metadata={'help': 'Maximum length of input text.'}
    )
    only_summary: bool = field(
        default=False,
        metadata={'help': 'Whether to overwrite embedding'}
    )
    overwrite: bool = field(
        default=False,
        metadata={'help': 'Whether to overwrite embedding'}
    )
    doc_id_field: str = field(default="did", metadata={'help': 'Doc id field. Default is `_id`'})
    doc_text_field: List[str] = field(default_factory=lambda: ["txt"], metadata={'help': 'Doc text field. Default is `content`'})
    doc_image_path_field: str = field(default="img_path", metadata={'help': 'Doc image path field. Default is `image`'})
    result_save_dir: str = field(
        default='./search_results',
        metadata={'help': 'Dir to saving search results.'}
    )
    search_result_save_dir: str = field(
        default='./search_results/dense.txt',
        metadata={'help': 'Dir to saving rank results.'}
    )
    index_save_dir: str = field(
        default='./index/',
        metadata={'help': 'Dir to save dense index. Corpus index will be saved to `dense_index_save_dir/index`. Corpus ids will be saved to `dense_index_save_dir/docid` .'}
    )
    threads: int = field(
        default=1,
        metadata={'help': 'Maximum threads to use during search'}
    )
    query_id_field: str = field(
        default='qid',
    )
    query_text_field: str = field(
        default='query_txt',
    )
    query_image_path_field: str = field(
        default='query_img_path'
    )
    image_prefix_path: str = field(
        default=""
    ) 
    lang: str = field(
        default='zh',
        metadata={'help': 'Language'}
    )
    faiss_gpu: bool = field(
        default=True,
        metadata={'help': 'Whether to use gpu for faiss'}
    )
    topk: int = field(
        default=100,
        metadata={'help': 'topk number'}
    )
    inference_type: str = field(
        default='yes_or_no',
        metadata={'help': 'inference type'}
    )
    tasklist_path: str = field(default=None, metadata={'help': 'path to task_list'})


class RetrievalPipeline():
    def __init__(self, args):
        self.args = args
        with open(self.args.model_config_path) as f:
            self.model_config = json.load(f)
        self.model_config['max_length'] = args.max_length
        self.encoder = InferenceModelMultiGPU(self.model_config)
        self.ranker = None
        if self.args.reranker is not None:
            if self.args.output_logits:
                from ranker import RankerForLogits
                self.ranker = RankerForLogits(self.args.reranker,self.args.inference_type,self.args.logits_save_path)
            else:
                from ranker import Ranker
                self.ranker = Ranker(self.args.reranker,self.args.inference_type,self.args.max_length)
            self.rank_result_save_dir = self.args.rank_result_save_dir
        print('load encoder finished')
        model_name_or_path = self.model_config.get('model_name_or_path')
        if model_name_or_path[-1] == '/':
            model_name_or_path = model_name_or_path[:-1]
        if os.path.basename(model_name_or_path).startswith('checkpoint-'):
            model_name_or_path = os.path.dirname(model_name_or_path) + '_' + os.path.basename(model_name_or_path)
        self.encoder_name = os.path.basename(model_name_or_path)
        if self.args.index_save_dir is not None:
            self.index_save_dir = self.args.index_save_dir
        else:
            self.index_save_dir = os.path.join('index', self.encoder_name)
        if self.args.search_result_save_dir is not None:
            self.search_result_save_dir = self.args.search_result_save_dir
        else:
            self.search_result_save_dir = os.path.join('search_results', self.encoder_name)
        if self.args.result_save_dir is not None:
            self.result_save_dir = self.args.result_save_dir
        else:
            self.result_save_dir = os.path.join('results', self.encoder_name)
        
        self.index_builder = Indexer(encoder=self.encoder, faiss_gpu=args.faiss_gpu)
        self.evaluater = Evaluator()
        tasklist_path = self.args.tasklist_path 
        with open(tasklist_path) as f:
            self.task_dict = json.load(f)
        
    
    def stop(self):
        self.encoder.stop()

    def read_corpus(self, corpus_path: datasets.Dataset, image_prefix_path, doc_id_field, doc_text_fields, image_path_field):
        corpus_list = []
        docids = []

        def preprocess_function(examples):
            inputs_list = []
            did_list = examples[doc_id_field]
            content_list = [[] for _ in range(len(did_list))]
            for field in doc_text_fields:
                for i in range(len(did_list)):
                    if field in examples and examples[field][i] is not None:
                        content_list[i].append(examples[field][i])
            content_list = ['\n'.join(ele) for ele in content_list]
            content_list = [None if len(ele) == 0 else ele for ele in content_list]
            image_list = []
            length_list = [len(ele) if ele else 0 for ele in content_list]
            if image_path_field not in examples:
                image_list = [None] * len(did_list)
            else:
                for image_path in examples[image_path_field]:
                    if image_path is not None:
                        image_path = os.path.join(image_prefix_path, image_path)
                        image_list.append(image_path)
                    else:
                        image_list.append(None)
            inputs = {'_id': did_list, 'text': content_list, 'image': image_list, 'length': length_list}
            return inputs

        if not os.path.exists(corpus_path):
            raise FileNotFoundError(f"{corpus_path} not found")
        corpus = datasets.load_dataset('json', data_files=corpus_path)['train']
        corpus = corpus.map(preprocess_function, batched=True, num_proc=16)
        docids = [x['_id'] for x in corpus]
        return docids, corpus

    def read_queries(self, query_path: str, image_prefix_path, image_path_field, query_id_field: str='_id', query_text_field: str='text'):
        if not os.path.exists(query_path):
            raise FileNotFoundError(f"{query_path} not found")

        queries = datasets.load_dataset('json', data_files=query_path)['train']

        def preprocess_function(examples):
            inputs_list = []
            did_list = examples[query_id_field]
            if query_text_field not in examples:
                content_list = [None] * len(did_list)
            else:
                content_list = examples[query_text_field]
            image_list = []
            if image_path_field not in examples:
                image_list = [None] * len(did_list)
            else:
                for image_path in examples[image_path_field]:
                    if image_path is not None:
                        image_path = os.path.join(image_prefix_path, image_path)
                        image_list.append(image_path)
                    else:
                        image_list.append(None)
            inputs = {'_id': did_list, 'text': content_list, 'image': image_list}
            return inputs
        queries = queries.map(preprocess_function, batched=True, num_proc=8)
        qids = [x['_id'] for x in queries]
        print('example queries', queries[:5])
        return qids, queries

    def get_task_list(self, args):
        if args.task_name == 'alltask':
            task_list = [key for key, value in self.task_dict.items()]
            return task_list
        task_list = [key for key, value in self.task_dict.items() if value['task_type'] != 'beir' and 'subtask' not in key]
        return task_list

    def run_sumary(self):
        task_list = self.get_task_list(self.args)
        for task_name in task_list:
            print(f'evaluating task: {task_name}')
            if task_name in self.task_dict:
                task_dict = self.task_dict[task_name]
            else:
                logger.error(f"task {task_name} not found")
                return
            result_save_path = os.path.join(self.result_save_dir, task_name+".json")
            if not self.args.overwrite and os.path.exists(result_save_path):
                logger.info(f'{result_save_path} has already exists, Skip...')
                continue
            search_result_save_path = os.path.join(self.search_result_save_dir, task_name+".txt")
            qrel_path = task_dict['qrel_path']
            results = self.evaluater.evaluate(qrel_path, search_result_save_path, result_save_path)
            results.update(task_dict)
            results['main_score'] = results[task_dict['main_metric']]
            results['encoder_name'] = self.encoder_name
            with open(result_save_path,'w') as fout:
                fout.write(json.dumps(results, ensure_ascii=False, indent=4))

        self.summary(task_list, self.args)
        return
    
    def summary(self, task_list, args):
        all_scores = {}
        tot = 0
        sum_scores = 0
        result_save_dir = args.result_save_dir
        for task in task_list:
            with open(os.path.join(result_save_dir, task + '.json')) as f:
                metrics = json.load(f)
                task_type = metrics['task_type']
                if task_type not in all_scores:
                    all_scores[task_type] = []
                main_metric=metrics['main_metric']
                main_score = metrics[main_metric]
                all_scores[task_type].append((task, main_score))
                sum_scores = sum_scores + metrics['main_score']
                tot += 1
        names = []
        scores = []
        names.append(f'Avg({tot})')
        scores.append(sum_scores / tot)
        for task_type in all_scores:
            task_num = len(all_scores[task_type])
            names.append(task_type + f'({str(task_num)})')
            scores.append(sum([ele[1] for ele in all_scores[task_type]]) / task_num)
        for task_type in all_scores.keys():
            for ele in all_scores[task_type]:
                names.append(f"{ele[0]}({task_type})")
                scores.append(ele[1])
        with open(os.path.join(result_save_dir, 'summary.tsv'),'w') as fout:
            fout.write('\t'.join(names)+'\n')
            fout.write('\t'.join([str(round(ele*100,2)) for ele in scores])+'\n')

    def run(self):
        task_list = self.get_task_list(self.args)
        for task in task_list:
            self._run(task)
        self.summary(task_list, self.args)

    def _run(self, task_name):
        print('running task: ', task_name)
        if task_name in self.task_dict:
            task_dict = self.task_dict[task_name]
        else:
            logger.error(f"task {task_name} not found")
            return
        result_save_path = os.path.join(self.result_save_dir, task_name+".json")
        if not self.args.overwrite and os.path.exists(result_save_path):
            logger.info(f'{result_save_path} has already exists, Skip...')
            return
        logger.info('start loading corpus')
        if 'prompts' in task_dict:
            prompt = random.choice(task_dict['prompts'])
            self.encoder.instruction = prompt
        docids, corpus = self.read_corpus(task_dict['corpus_path'], 
            image_prefix_path=task_dict.get('image_prefix_path', self.args.image_prefix_path), 
            doc_id_field=task_dict.get('doc_id_field', self.args.doc_id_field), 
            doc_text_fields=task_dict.get('doc_text_field', self.args.doc_text_field),
            image_path_field=task_dict.get('doc_image_path_field', self.args.doc_image_path_field))
        logger.info('start loading queries')
        qids, queries = self.read_queries(task_dict['query_path'], 
            image_prefix_path=task_dict.get('image_prefix_path', self.args.image_prefix_path),
            query_id_field=task_dict.get('query_id_field', self.args.query_id_field), 
            query_text_field=task_dict.get('query_text_field', self.args.query_text_field),
            image_path_field=task_dict.get('query_image_path_field', self.args.query_image_path_field))
        logger.info('start building index')
        index_save_dir = os.path.join(self.index_save_dir, task_name)
        batch_size = self.args.batch_size
        self.index_builder.build_index(docids, corpus, index_save_dir=index_save_dir, index_type='dense', batch_size=batch_size, overwrite=self.args.overwrite)
        search_result_save_path = os.path.join(self.search_result_save_dir, task_name+".txt")
        logger.info('start searching index')
        retriever = DenseRetriever(encoder=self.encoder, faiss_gpu=self.args.faiss_gpu, dense_index_save_dir=index_save_dir)
        search_results = retriever.retrieval(qids, queries, topk=self.args.topk, batch_size=batch_size, result_save_path=search_result_save_path, overwrite=self.args.overwrite)

        if self.ranker is not None:
            print('start ranking')
            if task_name in self.task_dict:
                rank_instrucion = self.task_dict[task_name]['prompts'][0]
                self.ranker.reranker.instruction = rank_instrucion

            search_result_save_path = os.path.join(self.rank_result_save_dir, task_name + ".txt")
            if not self.args.overwrite and os.path.exists(search_result_save_path):
                logger.info(f'{search_result_save_path} has already exists, Skip...')
            else:
                corpus_dict = {str(_id): doc for _id, doc in zip(docids, corpus)}
                dids_list = [[str(ele[0]) for ele in search_results[qid]] for qid in qids]
                docs_list = [[corpus_dict[str(ele[0])] for ele in search_results[qid]] for qid in qids]
                self.ranker.rank(qids, queries, dids_list, docs_list, topk=self.args.topk, result_save_path=search_result_save_path, batch_size=batch_size)
        qrel_path = task_dict['qrel_path']
        results = self.evaluater.evaluate(qrel_path, search_result_save_path, result_save_path)
        results.update(task_dict)
        results['main_score'] = results[task_dict['main_metric']]
        results['encoder_name'] = self.encoder_name
        with open(result_save_path,'w') as fout:
            fout.write(json.dumps(results, ensure_ascii=False, indent=4))

    def stop(self):
        time.sleep(1)
        self.encoder.stop()
        if self.ranker:
            self.ranker.stop()


        
if __name__ == '__main__':
    parser = HfArgumentParser([EvalArgs])
    eval_args = parser.parse_args_into_dataclasses()[0]
    pipeline = RetrievalPipeline(eval_args)
    if eval_args.only_summary:
        pipeline.run_sumary()
    else:
        pipeline.run()
    pipeline.stop()
