import datetime
import os.path
import random
import time
import re
import numpy as np
import sys


from test_chatgpt import ChatGPTAPI
from query_strategy import random_query, diversity_based_query, similarity_based_query, base_query_strategy
from benchmark import qmsum_interface
# import tiktoken 
from tensorboardX import SummaryWriter
from lib.metrics import Metric


from lib.retrievers import BaseRetriever, CustomBM25Retriever, EnsembleRetriever, EnsembleRerankRetriever
from lib.base import HuggingfaceEmbeddings
from utils import wikimultihop_demo, wikimultihop_demo_zero_shot

class count():

    def __init__(self):
        self.cur_time = time.time()
        self.all_count = 0
        self.score = 0
        self.avg = 0

    def update(self, score):
        self.score += score
        self.all_count += 1
        if (self.score == 0):
            self.avg = 0
        else:
            self.avg = 1.0 * self.score / self.all_count

    def summary(self):

        return dict(
            total_number = self.all_count,
            avg_score = self.avg,
            time_consuming = (time.time() - self.cur_time) / 60
        )

class qmsum_solver():
    def __init__(self, args, solver_name="qmsum_solver"):
        self.chatbot = ChatGPTAPI(model_name=args.model)
        self.oracle = ChatGPTAPI(model_name='gpt-4')
        self.args = args
        self.metric = Metric(self.chatbot)
        self.logger = SummaryWriter(os.path.join("logs", solver_name + "/" + datetime.datetime.now().strftime("%y-%m-%d-%H:%M:%S")) + args.suffix)
    def __post_ask_call(self, qa, **kwargs):
        problem = qa['query']
        self.logger.add_text(f"output_summary/{self.task_name}", f"(Active) Question: {problem} \n\n Ground truth: {kwargs['ground_truth']} \n\n Problem with hints: {kwargs['problem_with_hints']} \n\n DG answer: {kwargs['original_answer']} \n\n Our answer: {kwargs['answer']} \n\n Explaination: {kwargs['active_explain']} \n\nScore Summary: {kwargs['summary']}", self.metric.active_entity_logger.all_count)

        self.logger.add_scalar("score_avg/original_f1", self.metric.original_f1_logger.avg['f1'], self.metric.original_f1_logger.all_count)
        self.logger.add_scalar("score_avg/active_f1", self.metric.active_f1_logger.avg['f1'], self.metric.active_f1_logger.all_count)

        self.logger.add_scalar("prompt_acc_score_avg/original_acc", self.metric.original_prompt_acc_logger.avg, self.metric.original_prompt_acc_logger.all_count)
        self.logger.add_scalar("prompt_acc_score_avg/active_acc", self.metric.active_prompt_acc_logger.avg, self.metric.active_prompt_acc_logger.all_count)

        self.logger.add_scalar("exact_match_score/original", self.metric.original_em_logger.avg['correct'], self.metric.original_em_logger.all_count)
        self.logger.add_scalar("exact_match_score/active", self.metric.active_em_logger.avg['correct'], self.metric.active_em_logger.all_count)

    def iterate_one_task(self, task:qmsum_interface):
        qa_list = task.dataset['qa_list']
        question_cnt = 0
        for qa in qa_list:
            if question_cnt == 200:
                print(f"\n\nTask: {self.task_name}, all scores summary")
                print(self.metric.get_all_summary())
                exit(0)
            res = self.ask(qa, task.dataset)
            summary = self.metric.calculate_all_scores(res['answer'], orginal_ans=res['original_answer'], gt=res['ground_truth'], question=['query'])
            res.update({"summary": summary})
            self.__post_ask_call(qa, **res)
            question_cnt = question_cnt + 1


    def run(self):
        from utils import ES_full_task_list, IS_full_task_list, TS_full_task_list, wikimultihop_list, musique_list
        if self.args.task == "ES":
            full_task_list = ES_full_task_list
        elif self.args.task == "IS":
            full_task_list = IS_full_task_list
        elif self.args.task == "TS":
            full_task_list = TS_full_task_list
        elif self.args.task == "wikimultihopqa":
            full_task_list = wikimultihop_list
        elif self.args.task == "musique":
            full_task_list = musique_list
        else:
            sys.exit("Unknow task name")
            Exception
        for task_name in full_task_list:
            self.task_name = task_name
            print(task_name)
            task = qmsum_interface(self.args.task, task_name)
            self.iterate_one_task(task)

            print(f"\n\nTask: {self.task_name}, all scores summary")
            print(self.metric.get_all_summary())

    def extract_ans(self, ans):
        answer_match = re.search(r"the answer is:(.*)", ans, re.DOTALL)
        answer_content = None
        explanation_content = None
        if answer_match:
            answer_content = answer_match.group(1).strip()
            print("Answer content:", answer_content)

        # Extract the explanation content
        explanation_match = re.search(r"(.*?)So the answer is:", ans, re.DOTALL)
        if explanation_match:
            explanation_content = explanation_match.group(1).strip()
            print("Explanation content:", explanation_content)
        else:
            explanation_content = ans
            answer_content = ans
        
        return explanation_content, answer_content
 




