import openai 
import os
import json
import numpy as np
import random
import torch
import re
import time
from typing import Any
import asyncio
import math
import string
from transformers import (AutoModelForCausalLM, logging,
                          AutoModelForSequenceClassification, AutoTokenizer,
                          OPTForCausalLM)



async def dispatch_openai_requests(
    messages_list: list[list[dict[str,Any]]],
    model: str,
    temperature: float,
    max_tokens: int,
    frequency_penalty: int,
    presence_penalty: int
) -> list[str]:
    """Dispatches requests to OpenAI API asynchronously.
    
    Args:
        messages_list: List of messages to be sent to OpenAI ChatCompletion API.
        model: OpenAI model to use.
        temperature: Temperature to use for the model.
        max_tokens: Maximum number of tokens to generate.
        top_p: Top p to use for the model.
    Returns:
        List of responses from OpenAI API.
    """
    async_responses = [openai.ChatCompletion.acreate(
            model=model,
            messages=x,
            temperature=temperature,
            max_tokens=max_tokens,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty) for x in messages_list]

    return await asyncio.gather(*async_responses)




async def dispatch_openai_gpt_requests(
    messages_list: list[list[dict[str,Any]]],
    engine: str,
    temperature: float,
    max_tokens: int,
    frequency_penalty: int,
    presence_penalty: int
) -> list[str]:
    """Dispatches requests to OpenAI API asynchronously.
    
    Args:
        messages_list: List of messages to be sent to OpenAI ChatCompletion API.
        model: OpenAI model to use.
        temperature: Temperature to use for the model.
        max_tokens: Maximum number of tokens to generate.
        top_p: Top p to use for the model.
    Returns:
        List of responses from OpenAI API.
    """
    
    async_responses = [openai.Completion.create(
        engine="text-davinci-003",
        prompt=x,
        temperature=temperature,
        max_tokens=max_tokens,
        frequency_penalty=0,
        presence_penalty=0,
        logprobs=2) for x in messages_list]

    return await asyncio.gather(*async_responses)






async def dispatch_embedding_requests(
    messages_list: list[list[dict[str,Any]]],
    model: str
) -> list[str]:
    """Dispatches requests to OpenAI API asynchronously.
    
    Args:
        messages_list: List of messages to be sent to OpenAI ChatCompletion API.
        model: OpenAI model to use.
        temperature: Temperature to use for the model.
        max_tokens: Maximum number of tokens to generate.
        top_p: Top p to use for the model.
    Returns:
        List of responses from OpenAI API.
    """

    async_responses = [openai.Embedding.acreate(
            model=model,
            input=x) for x in messages_list]

    return await asyncio.gather(*async_responses)



  
def fix_seed(seed):
    # random
    random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # Pytorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    


def async_gpt4_generate(prompt, temp):
    ml = [[{"role": "user", "content": p}] for p in prompt]
    answer = None

    while answer is None:
        try:
            predictions = asyncio.run(dispatch_openai_requests(
                messages_list = ml,
                model='gpt-4',
                temperature=temp,
                max_tokens=512,
                frequency_penalty=0,
                presence_penalty=0
                )
            )
        except Exception as e:
            print(e)
            print("Retrying....")
            time.sleep(20)

        try:
            answer = [x['choices'][0]['message']['content'] for x in predictions]
        except Exception:
            print("Please Wait!")

    return answer






def async_chatgpt_generate(prompt, temp):
    ml = [[{"role": "user", "content": p}] for p in prompt]
    answer = None

    while answer is None:
        try:
            predictions = asyncio.run(dispatch_openai_requests(
                messages_list = ml,
                model='gpt-3.5-turbo',
                temperature=temp,
                max_tokens=512,
                frequency_penalty=0,
                presence_penalty=0
                )
            )
        except Exception as e:
            print(e)
            print("Retrying....")
            time.sleep(20)

        try:
            answer = [x['choices'][0]['message']['content'] for x in predictions]
        except Exception:
            print("Please Wait!")

    return answer




