import datetime
import os.path
import random
import time
import re
import tqdm
import numpy as np
import sys
import torch

from test_chatgpt import ChatGPTAPI
from benchmark import qmsum_interface
# import tiktoken 
from tensorboardX import SummaryWriter
from lib.metrics import Metric

from serpapi import GoogleSearch
from IPython.utils import io
import requests
import string


class qmsum_solver():
    def __init__(self, args, solver_name="qmsum_solver"):
        self.chatbot = ChatGPTAPI(model_name="gpt-3.5-turbo")
        self.oracle = ChatGPTAPI(model_name='gpt-4')
        self.args = args
        self.metric = Metric(self.chatbot)
        self.logger = SummaryWriter(os.path.join("logs", solver_name + "/" + datetime.datetime.now().strftime("%y-%m-%d-%H:%M:%S")) + args.suffix)
    
    def __post_ask_call(self, qa, **kwargs):
        self.logger.add_text(f"output_summary", f"Question: {kwargs['original_question']} \n\n Ground truth: {kwargs['ground_truth']} \n\n Total DG answer: {kwargs['total_dg_ans']} \n\n  Clean DG answer: {kwargs['original_answer']} \n\n Total Self-Ask answer: {kwargs['total_cur_ans']} \n\n  Clean Self-Ask answer: {kwargs['answer']} \n\n  QA List: {kwargs['qa_list']} \n\n Score Summary: {kwargs['summary']}", self.metric.active_entity_logger.all_count)

        self.logger.add_scalar("score_avg/original_f1", self.metric.original_f1_logger.avg['f1'], self.metric.original_f1_logger.all_count)
        self.logger.add_scalar("score_avg/active_f1", self.metric.active_f1_logger.avg['f1'], self.metric.active_f1_logger.all_count)

        self.logger.add_scalar("prompt_acc_score_avg/original_acc", self.metric.original_prompt_acc_logger.avg, self.metric.original_prompt_acc_logger.all_count)
        self.logger.add_scalar("prompt_acc_score_avg/active_acc", self.metric.active_prompt_acc_logger.avg, self.metric.active_prompt_acc_logger.all_count)

        self.logger.add_scalar("exact_match_score/original", self.metric.original_em_logger.avg['correct'], self.metric.original_em_logger.all_count)
        self.logger.add_scalar("exact_match_score/active", self.metric.active_em_logger.avg['correct'], self.metric.active_em_logger.all_count)

    def ask(self, problem, whole)->dict:
        pass

    def iterate_one_task(self, task:qmsum_interface):
        qa_list = task.dataset['qa_list']
        question_cnt = 0
        for qa in qa_list:
            if question_cnt == 400:
                print(f"\n\nTask: {self.task_name}, all scores summary")
                print(self.metric.get_all_summary())
                exit(0)
            res = self.ask(qa, task.dataset)
            summary = self.metric.calculate_all_scores(res['answer'], orginal_ans=res['original_answer'], gt=res['ground_truth'], question=['query'])
            res.update({"summary": summary})
            self.__post_ask_call(qa, **res)
            question_cnt = question_cnt + 1


    def run(self):
        from utils import ES_full_task_list, IS_full_task_list, TS_full_task_list, wikimultihop_list, musique_list
        if self.args.task == "ES":
            full_task_list = ES_full_task_list
        elif self.args.task == "IS":
            full_task_list = IS_full_task_list
        elif self.args.task == "TS":
            full_task_list = TS_full_task_list
        elif self.args.task == "wikimultihopqa":
            full_task_list = wikimultihop_list
        elif self.args.task == "musique":
            full_task_list = musique_list
        else:
            sys.exit("Unknow task name")
            Exception
        for task_name in full_task_list:
            self.task_name = task_name
            print(task_name)

            task = qmsum_interface(self.args.task, task_name)
            self.iterate_one_task(task)

            print(f"\n\nTask: {self.task_name}, all scores summary")
            print(self.metric.get_all_summary())