class AICL_solver(qmsum_solver):

    """
        Active In-context Learning solver.
    """
    
    _question_gen_prompt = "Here is the task {0}. List your questions about the task with numbers. Please ask task-related questions. Your questions are:"
    _query_oracle_prompt = "Please answer the question as short as possible based on the background information. The background information is: {0}. The question is: {1}"
    _task_prompt = """Here is the question: {0}. You need to answer the question in the following format. 
    For example: 
    Question: When did the director of film Hypocrite (Film) die? 
    Answer: The film Hypocrite was directed by Miguel Morayta. Miguel Morayta died on 19 June 2013. 
    So the answer is: 19 June 2013. 

    Question: Are both Kurram Garhi and Trojkrsti located in the same country?
    Answer: (Think step by step) Kurram Garhi is located in the country of Pakistan. Trojkrsti is located in the country of Republic of Macedonia. Thus, they are not in the same country. 
    So the answer is: no.

    First think step by step, then give your own answer like: 'So the answer is: [your answer]'  """

    _task_prompt_one_shot = """Here is the question: {0}. You need to answer the question in the following format. 
    For example: 
    Question: When did the director of film Hypocrite (Film) die? 
    Answer: The film Hypocrite was directed by Miguel Morayta. Miguel Morayta died on 19 June 2013. 
    So the answer is: 19 June 2013. 

    Give your own answer like: 'So the answer is: [your answer]'  """
    _active_trigger_prompt = "Do you have any "
    _format_prompt = wikimultihop_demo_zero_shot


    def __init__(self, args):
        # self.oracle = ChatGPTAPI(model_name="gpt-4")
        query_strategy = args.strategy
        self.args = args
        if query_strategy == "random":
            self.query = base_query_strategy(args.k)
        elif query_strategy == "similarity":
            self.query = similarity_based_query(args.k)
        elif query_strategy == "diversity":
            self.query = diversity_based_query(args.k)
        else:
            raise NotImplementedError(f"The query strategy {query_strategy} has not been implemented!")
        super(AICL_solver, self).__init__(args, f"AICLQmsum_solver({query_strategy})")
        
    def ask(self, qa, whole):

        problem = qa['query']
        # absent_background = qa['relevant_text_span_absent']
        relevent_text = qa['relevant_text_span']
        absent_background = relevent_text
        if self.args.is_full_context:
            absent_background = relevent_text

        gt = qa['answer']
        original_ans = self.chatbot.ask(self._task_prompt_one_shot.format(problem))
        explain, answer = self.extract_ans(original_ans)
        original_ans = answer

        AL_flag = self._need_active_learning(self._task_prompt.format(absent_background, problem), original_ans)
        
        if AL_flag:
            # step 2: generation some questions based on current question
            question_list = self._generate_question_list(problem, absent_background)

            # step 3: select questions from the question set
            query_list = self._query_strategy(problem, question_list)

            # step 4: actively asking
            QA_pairs = self._actively_asking(problem, relevent_text, query_list)

            # step 5: obtain answers with hints
            problem_with_hints = "You should generate your answer based on the following hints (combined by helping questions and their answers). \n"
            for idx, QA in enumerate(QA_pairs):
                q, a = QA
                problem_with_hints = problem_with_hints + f"\n Hint {idx+1}: '{q}? A: {a}\n'"

            problem_with_hints += self._format_prompt + problem

            active_explain, active_ans = self.active_ask(problem_with_hints)
        
            answer = dict(
                original_answer=original_ans,
                answer=active_ans,
                question_list=question_list,
                query_list=query_list,
                QA_pairs=QA_pairs,
                problem_with_hints=problem_with_hints,
                ground_truth = gt,
                active_flag = AL_flag,
                active_explain = active_explain
            )
        else:
            answer = dict(
                original_answer=original_ans,
                answer=None,
                question_list=None,
                query_list=None,
                QA_pairs=None,
                problem_with_hints=None,
                ground_truth = gt,
                active_flag = AL_flag
            )
        return answer

    def active_ask(self, problem_with_hints):
        ans = self.chatbot.ask(problem_with_hints)
        explain, answer = self.extract_ans(ans)
        if answer is None:
            return ans, ans
        else:
            return explain, answer

    def _extract_que(self, questions):
        """
            Extract questions from the original model output
        :param questions:
        :return:
        """
        try:
            ## Digit. sentence(?;)
            pattern = r'\d+\.\s(.+?)\?'
            matches = re.findall(pattern, questions)
            return matches
        except Exception as e:
            SystemExit("Error occurs when extracting answer:", e, "The questions is:", questions)
            return []


    def _need_active_learning(self, question, original_ans):

        answer_list = [self.chatbot.ask(question) for i in range(5)]
        embedding_list = [self._sentence_encoder.encode(answer) for answer in answer_list]
        original_ans_embedding = self._sentence_encoder.encode(original_ans)
        ######################### by standard deviation #########################
        cos_similarity_list = []
        for e in  embedding_list:
            cos_similarity = self._cosine_encoder(original_ans_embedding, e)
            print(cos_similarity)
            cos_similarity_list.append(cos_similarity)
        standard_deviation = np.std(np.array(cos_similarity_list))
        print(standard_deviation)
        # all_standard_deviation.append(standard_deviation)
        is_activate = (standard_deviation > 0.015)
        ######################################################################## 

        if is_activate:
            print('Activate!')
            return True
        else:
            print('Not Activate!')
            return False


    def _generate_question_list(self, problem, background)->list:
        res = self.chatbot.ask(self._question_gen_prompt.format(problem))
        print("original questions are:", problem)
        print("generated questions are:", res)
        self.logger.add_text(f"generate questions/{self.task_name}", f"Original questions are: {problem} \n\n Generated questions are: {res}", self.metric.active_f1_logger.all_count)
        question_list = self._extract_que(res)

        if len(question_list) == 0:
            return self._generate_question_list(problem, background)
        else:
            return question_list


    def _query_strategy(self, problem, question_list)->list:
        query_list = self.query.query(problem, question_list)

        return query_list



    def _actively_asking(self, problem, background,  query_list)->list:
        task_name = self.task_name
        # self.logger.add_text(f"filter_background_info/{self.task_name}", f"Original question is: {problem} \n\n Filtered background information is: {background_info}", self.task_stat.all_count)
        background_info = background
        answer_list = [self.oracle.ask(self._query_oracle_prompt.format(background_info, question)) for question in query_list]

        QA_list = [[Q, A] for Q,A in zip(query_list, answer_list)]

        return QA_list
    

