from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import pipeline
import pandas as pd
import numpy as np
import glob
import os
import re
import datetime
from tqdm import tqdm
import torch
import yaml

from utils.eval_metrics import is_valid_smiles, clogp, qed
from tdc.generation import MolGen
from pandarallel import pandarallel
from dotenv import load_dotenv


load_dotenv()

os.environ.get('HF_TOKEN')

script_dir = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(script_dir, "configs/cav_logp.yaml")

pandarallel.initialize()

if os.path.exists(config_path):
    with open(config_path, "r") as file:
        config = yaml.safe_load(file)
    modify_func_name = config.get("modify_func", "additive")  # Default to 'additive'
else:
    modify_func_name = "additive"

logp_prop_value = config.get("logp")
layer_to_modify = config.get("layer_to_modify")
alpha = config.get("alpha")
model_name = config.get("model_name")

print("CAV Alignment Function: ", modify_func_name)

if model_name == "chemlactica-125m":
    tokenizer = AutoTokenizer.from_pretrained("yerevann/chemlactica-125m")
    tokenizer.padding_side = "left"
    model = AutoModelForCausalLM.from_pretrained("yerevann/chemlactica-125m").eval().to('cuda')
    cav = np.load("PATH/cav/chemlactica-125m/logp_layer_6_tdc_cav_new.npy")
elif model_name == "chemlactica-1p3b":
    tokenizer = AutoTokenizer.from_pretrained("yerevann/chemlactica-1.3b")
    tokenizer.padding_side = "left"
    model = AutoModelForCausalLM.from_pretrained("yerevann/chemlactica-1.3b").eval().to('cuda')
    cav = np.load("PATH/cav/chemlactica-1p3b/logp_layer_12_tdc_cav_new.npy")
elif model_name == "chemma-2b":
    model = AutoModelForCausalLM.from_pretrained("yerevann/chemma-2b").eval().to("cuda")
    tokenizer = AutoTokenizer.from_pretrained("yerevann/chemma-2b")
    tokenizer.padding_side = "left"
    cav = np.load("PATH/cav/chemma-2b/logp_layer_9_tdc_cav_new.npy")

special_tokens = ['[WAVELENGTH]', '[/WAVELENGTH]', '[F_OSC]', '[/F_OSC]', '[QED]', '[/QED]', '[LOGP]', '[/LOGP]', '[START_SMILES]', '[END_SMILES]', '[SEP]']
tokenizer.add_special_tokens({'additional_special_tokens': special_tokens})
model.resize_token_embeddings(len(tokenizer))


def modify_layer_output(module, input, output):
    # If output is a tuple, assume the first element is the main tensor.
    if isinstance(output, tuple):
        tensor_output = output[0]
    else:
        tensor_output = output

    # Ensure alpha is a float
    alpha_val = float(alpha)

    # Convert cav to a PyTorch tensor on the same device and with the same dtype as the tensor_output.
    if isinstance(cav, np.ndarray):
        cav_tensor = torch.tensor(cav, device=tensor_output.device, dtype=tensor_output.dtype)
    elif isinstance(cav, torch.Tensor):
        cav_tensor = cav.to(tensor_output.device)
    else:
        cav_tensor = torch.tensor(cav, device=tensor_output.device, dtype=tensor_output.dtype)

    # Modify the tensor output.
    modified_tensor = tensor_output + alpha_val * cav_tensor

    # If the original output was a tuple, reconstruct it with the modified tensor as the first element.
    if isinstance(output, tuple):
        return (modified_tensor,) + output[1:]
    else:
        return modified_tensor

    
MODIFY_FUNCTIONS = {
    "additive": modify_layer_output,
}

selected_modify_func = MODIFY_FUNCTIONS.get(modify_func_name, modify_layer_output)

if model_name == "chemma-2b":
    layer = model.model.layers[layer_to_modify]
else:
    layer = model.model.decoder.layers[layer_to_modify]

hook_handle = layer.register_forward_hook(selected_modify_func)

eval_set = []
counter = 0

input_text = f"</s>[LOGP]{logp_prop_value}[/LOGP][START_SMILES]"
prompt = tokenizer(input_text, return_tensors="pt").to(model.device)

# Generation settings
num_samples = 2000
batch_size = 32  
all_outputs = []

# Diversity parameters
temperature = 0.9     

# Generate in batches
num_batches = (num_samples + batch_size - 1) // batch_size  # Calculate number of batches
for start_idx in tqdm(range(0, num_samples, batch_size), total=num_batches, desc="Generating samples"):
    end_idx = min(start_idx + batch_size, num_samples)
    output = model.generate(
        prompt.input_ids.repeat(end_idx - start_idx, 1),
        do_sample=True,  
        temperature=temperature,
        max_length=512,
        return_dict_in_generate=True,
        eos_token_id=tokenizer.convert_tokens_to_ids("[END_SMILES]"),
    )
    # Decode batch outputs
    decoded_outputs = tokenizer.batch_decode(output.sequences, skip_special_tokens=False)

    for generated_output in decoded_outputs:

        # Extract SMILES string
        smiles_pattern = re.search(r'\[START_SMILES\](.*?)(\[END_SMILES\]|$)', generated_output)
        smiles_string = smiles_pattern.group(1) if smiles_pattern else None

        # Validate SMILES and calculate properties
        valid_smiles = is_valid_smiles(smiles_string) if smiles_string else False
        clogp_value = clogp(smiles_string) if valid_smiles else None
        qed_value = qed(smiles_string) if valid_smiles else None

        # Append to evaluation set
        eval_set.append({
            'input_text': input_text,
            # 'output': generated_output,
            'smiles_string': smiles_string,
            'valid_smiles': valid_smiles,
            'clogp_value': clogp_value,
            'qed_value': qed_value
        })

        counter += 1

eval_df = pd.DataFrame(eval_set)

if logp_prop_value < 0:
    logp_prop = "low"
else:
    logp_prop = "high"

current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

file_name = f"ood_logp_{logp_prop}_cav_alpha_{alpha}_tdc.csv"

eval_df.to_csv(f"./evaluation_samples/cav/logp/{model_name}/{file_name}", index=False)
print(f" Saved File at ./evaluation_samples/cav/logp/{model_name}/{file_name}")