def async_gpt3_generate(prompt, temp=0.0, max_token=256):
    
    answer_list = []
    for x in prompt:
        answer = None
        while answer is None:
            try:
                reply = openai.Completion.create(
                    engine="text-davinci-003",
                    prompt=x,
                    temperature=temp,
                    max_tokens=max_token,
                    frequency_penalty=0,
                    presence_penalty=0,
                    logprobs=5)

            except Exception as e:
                print(e)
                print("Retrying....")
                time.sleep(20)

            try:
                answer = reply['choices'][0]['text']                 
                answer_list.append(answer)
            except Exception:
                print("Please Wait!")
    return answer_list




def async_gpt3_generate_prob(prompt, temp=0.0, max_token=256):
    
    answer_list = []
    answer_prob_list = []
    for x in prompt:
        answer = None
        while answer is None:
            try:
                reply = openai.Completion.create(
                    engine="text-davinci-003",
                    prompt=x,
                    temperature=temp,
                    max_tokens=max_token,
                    frequency_penalty=0,
                    presence_penalty=0,
                    logprobs=5)

            except Exception as e:
                print(e)
                print("Retrying....")
                time.sleep(20)

            try:
                answer = reply['choices'][0]['text'] 
                log_prob = reply['choices'][0]['logprobs']['token_logprobs']
                lon_prob_mean = sum(log_prob) / len(log_prob)
                answer_list.append(answer)
                answer_prob_list.append(math.exp(lon_prob_mean))

            except Exception:
                print("Please Wait!")
    return answer_list, answer_prob_list






def async_embedding_generate(prompt):
    answer = None

    while answer is None:
        try:
            predictions = asyncio.run(dispatch_embedding_requests(
                messages_list = prompt,
                model='text-embedding-ada-002'
                )
            )
        except Exception as e:
            print(e)
            print("Retrying....")
            time.sleep(20)

        try:
            answer = [x['data'][0]['embedding'] for x in predictions]
        except Exception:
            print("Please Wait!")

    return answer



def answer_cleansing(args, pred):

    # print("pred_before : " + pred)
    if args.question_type in ("open_form"):
        return pred
    if args.question_type in ("multiple_choice"):
        pred = re.findall(r'A|B|C|D|E', pred)
    elif args.question_type in ("mathqa"):
        pred = pred.replace(",", "")
        pred = [s for s in re.findall(r'-?\d+\.?\d*', pred)]
    else:
        raise ValueError("question type is not properly defined ...")

    # If there is no candidate in list, null is set.
    if len(pred) == 0:
        pred = ""
    else:
        pred = pred[-1]
    # (For arithmetic tasks) if a word ends with period, it will be omitted ...
    if pred != "":
        if pred[-1] == ".":
            pred = pred[:-1]
    
    # print("pred_after : " + pred)
    # print('**********************')

    return pred





def ans_uq_cleansing(args, completion):

    
    # task_template = 'Please strictly use the following template to provide answer: explanation: [insert step-by-step analysis], \nanswer: [provide your answer].   
         

    # match the template ... explanation: ... \nanswer: ...
    answer_match = re.match(
        r"(?:.*)explanation:(?:.*)(\nanswer:)(?P<answer>.*)", completion, flags=re.IGNORECASE | re.DOTALL
    )
    if answer_match:
        return answer_match.group("answer")
    
    
    
    # match the template ... explanation ... \nanswer ...
    answer_match = re.match(
        r"(?:.*)explanation(?:.*)(\nanswer)(?P<answer>.*)", completion, flags=re.IGNORECASE | re.DOTALL
    )
    if answer_match:
        return answer_match.group("answer")
    
    
    
    # match the template   ... \nanswer: ...
    answer_match = re.match(
        r"(?:.*)(\nanswer:)(?P<answer>.*)", completion, flags=re.IGNORECASE | re.DOTALL
    )
    if answer_match:
        return answer_match.group("answer")



    # match the template   ... \nanswer ...
    answer_match = re.match(
        r"(?:.*)(\nanswer)(?P<answer>.*)", completion, flags=re.IGNORECASE | re.DOTALL
    )
    if answer_match:
        return answer_match.group("answer")
    



    # match the template   ... answer ... (only return the last match)
    answer_match = re.match(
        r"(?:.*)answer(?P<answer>.*)", completion, flags=re.IGNORECASE | re.DOTALL
    )
    
    if answer_match:
        return answer_match.group("answer")
    
    
    # no match 
    print(f"Could not parse observed consistency answer from completion: {completion}")
    return completion



