import json
import argparse
import os
import copy
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig

from experiment.experiment import Experiment

from utils.utils import timestamp, get_remove_mlp, get_remove_mha, get_remove_ln
from utils.output import plot_pca
from utils.output import get_representations

"""
Parse arguments passed to script.
"""
parser = argparse.ArgumentParser("run config runner for metrics across prompts")
parser.add_argument("--data", action="store", type=str)
parser.add_argument("--model", action="store", type=str)
parser.add_argument("--untrained-model", action="store", type=str)
parser.add_argument("--config", action="store", type=str)
parser.add_argument("--start", action="store", type=int)
parser.add_argument("--end", action="store", type=int)
parser.add_argument("--quantize", action="store_false", default=True) # Quantize model when loading?
parser.add_argument("--untrained", action="store_true", default=False) # Load untrained model?
parser.add_argument("--trained", action="store_true", default=False) # Load trained model?

args = parser.parse_args()

# device = torch.device("cuda:0")


"""
Set up model and data configs.
"""

config = json.load(open(args.config, "r"))
experiment_files = os.listdir(experiment_dir)


"""
Load model from specified path.
"""


model_name = args.model
model_path = config["MODEL_PATH"][model_name]

models = []
model_types = []

timestamp("Loading tokenizer from " + model_path)
tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    use_fast=True
)
timestamp(f"{model_name} tokenizer has been successfully loaded from {model_path}")


if args.trained:
    timestamp(f"Loading {model_name} from {model_path}")
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        # device_map="auto",
        trust_remote_code=True,
        load_in_4bit=args.quantize
    )
    models.append(model)
    model_types.append("standard")

    timestamp(f"{model_name} has been successfully loaded from {model_path}")


if args.untrained:
    timestamp(f"Loading untrained {model_name}")

    config_untrained = AutoConfig.from_pretrained(model_path)
    untrained_model =  AutoModelForCausalLM.from_config(config_untrained)
    models.append(untrained_model)
    model_types.append("untrained")

    timestamp(f"Untrained {model_name} has been successfully loaded")
    

data_name = args.data

data_config = config["DATA_CONFIG"][data_name]
timestamp(f"Running over the {data_name} dataset")

try:
    prompt = ". What is the capital of France? The capital is"

    input_ids = tokenizer(prompt, return_tensors='pt').input_ids.cuda()

    with torch.no_grad():
        outputs = model(input_ids, output_hidden_states=True, output_attentions=True)

    labels = [tokenizer.decode(t) for t in input_ids[0]]

    print(labels)

    #if model_name == "neo-20b":
    first_token = labels.index(" What")
    #if model_name == "mpt-30b":
    #    first_token = 1

    representations = get_representations(outputs.hidden_states)

    plot_pca(representations[:-4, first_token:], name=model_name, dim=2, labels=labels[first_token:])


except Exception as e:
    print(f"An error occurred: {e}")




