'''
Adapted from https://github.com/kojima-takeshi188/zero_shot_cot
'''

from statistics import mean
from torch.utils.data import Dataset
# import openai
import os
import logging
import multiprocessing
import json
import numpy as np
import torch
import re
import random
import time
import datetime
import pandas as pd

from transformers import AutoModelForCausalLM, AutoTokenizer

def shuffleDict(d):
  keys = list(d.keys())
  random.shuffle(keys)
  [(key, d[key]) for key in keys]
  random.shuffle(keys)
  [(key, d[key]) for key in keys]
  random.shuffle(keys)
  keys = [(key, d[key]) for key in keys]
  #keys = d(keys)
  return dict(keys)
  
def fix_seed(seed):
    # random
    random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # Pytorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    
def print_now(return_flag=0):
    t_delta = datetime.timedelta(hours=9)
    JST = datetime.timezone(t_delta, 'JST')
    now = datetime.datetime.now(JST)
    now = now.strftime('%Y/%m/%d %H:%M:%S')
    if return_flag == 0:
        print(now)
    elif return_flag == 1:
        return now
    else:
        pass

# # Sentence Generator (Decoder) for GPT-3 ...
# def decoder_for_gpt3(args, input, max_length):
    
#     # GPT-3 API allows each users execute the API within 60 times in a minute ...
#     # time.sleep(1)
#     time.sleep(args.api_time_interval)
    
#     # https://beta.openai.com/account/api-keys
#     # openai.api_key = "[Your OpenAI API Key]"
    
#     # Specify engine ...
#     # Instruct GPT3
#     if args.model == "gpt3":
#         engine = "text-ada-001"
#     elif args.model == "gpt3-medium":
#         engine = "text-babbage-001"
#     elif args.model == "gpt3-large":
#         engine = "text-curie-001"
#     elif args.model == "gpt3-xl":
#         engine = "text-davinci-002"
#     elif args.model == "text-davinci-001":
#         engine = "text-davinci-001"
#     elif args.model == "code-davinci-002":
#         engine = "code-davinci-002"
#     else:
#         raise ValueError("model is not properly defined ...")
        
#     if ("few_shot" in args.method or "auto" in args.method)  and engine == "code-davinci-002":
#         response = openai.Completion.create(
#           engine=engine,
#           prompt=input,
#           max_tokens=max_length,
#           temperature=args.temperature,
#           top_p=1,
#           frequency_penalty=0,
#           presence_penalty=0,
#           stop=["\n"]
#         )
#     else:
#         response = openai.Completion.create(
#             engine=engine,
#             prompt=input,
#             max_tokens=max_length,
#             temperature=args.temperature,
#             top_p=1,
#             frequency_penalty=0,
#             presence_penalty=0,
#             stop=None
#         )

#     return response["choices"][0]["text"]

class Decoder():
    def __init__(self):
        # print_now()
        pass
 
    def decode(self, args, input, max_length):
        response = decoder_for_gpt3(args, input, max_length)
        return response

class HF_Decoder():
    def __init__(self, args):
        self.tokenizer = AutoTokenizer.from_pretrained(args.model_path+args.model)
        self.model = AutoModelForCausalLM.from_pretrained(args.model_path+args.model, device_map="auto", torch_dtype=torch.float16)
 
    def decode(self, args, input, max_length, extract=False):

        terminators = [
            self.tokenizer.eos_token_id,
            self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]

        if extract:

            messages = [
                {"role": "system", "content": "You are a helpful chatbot"},
                {"role": "user", "content": input},
            ]
            input_ids = self.tokenizer.apply_chat_template(
                messages,
                add_generation_prompt=True,
                return_tensors="pt"
            ).to("cuda")

            outputs = self.model.generate(input_ids, 
                                        pad_token_id=self.tokenizer.eos_token_id,
                                        return_dict_in_generate=True, 
                                        output_scores=True, 
                                        max_length=max_length,
                                        eos_token_id=terminators,
                                        do_sample=False,
                                        )
        else:
            messages = [
                {"role": "system", "content": "You are a creative chatbot who always think outside the box and answer creatively!"},
                {"role": "user", "content": input},
            ]
            input_ids = self.tokenizer.apply_chat_template(
                messages,
                add_generation_prompt=True,
                return_tensors="pt"
            ).to("cuda")

            outputs = self.model.generate(input_ids, 
                                        pad_token_id=self.tokenizer.eos_token_id,
                                        return_dict_in_generate=True, 
                                        output_scores=True, 
                                        max_length=max_length,
                                        eos_token_id=terminators,
                                        do_sample=True,
                                        temperature=args.temperature,
                                        top_p=0.9
                                        )
        
        
        # Get log_likelihoods.
        # outputs.scores are the logits for the generated token.
        # outputs.scores is a tuple of len = n_generated_tokens.
        # Each entry is shape (bs, vocabulary size).
        # outputs.sequences is the sequence of all tokens: input and generated.
        transition_scores = self.model.compute_transition_scores(
            outputs.sequences, outputs.scores, normalize_logits=True)
        # Transition_scores[0] only contains the scores for the first generated tokens.

        log_likelihoods = [score.item() for score in transition_scores[0]]
        # if len(log_likelihoods) == 1:
        #     logging.warning('Taking first and only generation for log likelihood!')
        #     log_likelihoods = log_likelihoods
        # else:
        #     log_likelihoods = log_likelihoods[:n_generated]

        # if len(log_likelihoods) == self.max_new_tokens:
        #     logging.warning('Generation interrupted by max_token limit.')

        if len(log_likelihoods) == 0:
            raise ValueError
        
        response = self.tokenizer.batch_decode(outputs.sequences[:, input_ids.shape[1]:], skip_special_tokens=True)[0]
        
        return response, log_likelihoods
    

