import string
import random
import sys, os


sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from downstream_datasets.mmlu_pro_loader import load_mmlu_pro_data, mmlu_pro_prompt_transform
from downstream_datasets.mmlu_loader import load_mmlu_data, mmlu_prompt_transform
from downstream_datasets.human_eval_loader import load_humaneval_data, humaneval_prompt_transform
from downstream_datasets.lambada_loader import load_lambada_data, lambada_prompt_transform
from downstream_datasets.arc_loader import load_arc_data, arc_prompt_transform
from downstream_datasets.glue_loader import load_glue_data, glue_prompt_transform
from downstream_datasets.nq_loader import load_nq_data, nq_prompt_transform, get_all_short_answer_texts
from downstream_datasets.drop_loader import load_drop_data, drop_prompt_transform
from downstream_datasets.gsm8k_loader import load_gsm8k_data, gsm8k_prompt_transform
from downstream_datasets.triviaqa_loader import load_triviaqa_data, triviaqa_prompt_transform
from downstream_datasets.boolq_loader import load_boolq_data, boolq_prompt_transform
from downstream_datasets.mbpp_loader import load_mbpp_data, mbpp_prompt_transform
from downstream_datasets.truthfulqa_loader import load_truthfulqa_data, truthfulqa_prompt_transform

import re

import json
import string

import json
import string
import sys
import io

def eval_func_mbpp(answer, pred, prompt, model=None, tokenizer=None, inputs=None, trainer=None, **kwargs):
    i = 0
    if inputs is not None and "prompt" in inputs:
        for idx, p in enumerate(inputs['prompt']):
            if p == prompt:
                i = idx
                break


    test_list = None
    if inputs is not None and "test_list" in inputs:
        if isinstance(inputs['test_list'], list) and len(inputs['test_list']) > i:
            test_list = inputs['test_list'][i]
        else:
            test_list = None

    try:
        code = pred.strip()
        old_stdout = sys.stdout
        sys.stdout = io.StringIO()
        exec_globals = {}
        exec(code, exec_globals)

        if test_list is not None:
            for case in test_list:
                exec(case, exec_globals)
        else:
            sys.stdout = old_stdout
            return False

        sys.stdout = old_stdout
        return True
    except Exception as e:
        sys.stdout = old_stdout
        return False

def normalize_text(s):
    s = s.strip().lower().translate(str.maketrans('', '', string.punctuation))
    return s

def extract_triviaqa_candidates(answer):
    if isinstance(answer, str):
        try:
            answer_json = json.loads(answer)
            answer = answer_json
        except Exception:
            return [normalize_text(answer)]
    candidates = []
    if isinstance(answer, dict):
        for key in ['value', 'normalized_value']:
            if key in answer and answer[key]:
                candidates.append(normalize_text(answer[key]))
        for key in ['aliases', 'normalized_aliases']:
            items = answer.get(key, [])
            if isinstance(items, str):
                items = [items]
            for v in items:
                candidates.append(normalize_text(v))
    elif isinstance(answer, list):
        for a in answer:
            candidates.extend(extract_triviaqa_candidates(a))
    else:
        candidates.append(normalize_text(str(answer)))
    return list(set([c for c in candidates if c]))

def extract_answer_string_from_generation(gen_text):
    if "| answer:" in gen_text.lower():
        answer_part = gen_text.lower().split("| answer:", 1)[1].strip()
        try:
            answer_json = json.loads(answer_part.replace("'", '"'))
            preds = extract_triviaqa_candidates(answer_json)
            return preds
        except Exception:
            return [normalize_text(answer_part)]
    try:
        answer_json = json.loads(gen_text)
        preds = extract_triviaqa_candidates(answer_json)
        return preds
    except Exception:
        return [normalize_text(gen_text)]
    
def eval_func_triviaqa(answer, pred, prompt, model=None, tokenizer=None, inputs=None, trainer=None):
    gt_candidates = extract_triviaqa_candidates(answer)
    pred_candidates = extract_answer_string_from_generation(pred)
    for pc in pred_candidates:
        for ac in gt_candidates:
            if ac in pc or pc in ac:
                return 1
    return 0

def gsm8k_postprocess(ans: str):
    match = re.search(r"####\s*([-+]?\d*\.?\d+)", ans)
    if match:
        return match.group(1).strip()
    nums = re.findall(r"[-+]?\d*\.?\d+", ans)
    if nums:
        return nums[-1]
    return ans.strip()


import re

import re

