import os
from typing import Dict
from task_tracker.models.model import Model
from task_tracker.CONFIG import current_risk
import torch

# Set the cache directory for Hugging Face transformers
current_dataset= "hotpotqa"
cache_dir = "/guardrail/TaskTracker/store/model/"
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.environ["HF_HOME"] = cache_dir

# Directory where model activation data will be stored
activation_parent_dir =  '/guardrail/TaskTracker/store/activations/'+ current_risk +"/" + current_dataset

# Directory where the dataset text files are stored
text_dataset_parent_dir = "/guardrail/TaskTracker/store/output_datasets/"+ current_risk + "/" + current_dataset

# Directory where the database file is stored
database_dir = "/guardrail/TaskTracker/store/output_datasets/"+ current_risk + "/" + current_dataset + "/database.json"

# Paths to dataset files
# data = {
#     "train": os.path.join(text_dataset_parent_dir, "train_subset.json"),
#     "val_clean": os.path.join(text_dataset_parent_dir, "dataset_out_clean.json"),
#     "val_poisoned": os.path.join(text_dataset_parent_dir, "dataset_out_poisoned_v1.json"),
#     "test_clean": os.path.join(text_dataset_parent_dir, "dataset_out_clean_v2.json"),
#     "test_poisoned": os.path.join(text_dataset_parent_dir, "dataset_out_poisoned_v2.json"),
# }

if current_risk == 'Unauthorized_Access':
    data = {
        "train_case": os.path.join(text_dataset_parent_dir, "dataset_case_train.json"),
        "train_employee": os.path.join(text_dataset_parent_dir, "dataset_employee_train.json"),
        "train_financial": os.path.join(text_dataset_parent_dir, "dataset_case_train.json"),
        "train_goods": os.path.join(text_dataset_parent_dir, "dataset_goods_train.json"),
        "val_case": os.path.join(text_dataset_parent_dir, "dataset_case_val.json"),
        "val_employee": os.path.join(text_dataset_parent_dir, "dataset_employee_val.json"),
        "val_financial": os.path.join(text_dataset_parent_dir, "dataset_financial_val.json"),
        "val_goods": os.path.join(text_dataset_parent_dir, "dataset_goods_val.json"),
        "test_case": os.path.join(text_dataset_parent_dir, "dataset_case_test.json"),
        "test_employee": os.path.join(text_dataset_parent_dir, "dataset_employee_test.json"),
        "test_financial": os.path.join(text_dataset_parent_dir, "dataset_financial_test.json"),
        "test_goods": os.path.join(text_dataset_parent_dir, "dataset_goods_test.json"),
    }
else:
    data = {
        "train_clean": os.path.join(text_dataset_parent_dir, "dataset_clean_train.json"),
        "train_poisoned": os.path.join(text_dataset_parent_dir, "dataset_poisoned_train.json"),
        "test_clean": os.path.join(text_dataset_parent_dir, "dataset_clean_test.json"),
        "test_poisoned": os.path.join(text_dataset_parent_dir, "dataset_poisoned_test.json"),
        "val_clean": os.path.join(text_dataset_parent_dir, "dataset_clean_val.json"),
        "val_poisoned": os.path.join(text_dataset_parent_dir, "dataset_poisoned_val.json"),
    }
# Initialize models with specific configurations
llama_3_70B = Model(
    name="meta-llama/Meta-Llama-3-70B-Instruct",
    output_dir=os.path.join(activation_parent_dir, "llama3_70b"),
    data=data,
    subset="train",
    torch_dtype=torch.bfloat16
)

llama_3_8B = Model(
    name="meta-llama/Meta-Llama-3-8B-Instruct",
    output_dir=os.path.join(activation_parent_dir, "llama3_8b"),
    data=data,
    subset="train",
    torch_dtype=torch.float32
)

mistral_7B = Model(
    name="mistralai/Mistral-7B-Instruct-v0.2",
    output_dir=os.path.join(activation_parent_dir, "mistral"),
    data=data,
    subset="train",
    torch_dtype=torch.float32
)

phi3 = Model(
    name="microsoft/Phi-3-mini-4k-instruct",
    output_dir=os.path.join(activation_parent_dir, "phi3"),
    data=data,
    subset="train",
    torch_dtype=torch.bfloat16
)

mixtral = Model(
    name="mistralai/Mixtral-8x7B-Instruct-v0.1",
    output_dir=os.path.join(activation_parent_dir, "mixtral"),
    data=data,
    subset="train",
    torch_dtype=torch.float16
)

vicuna = Model(
    name="lmsys/vicuna-7b-v1.5",
    output_dir=os.path.join(activation_parent_dir, "vicuna"),
    data=data,
    subset="train",
    torch_dtype=torch.float16
)

# Dictionary of models for easy access
models: Dict[str, Model] = {
    "llama3_70b": llama_3_70B,
    "llama3_8b": llama_3_8B,
    "mistral": mistral_7B,
    "phi3": phi3,
    "mixtral": mixtral,
    "vicuna": vicuna
}