class SASE_solver(qmsum_solver):

    """
        Self-Ask with Search Engine solver.
    """
    def __init__(self, args):
        self.args = args
        self.serpapi_key = "Enter your key here"
        self.bing_key = 'Enter your key here'
        
        #self-ask prompts
        self.intermediate = "\nIntermediate answer:"
        self.followup = "Follow up:"
        self.finalans= '\nSo the final answer is:'

        self.bprompt_2wiki = ['''Question: Who lived longer, Theodor Haecker or Harry Vaughan Watkins?
        Are follow up questions needed here: Yes.
        Follow up: How old was Theodor Haecker when he died?
        Intermediate answer: Theodor Haecker was 65 years old when he died.
        Follow up: How old was Harry Vaughan Watkins when he died?
        Intermediate answer: Harry Vaughan Watkins was 69 years old when he died.
        So the final answer is: Harry Vaughan Watkins.

        Question: Are both director of film FAQ: Frequently Asked Questions and director of film The Big Money from the same country?
        Are follow up questions needed here: Yes.
        Follow up: Who directed the film FAQ: Frequently Asked Questions?
        Intermediate answer: Carlos Atanes.
        Follow up: Who directed the film The Big Money?
        Intermediate answer: John Paddy Carstairs.
        Follow up: What is the nationality of Carlos Atanes?
        Intermediate answer: Carlos Atanes is Spanish.
        Follow up: What is the nationality of John Paddy Carstairs?
        Intermediate answer: John Paddy Carstairs is British.
        So the final answer is: No.''',
        '''Are follow up questions needed here: (Here, you can say 'Yes.\nFollow up: ...', like demonstrations mentioned above, or 'No.\nSo the final answer is:...' to provide the answer directly)''', 'myprompt']

        self.q_a_2wiki = ['''Question: Who lived longer, Theodor Haecker or Harry Vaughan Watkins?
        So the final answer is: Harry Vaughan Watkins.

        Question: Are both director of film FAQ: Frequently Asked Questions and director of film The Big Money from the same country?
        So the final answer is: No.''',
        '''So the final answer is:''', 'baseline']

        self.bprompt_mus = ['''Question: When does monsoon season end in the state the area code 575 is located?
        Are follow up questions needed here: Yes.
        Follow up: Which state is the area code 575 located in?
        Intermediate answer: The area code 575 is located in New Mexico.
        Follow up: When does monsoon season end in New Mexico?
        Intermediate answer: Monsoon season in New Mexico typically ends in mid-September.
        So the final answer is: mid-September.

        Question: The birth country of Jayantha Ketagoda left the British Empire when?
        Are follow up questions needed here: Yes.
        Follow up: What is the birth country of Jayantha Ketagoda?
        Intermediate answer: Sri Lanka.
        Follow up: When did Sri Lanka leave the British Empire?
        Intermediate answer: Sri Lanka left the British Empire on February 4, 1948.
        So the final answer is: February 4, 1948.''',
        '''Are follow up questions needed here: (Here, you can say 'Yes.\nFollow up: ...', like demonstrations mentioned above, or 'No.\nSo the final answer is:...' to provide the answer directly)''', 'cur_prompt_followup_no1shot_noanalysis_fixperiod']

        self.q_a_mus = ['''Question: When does monsoon season end in the state the area code 575 is located?
        So the final answer is: mid-September.

        Question: The birth country of Jayantha Ketagoda left the British Empire when?
        So the final answer is: February 4, 1948.''',
        '''So the final answer is:''', 'baseline_q_a_mus']

        super(SASE_solver, self).__init__(args, f"self_ask_search_engine")
        
    def ask(self, qa, whole):
        if self.args.task == "wikimultihopqa":
            self_ask_prompt = self.bprompt_2wiki
            baseline_prompt = self.q_a_2wiki
        elif self.args.task == "musique":
            self_ask_prompt = self.bprompt_mus
            baseline_prompt = self.q_a_mus

        print('####################################')
        print(f"query: {qa['query']}")
        print(f"answer: {qa['answer']}")

        question = qa['query']

        dg_prompt = baseline_prompt[0] + '\n' + '\n' + 'Question: ' + question + '\n' + baseline_prompt[1]
        total_dg_ans = self.chatbot.ask(question=dg_prompt, temperature=0, stop=None)
        clean_dg_ans = self.normalize_answer(self.extract_answer(total_dg_ans))

        qa_list = []

        if 'no_search' in self.args.suffix:
            cur_prompt = self_ask_prompt[0] + '\n' + '\n' + 'Question: ' + question + '\n' + self_ask_prompt[1]
            total_cur_ans = self.chatbot.ask(question=cur_prompt, temperature=0, stop=None)
            clean_cur_ans = self.normalize_answer(self.extract_answer(total_cur_ans))
        elif 'bing_search' in self.args.suffix:
            cur_prompt = self_ask_prompt[0] + '\n' + '\n' + 'Question: ' + question + '\n' + self_ask_prompt[1]
            ret_text = self.chatbot.ask(question=cur_prompt, temperature=0, stop=self.intermediate)

            while self.followup in self.get_last_line(ret_text):

                cur_prompt += ret_text
                question = self.extract_question(ret_text)
                external_answer = self.bing(question)
                qa_list.append({"question":question, "answer":external_answer})
                print(question, end='')

                if external_answer is not None:
                    cur_prompt += self.intermediate + ' ' + external_answer + '.'
                    print(self.intermediate + ' ' + self.yellowfy(external_answer) + '.')
                    ret_text = self.chatbot.ask(question=cur_prompt, temperature=0, stop=self.intermediate)
                else:
                    cur_prompt += self.intermediate
                    print(self.intermediate + ' ')
                    gpt_answer = self.chatbot.ask(question=cur_prompt, temperature=0, stop=['\n'+self.followup, self.finalans])
                    cur_prompt += gpt_answer

            if self.finalans not in ret_text:
                cur_prompt += self.finalans
                ret_text = self.chatbot.ask(question=cur_prompt, temperature=0, stop='\n')

            total_cur_ans = ret_text
            clean_cur_ans = self.normalize_answer(self.extract_answer(total_cur_ans))

        answer = dict(
            original_question = qa['query'],
            ground_truth = qa['answer'],
            total_dg_ans=total_dg_ans,
            original_answer=clean_dg_ans,
            total_cur_ans=total_cur_ans,
            answer=clean_cur_ans,
            qa_list=qa_list
        )

        print(f'clean_dg_ans: {clean_dg_ans}')
        print(f'clean_cur_ans: {clean_cur_ans}')
        print('####################################')

        return answer
    