def option_uq_cleansing(completion: str) -> float:
    """Parses self reflection answer.

    Answer is expected to be in the format "answer: ... A | B ...".
    - If the answer is "A" (corresponding to positive self reflection), 1.0 is returned
    - If the answer is "B" (corresponding to negative self reflection), 0.0 is returned
    - Otherwise, <low confidence> is returned.
    """
    
    # try to match \nanswer:
    answer_match = re.match(
        r"(?:.*)(\nanswer:).*\b\(?(?P<answer>[ab])\)?\b.*", completion, flags=re.IGNORECASE | re.DOTALL
    )
    if answer_match:
        return float(answer_match.group("answer").lower() == "a")

  
    # try to match answer:
    answer_match = re.match(
        r"(?:.*)answer:.*\b\(?(?P<answer>[ab])\)?\b.*", completion, flags=re.IGNORECASE | re.DOTALL
    )
    if answer_match:
        return float(answer_match.group("answer").lower() == "a")

  
    # try to match answer
    answer_match = re.match(
        r"(?:.*)answer.*\b\(?(?P<answer>[ab])\)?\b.*", completion, flags=re.IGNORECASE | re.DOTALL
    )
    
    if answer_match:
        return float(answer_match.group("answer").lower() == "a")
    
        
    print(f"Could not parse self reflection answer from completion: {completion}")
    return 0.5


# def option_uq_cleansing(pred_before):

#     # print("pred_before : " + pred_before)
    
#     pred_before = pred_before.replace(".", "")
#     pred_before = pred_before.replace(",", "")
#     pred_before = pred_before.replace("%", "")
#     pred_before = pred_before.replace("explanation:", "")
#     pred_before = pred_before.replace(":", "")
#     pred_before = pred_before.replace("Confidence", "confidence")
#     pred_before = pred_before.replace("Score", "score")
#     pred_before = pred_before.replace("Answer", "answer")

#     pred = pred_before.split('\nanswer')[-1]
#     pred = re.findall(r'A|B|C', pred)
#     if len(pred) == 0:
#         pred = 'C'
#     else:
#         pred = pred[0].replace(" ", "") 


#     if pred == 'A':
#         confidence = 1.0
#     elif pred == 'B':
#         confidence = 0.0
#     elif pred == 'C':
#         confidence = 0.5

#     return confidence


def eval_cleansing(pred_before):
    
    pred_before = pred_before.replace("Correct", "correct")
    pred_before = pred_before.replace("Incorrect", "incorrect")
    
    if re.search(r'\bincorrect\b', pred_before):
        return 'Incorrect'
    elif re.search(r'\bcorrect\b', pred_before):
        return 'Correct'
    else:
        return None




def jaccard_similarity(doc1, doc2): 
    
    # List the unique words in a document
    words_doc1 = set(doc1.replace(".", "").replace(",", "").lower().split()) 
    words_doc2 = set(doc2.replace(".", "").replace(",", "").lower().split())

    # Find the intersection of words list of doc1 & doc2
    intersection = words_doc1.intersection(words_doc2)

    # Find the union of words list of doc1 & doc2
    union = words_doc1.union(words_doc2)
        
    # Calculate Jaccard similarity score 
    # using length of intersection set divided by length of union set
    return float(len(intersection)) / len(union)





def data_reader(args):

    data_loader = []
    decoder = json.JSONDecoder()
         
    with open(args.question_json_dir) as f:
        lines = json.load(f)
        for line in lines:
            questions = line["question"].strip()

            if "answer" in line:
                answers = line["answer"]
            else:
                answers = None
            data_loader.append([questions, answers])

    return data_loader





