import json
import argparse
import os

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from experiment.experiment import Experiment

from utils.utils import timestamp

"""
Parse arguments passed to script.
"""
parser = argparse.ArgumentParser("run config runner for metrics across prompts")
# parser.add_argument("--data", action="store", type=str)
# parser.add_argument("--model", action="store", type=str)
parser.add_argument("--config", action="store", type=str)
parser.add_argument("--start", action="store", type=int)
parser.add_argument("--end", action="store", type=int)
parser.add_argument("--all", action="store_false", default=True)
args = parser.parse_args()

"""
Set up model and data configs.
"""

config = json.load(open(args.config, "r"))

print("Running all models = ", args.all)

experiment_files = os.listdir(experiment_dir)

for model_key in config["MODEL_PATH"].keys():

    model_name = model_key

    print(f"Loading {model_name}")

    start = args.start
    end = args.end

    #timestamp("Loaded config " + str(config))

    """
    Load model from specified path.
    """
    model_path = config["MODEL_PATH"][model_name]

    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        device_map="auto",
        load_in_4bit=True
    )

    timestamp("Loaded model from " + model_path)

    tokenizer = AutoTokenizer.from_pretrained(
        model_path,
        use_fast=True
    )
    timestamp("Loaded tokenizer from " + model_path)


    for data_key in config["DATA_CONFIG"].keys():

        data_config = config["DATA_CONFIG"][data_key]

        data_name = data_key

        print(f"Running over the {data_key} dataset")

        try:
            dataset = load_dataset(*data_config.values())
            timestamp("Loaded " + data_name + " from cache.")

            """
            Run experiment.
            """
            experiment = Experiment(
                model,
                tokenizer,
                model_name,
                out_dir=config["OUTPUT_PATH"],
                **config["EXPERIMENT_CONFIG"]
            )
            timestamp("Created experiment " + str(config["EXPERIMENT_CONFIG"]))

            outputs, out_file = experiment.run(
                dataset=dataset,
                data_name=data_name,
                start=start,
                end=end,
                save=True
            )
            print(f"Finished experiment on {model_name} over {data_name} dataset and saved in  " + out_file)

        except Exception as e:
            print(f"An error occurred: {e}")


