from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import pipeline
import pandas as pd
import numpy as np
import glob
import os
import re
import datetime
from tqdm import tqdm
import torch
import yaml

from utils.eval_metrics import is_valid_smiles, clogp, qed
from tdc.generation import MolGen
from pandarallel import pandarallel
from dotenv import load_dotenv


load_dotenv()

os.environ.get('HF_TOKEN')

pandarallel.initialize()

script_dir = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(script_dir, "configs/cav_eval.yaml")

if os.path.exists(config_path):
    with open(config_path, "r") as file:
        config = yaml.safe_load(file)
    modify_func_name = config.get("modify_func", "additive") 
else:
    modify_func_name = "additive"

model_name = config.get("model")
data = config.get("data")
layer_to_modify = config.get("layer_to_modify")
qed_alpha = config.get("qed_alpha")
logp_alpha = config.get("logp_alpha")

print("CAV Alignment Function: ", modify_func_name)

print("Eval Dataset: ", data)

if data == "zinc":
    sampled_data = pd.read_csv("PATH/zinc_eval.csv")
elif data == "moses":
    eval_data = pd.read_parquet("PATH/test_data/moses_eval.parquet")
    sampled_data = eval_data.sample(n=2000, random_state=21)
elif data == "chembl":
    eval_data = pd.read_parquet("PATH/test_data/chembl_eval.parquet")
    sampled_data = eval_data.sample(n=2000, random_state=21)

# Load the tokenizer
if model_name == "chemlactica-125m":
    model = AutoModelForCausalLM.from_pretrained("PATH/chemlactica-125m").eval().to("cuda")
    tokenizer = AutoTokenizer.from_pretrained("PATH/chemlactica_tokenizer")
    tokenizer.padding_side = "left"
    cav = np.load("PATH/cav/chemlactica-125m/qed_layer_6_tdc_cav_new.npy")
    logp_cav = np.load("PATH/cav/chemlactica-125m/logp_layer_6_tdc_cav_new.npy")
elif model_name == "chemlactica-1p3b":
    model = AutoModelForCausalLM.from_pretrained("yerevann/chemlactica-1.3b").eval().to("cuda")
    tokenizer = AutoTokenizer.from_pretrained("yerevann/chemlactica-1.3b")
    tokenizer.padding_side = "left"
    cav = np.load("PATH/cav/chemlactica-1p3b/qed_layer_12_tdc_cav_new.npy")
    logp_cav = np.load("PATH/cav/chemlactica-1p3b/logp_layer_12_tdc_cav_new.npy")
elif model_name == "chemma-2b":
    model = AutoModelForCausalLM.from_pretrained("yerevann/chemma-2b").eval().to("cuda")
    tokenizer = AutoTokenizer.from_pretrained("yerevann/chemma-2b")
    tokenizer.padding_side = "left"
    cav = np.load("PATH/cav/chemma-2b/qed_layer_9_tdc_cav_new.npy")
    logp_cav = np.load("PATH/cav/chemma-2b/logp_layer_9_tdc_cav_new.npy")

special_tokens = ['[WAVELENGTH]', '[/WAVELENGTH]', '[F_OSC]', '[/F_OSC]', '[QED]', '[/QED]', '[LOGP]', '[/LOGP]', '[START_SMILES]', '[END_SMILES]', '[SEP]']
tokenizer.add_special_tokens({'additional_special_tokens': special_tokens})
model.resize_token_embeddings(len(tokenizer))


def modify_layer_output(module, input, output):
    # If output is a tuple, assume the first element is the main tensor.
    if isinstance(output, tuple):
        tensor_output = output[0]
    else:
        tensor_output = output

    # Ensure qed_alpha is a float
    alpha_val = float(qed_alpha)
    logp_alpha_val = float(logp_alpha)

    # Convert cav to a PyTorch tensor on the same device and with the same dtype as the tensor_output.
    if isinstance(cav, np.ndarray):
        cav_tensor = torch.tensor(cav, device=tensor_output.device, dtype=tensor_output.dtype)
        logp_cav_tensor = torch.tensor(logp_cav, device=tensor_output.device, dtype=tensor_output.dtype)
    elif isinstance(cav, torch.Tensor):
        cav_tensor = cav.to(tensor_output.device)
        logp_cav_tensor = logp_cav.to(tensor_output.device)
    else:
        cav_tensor = torch.tensor(cav, device=tensor_output.device, dtype=tensor_output.dtype)
        logp_cav_tensor = torch.tensor(logp_cav, device=tensor_output.device, dtype=tensor_output.dtype)

    # Modify the tensor output.
    modified_tensor = tensor_output + alpha_val * cav_tensor + logp_alpha_val * logp_cav_tensor

    # If the original output was a tuple, reconstruct it with the modified tensor as the first element.
    if isinstance(output, tuple):
        return (modified_tensor,) + output[1:]
    else:
        return modified_tensor

MODIFY_FUNCTIONS = {
    "additive": modify_layer_output,
}

selected_modify_func = MODIFY_FUNCTIONS.get(modify_func_name, modify_layer_output)
if model_name == "chemma-2b":
    layer = model.model.layers[layer_to_modify]
else:
    layer = model.model.decoder.layers[layer_to_modify]
hook_handle = layer.register_forward_hook(selected_modify_func)

batch_size = 32
eval_set = []

input_texts = [f"</s>[QED]{row['QED']}[/QED][LOGP]{row['LOGP']}[/LOGP][SEP][START_SMILES]" for _, row in sampled_data.iterrows()]

# Process in batches
for i in tqdm(range(0, len(input_texts), batch_size), total=len(input_texts) // batch_size + 1):
    batch_inputs = input_texts[i:i + batch_size]

    # Tokenize entire batch
    prompt = tokenizer(batch_inputs, return_tensors="pt", padding=True, truncation=True).to(model.device)

    # Generate batch output
    output = model.generate(
        prompt.input_ids, 
        do_sample=False, 
        max_length=512, 
        return_dict_in_generate=True, 
        output_scores=True, 
        eos_token_id=tokenizer.convert_tokens_to_ids("[END_SMILES]")
    )

    # Decode outputs in batch
    batch_outputs = tokenizer.batch_decode(output.sequences)

    # Extract SMILES and evaluate properties
    for input_text, out in zip(batch_inputs, batch_outputs):
        smiles_pattern = re.search(r'\[START_SMILES\](.*?)(\[END_SMILES\]|$)', out)
        smiles_string = smiles_pattern.group(1) if smiles_pattern else None

        valid_smiles = is_valid_smiles(smiles_string)
        clogp_value = clogp(smiles_string)
        qed_value = qed(smiles_string)

        eval_set.append({
            'input_text': input_text,
            'output': out,
            'smiles_string': smiles_string,
            'valid_smiles': valid_smiles,
            'clogp_value': clogp_value,
            'qed_value': qed_value
        })

eval_df = pd.DataFrame(eval_set)

current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

file_name = f"generated_smiles_{data}_cav_{modify_func_name}_tdc_new.csv"

eval_df.to_csv(f"./evaluation_samples/eval/cav/{data}/{model_name}/{file_name}", index=False)