import os, sys
path = os.path.abspath(os.getcwd())
sys.path.append(path)
from model_arithmetic import Evaluation, ModelArithmetic, load_model, LLMPrompt, Max, KL_indicator, enable_logging
import torch
from loguru import logger
from transformers import set_seed

import tensorflow as tf

enable_logging()

# Necessary in order to avoid the small BLEURT model to take up all GPU memory
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
    # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
    # Visible devices must be set before GPUs have been initialized
        print(e)

BASE_EVAL_PATH = "eval/performance"

def evaluate(task_name, formula, save_path, default_model, num_fewshot=0, limit=None, no_cache=False, batch_size=1, dtype=torch.float16, output_folder=None):
    set_seed(42)

    model_args = None

    evaluator = Evaluation()
    
    if isinstance(formula, tuple):
        retroactive = [formula[1]]
        formula = formula[0]
    else:
        retroactive = []
    arithmetic = ModelArithmetic(formula, default_model=default_model, retroactive_operators=retroactive, dtype=dtype, 
                                  needs_input_tokens_lm_eval=True, lm_eval_task=task_name)

    evaluator.evaluate_lm_eval(model=arithmetic, model_args=model_args, task_name=task_name, batch_size=batch_size, 
                               num_fewshot=num_fewshot, 
                               limit=limit, no_cache=no_cache, save_path=save_path, write_out=True, output_folder=output_folder)
    evaluator.save(save_path)
    
def eval_multiple(formula, datasets, name, limit=None, num_fewshot=0, batch_size=1):
    os.makedirs(os.path.join(BASE_EVAL_PATH, name), exist_ok=True)
    with open(os.path.join(BASE_EVAL_PATH, name, "formula.txt"), 'w') as outfile:
        outfile.write(str(formula))
    for dataset in datasets:
        evaluate(
            formula=formula,
            default_model="meta-llama/Llama-2-13b-hf",
            task_name=dataset,
            num_fewshot=num_fewshot,
            limit=limit,
            no_cache=True,
            save_path=os.path.join(BASE_EVAL_PATH, name, f"{dataset}_eval.json"),
            batch_size=batch_size,
            dtype=torch.bfloat16,
            output_folder=os.path.join(BASE_EVAL_PATH, name)
        )



if __name__ == "__main__":
    with logger.catch():
        gpt2xl = LLMPrompt("", prompt_template=lambda e, f: f"{f}", model="gpt2-xl")
        gpt2xl_no_context = LLMPrompt("", prompt_template=lambda e, f: f"", model="gpt2-xl")
        gpt2 = LLMPrompt("", prompt_template=lambda e, f: f"{f}", model="gpt2")
        gpt2_no_context = LLMPrompt("", prompt_template=lambda e, f: f"", model="gpt2")
        formulas = [
            gpt2xl,
            1.5 * gpt2xl - 0.5 * gpt2xl_no_context,
            1.5 * gpt2xl - 0.5 * gpt2_no_context,
            1.5 * gpt2xl - 0.5 * (gpt2xl - gpt2 + gpt2_no_context),
            gpt2xl - 0.5 * Max(gpt2xl_no_context - gpt2xl, 0),
            gpt2xl - 1 * Max(gpt2xl_no_context - gpt2xl, 0),
            gpt2xl - 5 * Max(gpt2xl_no_context - gpt2xl, 0),
            gpt2xl - 100 * Max(gpt2xl_no_context - gpt2xl, 0),
        ]

        for index, formula in enumerate(formulas):
            if index < 0:
                continue
            eval_multiple(
                formula=formula,
                datasets=["hellaswag", "lambada_openai", "winogrande", "arc_easy", "boolq", "arc_challenge", "piqa", "sciq"],
                # datasets=["crows_pairs_english"],
                name=str(index),
                limit=1000,
            )