import numpy as np
import pandas as pd
import torch
import random
from preprocessing import *
import sys
from prompt import *

seed = 9
machine = 'svm'
dataset = sys.argv[1]

np.random.seed(seed)
random.seed(seed)

_, X, _, y, _, categorical, _ = load_data(dataset, seed=seed)

file = f'./explanations/{dataset}.csv'
shaps = pd.read_csv(file)
shaps.set_index("id", inplace=True) 
shaps = shaps.iloc[:,3:] 
shaps[shaps > 1] = 0
shaps[shaps < -1] = 0

X.dropna(inplace=True)
indexes = np.intersect1d(X.index, shaps.index)

X = X.loc[indexes]
X.sort_index(inplace=True)
shaps = shaps.loc[indexes]
shaps.sort_index(inplace=True)

if dataset == 'compas':
    prompt = PROMPT_COMPAS
elif dataset == 'parkinson':
    prompt = PROMPT_PARKINSON
elif dataset == 'creditcard':
	prompt = PROMPT_CREDITCARD
elif dataset == 'adult':
	prompt = PROMPT_ADULT
elif dataset == 'churn':
	prompt = PROMPT_CHURN
elif dataset == 'wine':
	prompt = PROMPT_WINE
elif dataset == 'bike':
    prompt = PROMPT_BIKE
elif dataset == 'power':
    prompt = PROMPT_POWER
else:
	raise ValueError(f"Unknown dataset: {dataset}") 

token = 'YOUR_TOKEN_HERE'

import os 

folder = 'llama3'
os.makedirs(folder, exist_ok=True)
os.environ['HF_HOME'] = folder
os.environ['HF_HUB_CACHE'] = folder
os.environ['HUGGINGFACE_HUB_CACHE'] = folder

# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

import torch

# Allow user to specify device via command line argument or environment variable
if len(sys.argv) > 2:
	device = sys.argv[2]
else:
	device = os.environ.get("LLM_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")

model_id = 'meta-llama/Llama-3.1-8B-Instruct'

tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
model = AutoModelForCausalLM.from_pretrained(model_id, token=token).to(device)
for i, row in shaps.iterrows():
	x = X.loc[i]
	history = []
	message = prompt + "\n"
	for j, (feature_id, feature_value) in enumerate(zip(x.index, x.values)):
		relevance_score = row[j]
		message += f"{feature_id}  : {feature_value} = {relevance_score}\n"
	message += "\nEvaluate the explanation quality and provide your response."
	history.append({"role": "user", "content": message})
	messages = history.copy()
	inputs = tokenizer.apply_chat_template(
		messages,
		add_generation_prompt=True,
		tokenize=True,
		return_dict=True,
		return_tensors="pt",
	).to(device)
	outputs = model.generate(**inputs, max_new_tokens=10, do_sample=True, temperature=0.7, top_p=0.9)
	output = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:])
	# After getting output, you can append it to history for the next turn:
	output = output.split('<')[0]
	output = output.split('\n')[0]
	label = output.split(' ')[0]
	features_str = ' '.join(output.split(' ')[1:])
	# Extract only the first line and parse numbers
	first_line = features_str.split('\n')[0].strip()
	features = []
	for f in first_line.split(','):
		f = f.strip()
		if f and f.isdigit():
			features.append(int(f))

	feature_values = [0 if i not in features else 1 for i in range(X.shape[1])]
	data = [i, seed, int(label)] + feature_values

	columns = ['id', 'seed', 'label'] + X.columns.tolist()
	data = pd.DataFrame([data], columns=columns)
	data.to_csv(f'./llm_labels/{dataset}.csv', mode='a', header=not os.path.exists(f'./llm_labels/{dataset}.csv'), index=False)
