from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import pipeline
import pandas as pd
import glob
import os
import re
import datetime
from tqdm import tqdm

from utils.eval_metrics import is_valid_smiles, clogp, qed

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("PATH/finetuned_models/finetuned_chemma_2b")

# Load the model from your checkpoint
model = AutoModelForCausalLM.from_pretrained("PATH/finetuned_models/finetuned_chemma_2b").eval().to('cuda')

special_tokens = ['[WAVELENGTH]', '[/WAVELENGTH]', '[F_OSC]', '[/F_OSC]', '[QED]', '[/QED]', '[LOGP]', '[/LOGP]', '[START_SMILES]', '[END_SMILES]', '[SEP]']
tokenizer.add_special_tokens({'additional_special_tokens': special_tokens})
model.resize_token_embeddings(len(tokenizer))

eval_set = []
counter = 0

input_text = "</s>[LOGP]8.1940[/LOGP][START_SMILES]"
prompt = tokenizer(input_text, return_tensors="pt").to(model.device)

# Generation settings
num_samples = 2000
batch_size = 16  # Adjust based on GPU memory
all_outputs = []

# Diversity parameters
temperature = 0.9  
top_k = 50      
top_p = 0.9        

# Generate in batches
num_batches = (num_samples + batch_size - 1) // batch_size 
for start_idx in tqdm(range(0, num_samples, batch_size), total=num_batches, desc="Generating samples"):
    end_idx = min(start_idx + batch_size, num_samples)
    output = model.generate(
        prompt.input_ids.repeat(end_idx - start_idx, 1),
        do_sample=True,  # Enable sampling
        temperature=temperature,
        max_length=512,
        return_dict_in_generate=True,
        eos_token_id=tokenizer.convert_tokens_to_ids("[END_SMILES]"),
    )
    # Decode batch outputs
    decoded_outputs = tokenizer.batch_decode(output.sequences, skip_special_tokens=False)

    for generated_output in decoded_outputs:

        # Extract SMILES string
        smiles_pattern = re.search(r'\[START_SMILES\](.*?)(\[END_SMILES\]|$)', generated_output)
        smiles_string = smiles_pattern.group(1) if smiles_pattern else None

        # Validate SMILES and calculate properties
        valid_smiles = is_valid_smiles(smiles_string) if smiles_string else False
        clogp_value = clogp(smiles_string) if valid_smiles else None
        qed_value = qed(smiles_string) if valid_smiles else None

        # Append to evaluation set
        eval_set.append({
            'input_text': input_text,
            # 'output': generated_output,
            'smiles_string': smiles_string,
            'valid_smiles': valid_smiles,
            'clogp_value': clogp_value,
            'qed_value': qed_value
        })

        counter += 1

eval_df = pd.DataFrame(eval_set)

current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
file_name = f"chemma_2b_base_generated_smiles_data_ood_logp_high.csv"

eval_df.to_csv(f"./evaluation_samples/finetuned/chemma-2b/{file_name}", index=False)