def gsm8k_postprocess(ans):
    if not isinstance(ans, str):
        ans = str(ans)
    match = re.search(r'####\s*(-?\d+\.?\d*)', ans)
    if match:
        return match.group(1)
    match = re.findall(r'>>\s*(-?\d+\.?\d*)', ans)
    if match:
        return match[-1]
    nums = re.findall(r'(-?\d+\.?\d*)', ans)
    if nums:
        return nums[-1]
    return ans.strip()  

def gsm8k_eval_func(answer, pred, prompt, **kwargs):
    gt = gsm8k_postprocess(answer)
    pd = gsm8k_postprocess(pred)
    try:
        return float(gt) == float(pd)
    except Exception as e:
        return False

import sys
import io

def eval_func_humaneval(answer, pred, prompt, test_list=None, test_code=None, **kwargs):

    try:
        code = pred.strip()
        old_stdout = sys.stdout
        sys.stdout = io.StringIO()

        exec_globals = {}
        exec(code, exec_globals)

        if test_list is not None:
            for case in test_list:
                exec(case, exec_globals)
        elif test_code is not None:
            exec(test_code, exec_globals)
        else:
            sys.stdout = old_stdout
            return False  

        sys.stdout = old_stdout
        return True
    except Exception as e:
        sys.stdout = old_stdout
        return False


 

def substring_until(s, split_strs):
    idx = len(s)
    for split_str in split_strs:
        try:
            new_idx = s.index(split_str)
            if new_idx < idx:
                idx = new_idx
        except Exception:
            pass
    return s[:idx]

import re

def clean_text(s):
    return re.sub(r'[^\w\s]', '', s.strip().lower())

def eval_func_truthfulqa(y_true_item, y_pred, prompt, **kwargs):
    mc_targets = y_true_item.get("mc1_targets", {})
    choices = mc_targets.get("choices", [])
    labels = mc_targets.get("labels", [])
    if labels and choices:
        try:
            true_idx = labels.index(1)
            true_answer = choices[true_idx]
        except Exception:
            return 0.0
    else:
        return 0.0
    pred_norm = clean_text(y_pred)
    true_norm = clean_text(true_answer)
    if pred_norm == true_norm or true_norm in pred_norm or pred_norm in true_norm:
        return 1.0
    return 0.0

def pred_postprocess_default(pred):
    pred = pred.strip().lower()
    return substring_until(pred, ['\n']).strip().lower().translate(str.maketrans('', '', string.punctuation))

def eval_func_default(answer, pred, prompt, model=None, tokenizer=None, inputs=None, trainer=None):
    import string
    if not isinstance(answer, list):
        answer = [answer.strip().lower().translate(str.maketrans('', '', string.punctuation))]
    else:
        answer = [a.strip().lower().translate(str.maketrans('', '', string.punctuation)) for a in answer]
    pred_processed = pred.strip().lower().translate(str.maketrans('', '', string.punctuation))
    return any(ans in pred_processed for ans in answer)

def get_all_short_answer_texts(ex):
    annotations = ex.get("annotations", {})
    if not isinstance(annotations, dict):
        return []
    short_answers = annotations.get('short_answers', [])
    all_texts = []
    if isinstance(short_answers, list):
        for sa in short_answers:
            if isinstance(sa, dict):
                ts = sa.get('text', [])
                if isinstance(ts, list):
                    all_texts.extend([x for x in ts if x.strip()])
    return list(set(all_texts))  

def eval_func_nq(answer, pred, prompt, model=None, tokenizer=None, inputs=None, trainer=None):
    import string
    if isinstance(answer, str):
        answer_list = [a.strip() for a in answer.split(";") if a.strip()]
    elif isinstance(answer, list):
        answer_list = [str(a).strip() for a in answer if str(a).strip()]
    else:
        answer_list = [str(answer).strip()]
    table = str.maketrans('', '', string.punctuation)
    answer_list = [a.lower().translate(table) for a in answer_list]
    pred_processed = pred.strip().lower().translate(table)
    for ans in answer_list:
        if ans and (ans in pred_processed or pred_processed in ans):
            return True
    return False