class ClassifyWrapper():

    def __init__(self, dir, model_name='microsoft/deberta-large-mnli', device='cuda:0') -> None:
        logging.set_verbosity_error()
        self.model_device = device
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=dir)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name, cache_dir=dir).to(self.model_device)
        pass


    
    def batch_pred(self, prompt: str, sen_ref: str, sen: list, max_batch_size=128):
        # inputs = [question + sen_ref + ' [SEP] ' + question +  _ for _ in sen]
        inputs = [prompt + sen_ref + ' [SEP] ' + prompt +  _ for _ in sen]
        inputs = self.tokenizer(inputs, padding=True, truncation=True)
        input_ids = torch.tensor(inputs['input_ids']).to(self.model.device)
        attention_mask = torch.tensor(inputs['attention_mask']).to(self.model.device)
        logits = []
        for st in range(0, len(input_ids), max_batch_size):
            ed = min(st + max_batch_size, len(input_ids))
            logits.append(self.model(input_ids=input_ids[st:ed],
                                attention_mask=attention_mask[st:ed])['logits'])
        logits = torch.cat(logits, dim=0)
        # logits: [Contradiction, neutral, entailment]

        prob = torch.softmax(logits, 1).cpu().detach().numpy()
        contradiction = prob[:, 0]
        neutral = prob[:, 1]
        entailment = prob[:, 2]
        score = neutral + entailment

        ## flip over
        inputs = [prompt + _ + ' [SEP] ' + prompt + sen_ref for _ in sen]
        inputs = self.tokenizer(inputs, padding=True, truncation=True)
        input_ids = torch.tensor(inputs['input_ids']).to(self.model.device)
        attention_mask = torch.tensor(inputs['attention_mask']).to(self.model.device)
        logits = []
        for st in range(0, len(input_ids), max_batch_size):
            ed = min(st + max_batch_size, len(input_ids))
            logits.append(self.model(input_ids=input_ids[st:ed],
                                attention_mask=attention_mask[st:ed])['logits'])
        logits = torch.cat(logits, dim=0)
        # logits: [Contradiction, neutral, entailment]

        prob = torch.softmax(logits, 1).cpu().detach().numpy()
        contradiction_flip = prob[:, 0]
        neutral_flip = prob[:, 1]
        entailment_flip = prob[:, 2]
        score_flip = neutral_flip + entailment_flip 
        return  (score + score_flip) / 2




def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""

    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def handle_punc(text):
        exclude = set(string.punctuation + "".join([u"‘", u"’", u"´", u"`"]))
        return ''.join(ch if ch not in exclude else ' ' for ch in text)

    def lower(text):
        return text.lower()

    def replace_underscore(text):
        return text.replace('_', ' ')

    return white_space_fix(remove_articles(handle_punc(lower(replace_underscore(s))))).strip()


def print_result(args, save_dict, answer):

    dict_sort = sorted(save_dict, reverse=True)
    count = 0
    for key in dict_sort: 
        count += 1 
        print("{}st answer".format(count))
        if answer == None:
            print(f"Since you do not provide the answer, the answer is generated by {args.engine}:", key[1].replace("\n", " "))
            print("Confidence:",  key[0])
        else:
            print("Since you provide the answer, we will return the confidence for your provided answer:", key[1].replace("\n", " "))
            print("Confidence:", key[0])
        print('***************')



def save_resuls(args, save_list, key_list):
    last_key = key_list[-1]
    key_list.pop()
    os.makedirs(f'{args.save_log}', exist_ok=True) 
    with open(f'{args.save_log}/log.json', 'w') as f:
        f.write("[")
        for line in save_list[:-1]:
            f.write("{\n")
            for k in key_list:
                f.write(f'"{k}": "{line[k]}",\n')  # add comma at end of line
                
            k = last_key
            f.write(f'"{k}": "{line[k]}"\n')  # add comma at end of line
            f.write("},\n")

        line = save_list[-1]
        f.write("{\n")
        for k in key_list:
            f.write(f'"{k}": "{line[k]}",\n')  # add comma at end of line   
        k = last_key
        f.write(f'"{k}": "{line[k]}"\n')  # add comma at end of line
        f.write("}\n")
        f.write("]")


class llama_initial():
    
    def __init__(self, dir, device='cuda:0') -> None:
        logging.set_verbosity_error()
        self.model_device = device
        self.tokenizer = LlamaTokenizer.from_pretrained(dir)
        self.model = LlamaForCausalLM.from_pretrained(dir, device_map="auto")
        pass

    
    def batch_pred(self, prompts: list, temp=0.0):
        text = []
        for prompt in prompts:
            input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.cuda()
            # Generate
            generate_ids = self.model.generate(input_ids, max_new_tokens=64)
            text.append(self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0])

        return text
    
    
    