def data_reader(args):

    questions = []
    answers = []
    decoder = json.JSONDecoder()

    if args.dataset == "aqua":
      with open(args.dataset_path) as f:
        lines = f.readlines()
        for line in lines:
          json_res = decoder.raw_decode(line)[0]
          choice = "(" + "(".join(json_res["options"])
          choice = choice.replace("(", " (").replace(")", ") ")
          choice = "Answer Choices:" + choice
          questions.append(json_res["question"].strip() + " " + choice)
          answers.append(json_res["correct"])
  
    elif args.dataset == "gsm8k":
      with open(args.dataset_path) as f:
        lines = f.readlines()
        for line in lines:
          json_res = decoder.raw_decode(line)[0]
          questions.append(json_res["question"].strip())
          answers.append(json_res["answer"].split("#### ")[-1])
  
    elif args.dataset == "commonsensqa":
      with open(args.dataset_path) as f:
        lines = f.readlines()
        for line in lines:
          json_res = decoder.raw_decode(line)[0]
          choice = "Answer Choices:"
          for c in json_res["question"]["choices"]:
              choice += " ("
              choice += c["label"]
              choice += ") "
              choice += c["text"]
          questions.append(json_res["question"]["stem"].strip() + " " + choice)
          answers.append(json_res["answerKey"])

    elif args.dataset in ("addsub", "multiarith", "singleeq"):
      with open(args.dataset_path) as f:
        json_data = json.load(f)
        for line in json_data:
          q = line["sQuestion"].strip()
          a = str(line["lSolutions"][0])
          if a[-2:] == ".0":
              a = a[:-2]
          questions.append(q)
          answers.append(a)
        
    elif args.dataset == "strategyqa":
      with open(args.dataset_path) as f:
        json_data = json.load(f)["examples"]
        for line in json_data:
          q = line["input"].strip()
          a = int(line["target_scores"]["Yes"])
          if a == 1:
              a = "yes"
          else:
              a = "no"
          questions.append(q)
          answers.append(a)
        questions = questions[:500]
        answers = answers[:500]

        
    elif args.dataset == "sarcasm":
        prompt = 'Task: Detect sarcasm, help me identify whether this sentence is sarcastic.' + '\n' \
                    'First, we need to understand what sarcasm is. Sarcasm is a form of verbal irony, '+ '\n' \
                    'where the intended meaning of the words is the opposite of the literal meaning. '+ '\n' \
                    'In other words, the speaker is saying one thing but meaning the opposite. '
        with open(args.dataset_path) as file:
            for line in file:
                data = json.loads(line)
                q = data["headline"] + prompt
                a = data["is_sarcastic"]
                if a == 1:
                    a = "yes"
                else:
                    a = "no"
                questions.append(q)
                answers.append(a)
        questions = questions[:1000]
        answers = answers[:1000]

    elif args.dataset == "svamp":
      with open(args.dataset_path) as f:
        json_data = json.load(f)
        for line in json_data:
            q = line["Body"].strip() + " " + line["Question"].strip()
            a = str(line["Answer"])
            if a[-2:] == ".0":
                a = a[:-2]
            questions.append(q)
            answers.append(a)
            
    elif args.dataset in ("bigbench_date", "object_tracking"):
      with open(args.dataset_path) as f:
        json_data = json.load(f)
        json_data = json_data["examples"]
        if args.dataset == "bigbench_date":
            choice_index = ['A','B','C','D','E','F']
        elif args.dataset in ("object_tracking"):
            choice_index = ['A','B','C']
        else:
            raise ValueError("dataset is not properly defined ...")
        for line in json_data:
          q = line["input"].strip()
          if args.dataset == "bigbench_date":
              choice = "Answer Choices:"
              # Randomly shuffle the answer choice dictionary because the original answer is always A ...
              choice_dic = shuffleDict(line["target_scores"])
          elif args.dataset == "object_tracking":
              choice = "\nWhich choice is true ? Answer Choices:"
              choice_dic = line["target_scores"]
          else:
              raise ValueError("dataset is not properly defined ...")
          for i, key_value in enumerate(choice_dic.items()):
              key, value = key_value
              choice += " ("
              choice += choice_index[i]
              choice += ") "
              choice += key
              if value == 1:
                  a = choice_index[i]
                  #a = key
          q = q + " " + choice
          questions.append(q)
          answers.append(a)            
    
    elif args.dataset == "riddlesense":
        choice_index = ['A','B','C','D','E']
        with open(args.dataset_path) as file:
            for line in file:
                data = json.loads(line)
                q = data["question"]['stem']
                choice = "Answer Choices:"
                for c in data["question"]["choices"]:
                    choice += " ("
                    choice += c["label"]
                    choice += ") "
                    choice += c["text"]
                q = q + " " + choice
                a = data["answerKey"]
                questions.append(q)
                answers.append(a)

    elif args.dataset == "macgyver":
        sheet_name = 'Sheet1'
        xl = pd.ExcelFile(args.dataset_path)
        df_problems = xl.parse(sheet_name) 
        #collecting non-forced solutions
        for _, row in df_problems.iterrows():
            instruction_vanilla = f'Give a valid (feasible and efficient) solution very concisely. Use step1, step2, etc, and mention the tools to achieve each step. Use as few steps as possible and the answer should ideally be less than 100 words. When there is not a feasible solution given the constraint and provided tools, just say that it is not possible and give a very short justification.'

            index = row['ID']
            problem3 = row['Problem']
            q = f"{problem3}\n\n{instruction_vanilla}"
            if row['Solvable?'] == 'Yes':
                a = "yes"
            else:
                a = "no"

            questions.append(q)
            answers.append(a)

    elif args.dataset == "brainteaser":
        data = []
        sentence_data_path = args.dataset_path + 'sentence_puzzle.npy'
        wordplay_data_list = args.dataset_path + 'word_puzzle.npy'
        sentence_data_list = list(np.load(sentence_data_path,allow_pickle=True))
        wordplay_data_list = list(np.load(wordplay_data_list,allow_pickle=True))
        
        data.extend(sentence_data_list)
        data.extend(wordplay_data_list)
        for sample in data:
            choice_list = ['A', 'B', 'C', 'D']
            q = """
    
            Question: {}
            Choice:
            (A) {}
            (B) {}
            (C) {}
            (D) {}

            Answer:""".format(sample['question'],sample['choice_list'][0],sample['choice_list'][1],sample['choice_list'][2],sample['choice_list'][3])

            questions.append(q)
            answers.append(choice_list[sample['label']])



    elif args.dataset in ("coin_flip", "last_letters"):
      with open(args.dataset_path) as f:
        json_data = json.load(f)
        json_data = json_data["examples"]
        for line in json_data:
          q = line["question"]
          a = line["answer"]
          questions.append(q)
          answers.append(a)
        
    else:
        raise ValueError("dataset is not properly defined ...")
    
    q_len_list = []
    for q in questions:
        q_len_list.append(len(q.split(" ")))
    q_len_mean = mean(q_len_list)
    
    print("dataset : {}".format(args.dataset))
    print("data size : {}".format(len(answers)))
    print("average num of words for each sample : {}".format(q_len_mean))
    
    return questions, answers

