import os
import json
import numpy as np
import random
import torch
import re
import random
import time
import datetime
from torch.utils.data import Dataset
import yaml
from configs import *
from nltk import sent_tokenize

def extract_words(sentence, word_list):
    # Create a pattern dynamically from the word list
    pattern = r'\b(' + '|'.join(map(re.escape, word_list)) + r')\b'
    return re.findall(pattern, sentence, re.IGNORECASE)

def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # torch.backends.cudnn.deterministic = True

def get_now():
    now = datetime.datetime.now()
    now = now.strftime('%Y/%m/%d %H:%M:%S')
    return now

def test_data_reader(args):
    questions, answers = [], []
    with open(args.dataset_path) as f:
        json_data = json.load(f)
        examples = json_data["examples"]
        for e in examples:
            questions.append(e["input"])
            answers.append(e["target"])
    out = {"questions": questions, "answers": answers}
    return out

class AnswerCleaner:
    def __init__(self, args):
        self.dataset = args.dataset
        self.few = True if "few" in args.framework else False
        
    def get_pred(self, lm_output, to_detect=None):
        pred_sentence = self.pred_seperator(lm_output)
        preds = self.pred_cleaner(pred_sentence, to_detect=to_detect)
        pred = self.pred_selector(preds)
        return pred

    def pred_seperator(self, lm_output):
        sentence_tokens = sent_tokenize(lm_output)
        if len(sentence_tokens) == 0:
            pred_sentence = ""
        else:
            pred_sentence = sentence_tokens[-1] if self.few else sentence_tokens[0]
        return pred_sentence

    def pred_selector(self, preds):
        if len(preds) == 0:
            pred = ""
        else:
            if self.few:
                pred = preds[-1]
            else:
                pred = preds[0] if DATASET_TYPE[self.dataset] not in ("num", "tf") else preds[-1]
            
        return pred
    
    def pred_cleaner(self, pred_sentence, to_detect=None):
        if DATASET_TYPE[self.dataset] == "num":
            pred_sentence = pred_sentence.replace(",", "")
            preds = re.findall(r'-?\d+\.?\d*', pred_sentence)
            preds = [s[:-1] if s.endswith(".") else s for s in preds]
            preds = [anynum2str(p) for p in preds]
        elif DATASET_TYPE[self.dataset] == "yn":
            pred_sentence = pred_sentence.lower()
            preds = re.findall(r'\b(yes|no|true|false)\b', pred_sentence, re.IGNORECASE)
            preds = ["yes" if p in ("yes", "true") else "no" for p in preds]
        elif DATASET_TYPE[self.dataset] == "word":
            assert to_detect is not None, "to_detect should be given"
            pred_sentence = re.sub('[^a-zA-Z ]', '', pred_sentence)
            pred_sentence = pred_sentence.lower().strip()
            to_detect = [to_detect]
            preds = extract_words(pred_sentence, to_detect)
        elif DATASET_TYPE[self.dataset] == "mc":
            preds = re.findall(r'\b[A-Z]\b', pred_sentence)
        else:
            raise NotImplementedError
        return preds

    def gt_cleaning(self, y):
        if DATASET_TYPE[self.dataset] == "num":
            y = self.arithmetic_gt_cleaner(y)
        elif DATASET_TYPE[self.dataset] in ("yn", "word"):
            y = y.lower().strip()
        elif DATASET_TYPE[self.dataset] in ("mc"):
            y = y.replace("(", "").replace(")", "").strip()
        else:
            raise NotImplementedError
        return y

    def arithmetic_gt_cleaner(self, y):
        y = y.replace(",", "")
        y = anynum2str(y)            
        return y

def anynum2str(num):
    num = "{:.2f}".format(round(float(num), 3))
    return num

