from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import pipeline
import pandas as pd
import glob
import os
import re
import yaml
import datetime
from tqdm import tqdm

from utils.eval_metrics import is_valid_smiles, clogp, qed

script_dir = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(script_dir, "configs/eval.yaml")

if os.path.exists(config_path):
    with open(config_path, "r") as file:
        config = yaml.safe_load(file)

dataset = config.get("dataset")
model_name = config.get("model")

if dataset == "zinc":
    sampled_data = pd.read_csv("PATH/tdc_data/zinc_eval.csv")
elif dataset == "moses":
    eval_data = pd.read_parquet("PATH/tdc_data/moses_eval.parquet")
    sampled_data = eval_data.sample(n=2000, random_state=21)
elif dataset == "chembl":
    eval_data = pd.read_parquet("PATH/tdc_data/chembl_eval.parquet")
    sampled_data = eval_data.sample(n=2000, random_state=21)
elif dataset == "oled":
    eval_data = pd.read_parquet("PATH/test_data/oled_dataset.parquet")
    sampled_data = eval_data.sample(n=100, random_state=21)

if model_name == "chemlactica-125m":
    model = AutoModelForCausalLM.from_pretrained("yerevann/chemlactica-125m").eval().to("cuda")
    tokenizer = AutoTokenizer.from_pretrained("yerevann/chemlactica-125m")
    tokenizer.padding_side = "left"
elif model_name == "chemlactica-1p3b":
    model = AutoModelForCausalLM.from_pretrained("yerevann/chemlactica-1.3b").eval().to("cuda")
    tokenizer = AutoTokenizer.from_pretrained("yerevann/chemlactica-1.3b")
    tokenizer.padding_side = "left"
elif model_name == "finetuned_chemlactica_125m":
    tokenizer = AutoTokenizer.from_pretrained("PATH/finetuned_models/finetuned_chemlactica_125m")
    model = AutoModelForCausalLM.from_pretrained("PATH/finetuned_models/finetuned_chemlactica_125m").eval().to('cuda')
    tokenizer.padding_side = "left"
elif model_name == "finetuned_chemlactica_1.3b":
    tokenizer = AutoTokenizer.from_pretrained("PATH/finetuned_models/finetuned_chemlactica_1.3b")
    model = AutoModelForCausalLM.from_pretrained("PATH/finetuned_models/finetuned_chemlactica_1.3b").eval().to('cuda')
    tokenizer.padding_side = "left"
elif model_name == "finetuned_chemma_2b":
    tokenizer = AutoTokenizer.from_pretrained("PATH/finetuned_models/finetuned_chemma_2b")
    model = AutoModelForCausalLM.from_pretrained("PATH/finetuned_models/finetuned_chemma_2b").eval().to('cuda')
    tokenizer.padding_side = "left"

batch_size = 32
eval_set = []

input_texts = [f"</s>[QED]{round(row['QED'], 2)}[/QED][CLOGP]{round(row['LOGP'], 2)}[/CLOGP][SEP][START_SMILES]" for _, row in sampled_data.iterrows()]

# Process in batches
for i in tqdm(range(0, len(input_texts), batch_size), total=len(input_texts) // batch_size + 1):
    batch_inputs = input_texts[i:i + batch_size]

    # Tokenize entire batch
    prompt = tokenizer(batch_inputs, return_tensors="pt", padding=True, truncation=True).to(model.device)

    # Generate batch output
    output = model.generate(
        prompt.input_ids, 
        do_sample=True, 
        max_length=512, 
        return_dict_in_generate=True, 
        output_scores=True, 
        eos_token_id=tokenizer.convert_tokens_to_ids("[END_SMILES]")
    )                                                                                          

    # Decode outputs in batch
    batch_outputs = tokenizer.batch_decode(output.sequences)

    # Extract SMILES and evaluate properties
    for input_text, out in zip(batch_inputs, batch_outputs):
        smiles_pattern = re.search(r'\[START_SMILES\](.*?)(\[END_SMILES\]|$)', out)
        smiles_string = smiles_pattern.group(1) if smiles_pattern else None

        valid_smiles = is_valid_smiles(smiles_string)
        clogp_value = clogp(smiles_string)
        qed_value = qed(smiles_string)

        eval_set.append({
            'input_text': input_text,
            'output': out,
            'smiles_string': smiles_string,
            'valid_smiles': valid_smiles,
            'clogp_value': clogp_value,
            'qed_value': qed_value
        })


eval_df = pd.DataFrame(eval_set)

current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
file_name = f"generated_smiles_{dataset}_{model_name}.csv"
os.makedirs(f"./evaluation_samples/eval/{model_name}/", exist_ok=True)
eval_df.to_csv(f"./evaluation_samples/eval/{model_name}/{file_name}", index=False)