# Create dataset object before dataloader ...
class MyDataset(Dataset):
    def __init__(self, args):
        super().__init__()
        self.questions, self.answers = data_reader(args)
        self.len = len(self.questions)
        
    def __len__(self):
        return self.len
    
    def __getitem__(self, index):
        input = self.questions[index]
        output = self.answers[index]
        return input, output

def setup_data_loader(args):

    # fix randomness of dataloader to ensure reproducibility
    # https://pytorch.org/docs/stable/notes/randomness.html
    fix_seed(args.random_seed)
    worker_seed = torch.initial_seed() % 2**32
    print("worker_seed : {}".format(worker_seed))
    def seed_worker(worker_id):
        np.random.seed(worker_seed)
        random.seed(worker_seed)
    g = torch.Generator()
    g.manual_seed(worker_seed)
    
    dataloader_num_workers = multiprocessing.cpu_count()
    dataloader_num_workers = min(dataloader_num_workers, args.max_num_worker)
    print("dataloader_num_workers: " + str(dataloader_num_workers))
    
    dataset = MyDataset(args)
    
    dataloader = torch.utils.data.DataLoader(dataset,
                  shuffle=True,
                  batch_size=args.minibatch_size,
                  drop_last=False,
                  num_workers=dataloader_num_workers,
                  worker_init_fn=seed_worker,
                  generator=g,
                  pin_memory=True)

    return dataloader