class ReportManager:
    def __init__(self, args):
        self.args = args
        self.framework = args.framework
        self.model = args.model
        self.dataset = args.dataset

        # for logging
        self.start_time = time.time()
        self.start_datetime = get_now()
        self.whole_process = []
        self.preds = []
        self.gts = []
        self.prompt_examples = []
        self.config = {}
        self.api_usage = 0

        # for evaluation
        self.total = 0
        self.correct = 0
        self.correct_idx = []
        self.incorrect_idx = []
        self.unknown_idx = []

        # for saving
        self.line_header = []
        self.line_ios = []
        self.line_eval = []

    def update_log(self, whole_process, cleaned_preds, labels, api_usage, prompt_examples, config):
        self.whole_process.extend(whole_process)
        self.preds.extend(cleaned_preds)
        self.gts.extend(labels)
        self.api_usage += api_usage
        self.prompt_examples = prompt_examples
        self.config = config
        self.total += len(labels)

    def connect_log_file(self):
        date = time.strftime("%m%d")
        logpath = os.path.join(self.args.log_dir, self.model, self.args.dataset, self.framework, date)
        os.makedirs(logpath, exist_ok=True)
        
        clock = time.strftime("%H%M")
        logfile = "-".join([self.framework, self.model])
        logfile += "-" + clock

        self.log_to = os.path.join(logpath, logfile + ".log")
        self.log = open(self.log_to, "w")

        print(f"> EXPERIMENT DATE: {get_now()}")
        print(f"> MODEL: {self.model}, DATASET: {self.args.dataset}")
        print(f"> FRAMEWORK: {self.framework}, LOG_FILE: {self.log_to}")

    def evaluate(self):
        print(f"> EVALUATION TIME: {get_now()}")
        print(f"> LENGTH OF PREDICTIONS: {len(self.preds)}")
        print(f"> LENGTH OF GROUND TRUTHS: {len(self.gts)}")
        assert len(self.preds) == len(self.gts), "preds and gts should have the same length"
        for idx, (pred, gt) in enumerate(zip(self.preds, self.gts)):
            if str(pred).lower() == str(gt).lower():
                self.correct += 1
                self.correct_idx.append(idx)
            elif pred == "":
                self.unknown_idx.append(idx)
            else:
                self.incorrect_idx.append(idx)

    def write_header(self):
        self.line_header.append("=" * 50)
        self.line_header.append(str(self.args))
        self.line_header.append("=" * 50)
        self.line_header.append(f"> EXPERIMENT DATE: {get_now()}")
        self.line_header.append(f"> MODEL: {self.model}, DATASET: {self.args.dataset}")
        self.line_header.append(f"> FRAMEWORK: {self.framework}, LOG_FILE: {self.log_to}")
        self.line_header.append("=" * 50)
        for si in self.prompt_examples:
            self.line_header.append(si)

    def write_sample_io(self):
        for idx, (x, pred, y) in enumerate(zip(self.whole_process, self.preds, self.gts)):
            self.line_ios.append("=" * 50)
            self.line_ios.append(f"{idx}st data")
            self.line_ios.append(x)
            self.line_ios.append(f"\npred : {pred}")
            self.line_ios.append(f"GT : {y}")
            self.line_ios.append("=" * 50)

    def write_eval_result(self):
        self.evaluate()
        acc = round(self.correct / self.total * 100, 2)
        self.line_eval.append("=" * 50)
        self.line_eval.append(f"accuracy : {acc}")
        self.line_eval.append(f"num_correct : {self.correct}")
        self.line_eval.append(f"num_incorrect : {len(self.incorrect_idx)}")
        self.line_eval.append(f"num_unknown : {len(self.unknown_idx)}")
        self.line_eval.append("=" * 50)
        self.line_eval.append(f"correct_idx : {self.correct_idx}")
        self.line_eval.append(f"incorrect_idx : {self.incorrect_idx}")
        self.line_eval.append(f"unknown_idx : {self.unknown_idx}")
        self.line_eval.append("=" * 50)
        self.line_eval.append(f"total_time : {time.time() - self.start_time}")
        self.line_eval.append(f"total_api_usage : {self.api_usage}")
        self.line_eval.append("=" * 50)

    def save_report(self):
        self.connect_log_file()
        with open(self.log_to.replace(".log", "-config.yaml"), "w") as f:
            yaml.dump(self.config, f, default_flow_style=False)
        self.write_header()
        self.write_sample_io()
        self.write_eval_result()

        to_write = self.line_header + self.line_ios + self.line_eval
        self.log.write("\n".join(to_write))
        
        print("Evaluation Completed.")
        print(f"> Model: {self.model}, Framework: {self.framework}")
        print(f"> Total time: {time.time() - self.start_time}")
        print(f"> Total API usage: {self.api_usage}")
        print(f"> Accuracy: {round(self.correct / self.total * 100, 2)}")

    def __del__(self):
        self.log.close()