##################################################################################################################################
##################################################################################################################################
    def normalize_answer(self, s):
        """Lower text and remove punctuation, articles and extra whitespace."""
        def remove_articles(text):
            return re.sub(r'\b(a|an|the)\b', ' ', text)

        def white_space_fix(text):
            return ' '.join(text.split())

        def remove_punc(text):
            exclude = set(string.punctuation)
            return ''.join(ch for ch in text if ch not in exclude)

        def lower(text):
            return text.lower()

        return white_space_fix(remove_articles(remove_punc(lower(s))))

    def extract_answer(self, generated):
        if '\n' not in generated:
            last_line =  generated
        else:
            last_line = generated.split('\n')[-1]

        if ':' not in last_line:
            after_colon = last_line
        else:
            after_colon = generated.split(':')[-1]

        if ' ' == after_colon[0]:
            after_colon = after_colon[1:]
        if '.' == after_colon[-1]:
            after_colon = after_colon[:-1]

        return after_colon

    def extract_question(self, generated):
        if '\n' not in generated:
            last_line =  generated
        else: 
            last_line = generated.split('\n')[-1]

        if 'Follow up:' not in last_line:
            print('we probably should never get here...' + generated)

        if ':' not in last_line:
            after_colon = last_line
        else:
            after_colon = generated.split(':')[-1]
        
        if ' ' == after_colon[0]:
            after_colon = after_colon[1:]
        if '?' != after_colon[-1]:
            print('we probably should never get here...' + generated)

        return after_colon
    
    def get_last_line(self, generated):
        if '\n' not in generated:
            last_line =  generated
        else: 
            last_line = generated.split('\n')[-1]

        return last_line
    
    def greenify(self, input):
        return "\x1b[102m" + input + "\x1b[0m"

    def yellowfy(self, input):
        return "\x1b[106m" + input + "\x1b[0m"
    
    def google(self, question):
        params = {
            "api_key": self.serpapi_key,
            "engine": "google",
            "q": question,
            "google_domain": "google.com",
            "gl": "us",
            "hl": "en"
        }

        with io.capture_output() as captured: #disables prints from GoogleSearch
            search = GoogleSearch(params)
            res = search.get_dict()

        if 'answer_box' in res.keys() and 'answer' in res['answer_box'].keys():
            toret = res['answer_box']['answer']
        elif 'answer_box' in res.keys() and 'snippet' in res['answer_box'].keys():
            toret = res['answer_box']['snippet']
        elif 'answer_box' in res.keys() and 'snippet_highlighted_words' in res['answer_box'].keys():
            toret = res['answer_box']["snippet_highlighted_words"][0]
        elif 'snippet' in res["organic_results"][0].keys():
            toret= res["organic_results"][0]['snippet'] 
        else:
            toret = None

        return toret
    
    def bing(self, question):
        # Add your Bing Search V7 subscription key and endpoint to your environment variables.
        subscription_key = self.bing_key
        endpoint = "https://api.bing.microsoft.com/v7.0/search"

        # Query term(s) to search for. 
        query = question

        # Construct a request
        mkt = 'en-US'
        params = {
            'q': query,
            'mkt': mkt,
            'count':5,
            'responseFilter': 'Webpages',
        }
        headers = { 'Ocp-Apim-Subscription-Key': subscription_key }

        # Call the API
        try:
            response = requests.get(endpoint, headers=headers, params=params)
            response.raise_for_status()
            response = response.json()
            try:
                web_list = response['webPages']['value']
                ans = ""
                for idx, web in enumerate(web_list):
                    ans += f"[Answer{idx}: {web['snippet']}]"
                # simplify answers
                pmpt =  f'For the question: "{query}", there are some answers: "{ans}". Please answer the question in one sentence or one word based on the answers provided. (ps: Some answers may have errors, please identify and ignore them.)'
                ret = self.chatbot.ask(question=pmpt, temperature=0, stop=None)
            except:
                ret = self.chatbot.ask(question=question, temperature=0, stop=None)
            return ret
        except Exception as ex:
            raise ex  
##################################################################################################################################
##################################################################################################################################