class RAG_solver(qmsum_solver):

    """
       RAG solver.
    """
    _task_prompt = """Here is the question: {0}. And here is the context information {1}.
    You need to answer the question based on the context information provided in the following format. 
    For example: 
    Question: When did the director of film Hypocrite (Film) die? 
    Answer: The film Hypocrite was directed by Miguel Morayta. Miguel Morayta died on 19 June 2013. 
    So the answer is: 19 June 2013. 

    Give your own answer like: 'So the answer is: [your answer]'  """

    def __init__(self, args):
        self.args = args
        embed_model = HuggingfaceEmbeddings()
        if args.retriever_name == "base":
            retriever = BaseRetriever(
                args.docs_path, embed_model=embed_model,
                construct_index=True, add_index=False, similarity_top_k=args.retrieve_top_k
            )
        elif args.retriever_name == "bm25":
            retriever = CustomBM25Retriever(
                args.docs_path, embed_model=embed_model, 
                similarity_top_k=args.retrieve_top_k
            )
        elif args.retriever_name == "hybrid":
            retriever = EnsembleRetriever(
                args.docs_path, embed_model=embed_model, construct_index=True, add_index=False,
                similarity_top_k=args.retrieve_top_k
            )
        elif args.retriever_name == "hybrid-rerank":
            retriever = EnsembleRerankRetriever(
                args.docs_path, embed_model=embed_model, 
                construct_index=True, similarity_top_k=args.retrieve_top_k
            )
        else:
            raise ValueError(f"Unknown retriever: {args.retriever_name}")

        print(args.retriever_name)
        self.retriever = retriever
        super(RAG_solver, self).__init__(args)
        
    def ask(self, qa, whole):

        problem = qa['query']
        relevent_text = qa['relevant_text_span']
        absent_background = relevent_text

        gt = qa['answer']
        context = self.retriever.search_docs(problem)
        print(context)
        original_ans = self.chatbot.ask(self._task_prompt.format(problem, context))
        explain, answer = self.extract_ans(original_ans)
        original_ans = answer

        active_ans = "Dummy answer"
        question_list = ['Dummy question']
        query_list = ['query_list']
        QA_pairs = [['Dummy question', 'Dummpy answer']]
        problem_with_hints = ['Dummy content']
        AL_flag = True
        active_explain = explain

        answer = dict(
            original_answer=original_ans,
            answer=active_ans,
            question_list=question_list,
            query_list=query_list,
            QA_pairs=QA_pairs,
            problem_with_hints=problem_with_hints,
            ground_truth = gt,
            active_flag = AL_flag,
            active_explain = active_explain
        )
        return answer

   
   








  