# ver 0.2
def answer_cleansing(args, pred, must_choice=False):

    print("pred_before : " + pred)
    
    if args.method in ("few_shot", "few_shot_cot", "auto_cot"):
        preds = pred.split(args.direct_answer_trigger_for_fewshot)
        answer_flag = True if len(preds) > 1 else False 
        pred = preds[-1]

    if args.dataset in ("aqua", "commonsensqa", "riddlesense"):
        pred = re.findall(r'A|B|C|D|E', pred)
    elif args.dataset == "bigbench_date":
        pred = re.findall(r'A|B|C|D|E|F', pred)
    elif args.dataset in ("object_tracking"):
        pred = re.findall(r'A|B|C', pred)
    elif args.dataset in ("gsm8k", "addsub", "multiarith", "svamp", "singleeq", "brainteaser"):
        if must_choice:
            pred = re.findall(r'A|B|C|D', pred)
        else:
            pred = pred.replace(",", "")
            pred = [s for s in re.findall(r'-?\d+\.?\d*', pred)]
    elif args.dataset in ("strategyqa", "coin_flip", "sarcasm", 'macgyver'):
        pred = pred.lower()
        pred = re.sub("\"|\'|\n|\.|\s|\:|\,"," ", pred)
        pred = pred.split(" ")
        pred = [i for i in pred if i in ("yes", "no")]
    elif args.dataset == "last_letters":
        pred = re.sub("\"|\'|\n|\.|\s","", pred)
        pred = [pred]
    else:
        raise ValueError("dataset is not properly defined ...")

    # If there is no candidate in list, null is set.
    if len(pred) == 0:
        pred = ""
    else:
        if args.method in ("few_shot", "few_shot_cot", "auto_cot"):
            if answer_flag:
                # choose the first element in list ...
                pred = pred[0]
            else:
                # choose the last element in list ...
                pred = pred[-1]
        elif args.method in ("zero_shot", "zero_shot_cot"):
            # choose the first element in list ...
            pred = pred[0]
        else:
            raise ValueError("method is not properly defined ...")
    
    # (For arithmetic tasks) if a word ends with period, it will be omitted ...
    if pred != "":
        if pred[-1] == ".":
            pred = pred[:-1]
    
    print("pred_after : " + pred)
    
    return pred

def create_demo_text(args, cot_flag):
    x, z, y = [], [], []
    
    with open(args.demo_path, encoding="utf-8") as f:
        json_data = json.load(f)
        json_data = json_data["demo"]
        for line in json_data:
            x.append(line["question"])
            z.append(line["rationale"])
            y.append(line["pred_ans"])

    index_list = list(range(len(x)))
    
    demo_text = ""
    for i in index_list:
        if cot_flag:
            demo_text += x[i] + " " + z[i] + " " + \
                         args.direct_answer_trigger_for_fewshot + " " + y[i] + ".\n\n"
        else:
            demo_text += x[i] + " " + args.direct_answer_trigger_for_fewshot + " " + y[i] + ".\n\n"
    return demo_text

def answer_cleansing_zero_shot(args, pred, must_choice=False):
    pred = pred.strip()
    if args.dataset in ("aqua", "commonsensqa"):
        pred = re.findall(r'A|B|C|D|E', pred)
    elif args.dataset == "bigbench_date":
        pred = re.findall(r'A|B|C|D|E|F', pred)
    elif args.dataset in ("object_tracking"):
        pred = re.findall(r'A|B|C', pred)
    elif args.dataset in ("gsm8k", "addsub", "multiarith", "svamp", "singleeq", "brainteaser"):
        if must_choice:
            pred = re.findall(r'A|B|C|D', pred)
        else:
            pred = pred.replace(",", "")
            pred = [s for s in re.findall(r'-?\d+\.?\d*', pred)]
    elif args.dataset in ("strategyqa", "coin_flip", "sarcasm"):
        pred = pred.lower()
        pred = re.sub("\"|\'|\n|\.|\s|\:|\,", " ", pred)
        pred = pred.split(" ")
        pred = [i for i in pred if i in ("yes", "no")]
    elif args.dataset == "last_letters":
        pred = re.sub("\"|\'|\n|\.|\s", "", pred)
        pred = [pred]
    else:
        raise ValueError("dataset is not properly defined ...")

    # If there is no candidate in list, null is set.
    if len(pred) == 0:
        pred = ""
    else:
        # choose the first element in list ...
        pred = pred[0]

    # (For arithmetic tasks) if a word ends with period, it will be omitted ...
    if pred != "":
        if pred[-1] == ".":
            pred = pred[:-1]

    return pred
