import json
import argparse
import os
import copy
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig

from experiment.experiment import Experiment

from utils.utils import timestamp, get_remove_mlp, get_remove_mha, get_remove_ln

"""
Parse arguments passed to script.
"""
parser = argparse.ArgumentParser("run config runner for metrics across prompts")
# parser.add_argument("--data", action="store", type=str)
parser.add_argument("--model", action="store", type=str)
parser.add_argument("--untrained-model", action="store", type=str)
parser.add_argument("--config", action="store", type=str)
parser.add_argument("--start", action="store", type=int)
parser.add_argument("--end", action="store", type=int)
parser.add_argument("--quantize", action="store_false", default=True) # Quantize model when loading?
parser.add_argument("--untrained", action="store_true", default=False) # Load untrained model?
parser.add_argument("--trained", action="store_false", default=True) # Load trained model?

args = parser.parse_args()

print("trained = ", args.trained)
print("untrained = ", args.untrained)

quit()

# device = torch.device("cuda:0")


"""
Set up model and data configs.
"""

config = json.load(open(args.config, "r"))
experiment_files = os.listdir(experiment_dir)


"""
Load model from specified path.
"""

model_name = args.model
model_path = config["MODEL_PATH"][model_name]

timestamp("Loading tokenizer from " + model_path)
tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    use_fast=True
)
timestamp(f"{model_name} tokenizer has been successfully loaded from {model_path}")

if args.trained:
    timestamp(f"Loading {model_name} from {model_path}. Quantize = {args.quantize}.")
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        # device_map="auto",
        trust_remote_code=True,
        load_in_4bit=args.quantize
    )
    models = [model]
    model_types = ["standard"]
    timestamp(f"{model_name} has been successfully loaded from {model_path}")

if args.untrained:
    timestamp(f"Loading untrained {model_name}")
    config_untrained = AutoConfig.from_pretrained(model_path)
    untrained_model =  AutoModelForCausalLM.from_config(config_untrained)
    timestamp(f"Untrained {model_name} has been successfully loaded")
    models.append(untrained_model)
    model_types.append("untrained")


for data_key in config["DATA_CONFIG"].keys():
    data_config = config["DATA_CONFIG"][data_key]
    data_name = data_key
    timestamp(f"Running over the {data_key} dataset")

    dataset = load_dataset(*data_config.values())
    timestamp("Loaded " + data_name + " from cache.")

    """
    Run experiment.
    """
    experiment = Experiment(
        models,
        model_types,
        tokenizer,
        model_name,
        hidden_reps         = False,
        pca                 = False,
        lss                 = True,
        diffs               = True,
        head_entropy        = False,
        equidistance        = True,
        expodistance        = False, 
        expodistance2       = True,
        attention_var       = False,
        attention_var2      = False,
        norm                = False,
        similarity          = False,
        similarity_norm     = False,
        classifier          = False,
        rep_var             = False, 
        rep_cs              = False, 
        rep_classifier      = False,
        rep_class_cs        = False,
        out_dir=config["OUTPUT_PATH"]
    )
    # timestamp("Created experiment " + str(config["EXPERIMENT_CONFIG"]))

    outputs, out_file = experiment.run(
        dataset=dataset,
        data_name=data_name,
        start=args.start,
        end=args.end,
        save=True
    )
    timestamp(f"Finished experiment on {model_name} over {data_name} dataset and saved in  " + out_file)