def get_eval_dataset(dataset_name, num_shots, seed=42):
    # defaults
    top_k = 1
    top_p = 0
    temperature = 1
    max_new_tokens = 20
    shuffle_train = True

    eval_func = eval_func_default
    pred_postprocess_func = pred_postprocess_default

    
    if dataset_name == 'mmlu_pro':
        items = load_mmlu_pro_data(
            data_dir="./your_data_dir/test/mmlu_pro/data",
        )
        dataset_val = items
        dataset_train = None  
        prompt_transform = mmlu_pro_prompt_transform
        task_type = "choice"

    elif dataset_name == 'mmlu':
        test_items, support_items = load_mmlu_data()
        dataset_val = test_items
        dataset_train = None
        prompt_transform = mmlu_prompt_transform
        task_type = "choice"

    elif dataset_name == 'human_eval':
        items = load_humaneval_data(
            parquet_path="./your_data_dir/test/humaneval",
            seed=seed
        ) 
        dataset_val = items
        dataset_train = None  
        prompt_transform = humaneval_prompt_transform
        eval_func = eval_func_humaneval
        task_type = "generative"

    elif dataset_name == 'lambada':
        val_items, train_items = load_lambada_data(
            data_dir="./your_data_dir/test/lambada",
            seed=seed
        )
        dataset_val = val_items
        task_type = "generative"
        dataset_train = None   
        prompt_transform = lambada_prompt_transform

    elif dataset_name == 'arc':
        train_items, val_items, test_items = load_arc_data(
            data_dir="./test/arc",
            seed=seed
        )
        dataset_val = val_items       
        dataset_train = train_items   
        prompt_transform = arc_prompt_transform
        task_type = "choice"

    
    elif dataset_name == 'glue':
        train_items, val_items, test_items = load_glue_data(
            train_file="./your_data_dir/test/glue",
            seed=seed
        )
        dataset_val = val_items       
        dataset_train = train_items    
        prompt_transform = glue_prompt_transform
        task_type = "choice"

    elif dataset_name == 'nq':
        train_items, val_items = load_nq_data(
            data_dir="./your_data_dir/test/natural_questions",
            seed=seed
        )
        dataset_val = [ex for ex in val_items if any(get_all_short_answer_texts(ex))]
        dataset_train = train_items
        task_type = "generative"
        prompt_transform = nq_prompt_transform

    elif dataset_name == 'triviaqa':
        train_items, val_items, test_items = load_triviaqa_data(
            data_dir="./your_data_dir/test/trivia_qa",
            seed=seed
        )
        dataset_val = val_items
        dataset_train = train_items
        prompt_transform = triviaqa_prompt_transform
        eval_func = eval_func_triviaqa
        task_type = "generative"

    elif dataset_name == 'drop':
        val_items, train_items = load_drop_data(
            data_dir="./your_data_dir/test/drop",
            seed=seed
        )
        dataset_val = val_items
        dataset_train = train_items
        task_type = "generative"
        prompt_transform = drop_prompt_transform

    elif dataset_name == 'gsm8k':

        train_items, test_items = load_gsm8k_data(
            data_dir="./your_data_dir/test/gsm8k",
        )
        dataset_val = test_items         
        dataset_train = train_items
        prompt_transform = gsm8k_prompt_transform
        eval_func = gsm8k_eval_func
        task_type = "generative"
        pred_postprocess_func = gsm8k_postprocess

    elif dataset_name == 'boolq':
        train_items, val_items = load_boolq_data(
            data_dir="./your_data_dir/test/boolq",
            seed=seed
        )
        dataset_val = val_items
        dataset_train = train_items
        prompt_transform = boolq_prompt_transform
        task_type = "choice"

    elif dataset_name == "mbpp":
        train_items, val_items, test_items, prompt_items = load_mbpp_data(
            data_dir="./your_data_dir/test/mbpp",
            seed=seed
        )
        dataset_val = val_items
        dataset_train = train_items
        test_list = test_items
        task_type = "generative"
        prompt_transform = mbpp_prompt_transform
        eval_func = eval_func_mbpp

    elif dataset_name == "truthfulqa":
        val_items = load_truthfulqa_data(
            data_dir="./your_data_dir/test/truthful_qa",
            seed=seed
        )
        dataset_val = val_items
        dataset_train = None  
        task_type = "generative"
        prompt_transform = truthfulqa_prompt_transform
        eval_func = eval_func_truthfulqa
    
    else:
        raise ValueError(f"Dataset {dataset_name} not supported")



    from datasets import Dataset
    if isinstance(dataset_train, list):
        dataset_train = Dataset.from_list(dataset_train)
    if isinstance(dataset_val, list):
        dataset_val = Dataset.from_list(dataset_val)


    return {
        'top_k': top_k,
        'top_p': top_p,
        'temperature': temperature,
        'num_shots': num_shots,
        'max_new_tokens': max_new_tokens,
        'prompt_transform': prompt_transform,
        'dataset_train': dataset_train,
        'shuffle_train': shuffle_train,
        'dataset_val': dataset_val,
        'eval_func': eval_func,
        'pred_postprocess_func': pred_postprocess_func,
        "task_type": "generative",
    }
    


   


    