import json
import argparse

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from experiment.experiment import Experiment

from utils.utils import timestamp

"""
Parse arguments passed to script.
"""
parser = argparse.ArgumentParser("run config runner for metrics across prompts")
parser.add_argument("--data", action="store", type=str)
parser.add_argument("--model", action="store", type=str)
parser.add_argument("--config", action="store", type=str)
parser.add_argument("--start", action="store", type=int)
parser.add_argument("--end", action="store", type=int)
args = parser.parse_args()

"""
Set up model and data configs.
"""

model_name = args.model
data_name = args.data
config = json.load(open(args.config, "r"))
data_config = config["DATA_CONFIG"][data_name]

start = args.start
end = args.end

#timestamp("Loaded config " + str(config))

print(start, end)

"""
Load model from specified path.
"""
model_path = config["MODEL_PATH"][model_name]
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    device_map="auto",
    load_in_4bit=True
)
timestamp("Loaded model from " + model_path)

tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    use_fast=True
)
timestamp("Loaded tokenizer from " + model_path)

dataset = load_dataset(*data_config.values())
timestamp("Loaded " + data_name + " from cache.")

"""
Run experiment.
"""
experiment = Experiment(
    model,
    tokenizer,
    model_name,
    out_dir=config["OUTPUT_PATH"],
    **config["EXPERIMENT_CONFIG"]
)
timestamp("Created experiment " + str(config["EXPERIMENT_CONFIG"]))

outputs, out_file = experiment.run(
    dataset=dataset,
    data_name=data_name,
    start=start,
    end=end,
    save=True
)
print("Finished experiment and saved in  " + out_file)


