import torch
from datasets import load_dataset
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
)
from peft import LoraConfig, get_peft_model, TaskType
from peft import PeftModel
from trl import GRPOTrainer
import os
from rdkit import Chem
from rdkit.Chem import QED
import re
from rdkit.Chem import QED, Crippen, Descriptors

MODEL_NAME = ''
DATASET_NAME = ''
OUTPUT_DIR = ''

# SPECIAL_TOKENS = ["<SMILES>", "</SMILES>", "<QED>", "</QED>"]
SYSTEM_PROMPT = "You love and excel at editing SMILES strings to make original SMILES meet the required numeric properties.\n"

# === Load Tokenizer and Add Special Tokens ===
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
    "OpenDFM/ChemDFM-v1.5-8B",
    device_map="auto",
    torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
    trust_remote_code=True,
    load_in_8bit=True
    
)

# Resize embeddings to match tokenizer (important!)
base_model.resize_token_embeddings(len(tokenizer))

model = PeftModel.from_pretrained(base_model, model_path, is_trainable=True)

from datasets import Dataset

test_prompts = [
    "Given the intermediate molecule SMILES <SMILES>C[C@@H]1CN(C(=O)c2cc(Br)cn2C)CC[C@H]1[NH3+]</SMILES>, which is composed of fragments ['N1CC[C@@H]([NH3+])[C@H](C)C1', 'C()=O', 'c1cc(Br)cn1C']. Propose a single replace, add or remove step on fragment level that makes the new molecule's QED <QED>0.136</QED> lower, LogP <LogP>1.413</LogP> lower, and Molecular Weight <MW>64.913</MW> lower."
    # Add more prompts if needed
]

dataset = Dataset.from_dict({"input": test_prompts}

from trl import GRPOConfig
training_args = GRPOConfig(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,
    dataloader_drop_last=True,
    dataloader_num_workers=4,
    generation_kwargs={
        "do_sample": True,
        "temperature": 1.0,
        "top_p": 0.9,
        "top_k": 50,
        "max_new_tokens": 128
    },
    # Make sure shuffle is effectively used (default is True, but just to be safe):
    disable_tqdm=False,
    max_steps=1,
    #num_train_epochs=None,
    learning_rate=3e-5,
    bf16=torch.cuda.is_bf16_supported(),
    logging_steps=50,
    save_steps=500,
    save_total_limit=20,
    report_to="wandb",  # or "none"
    logging_dir="./logs-qed-logp-mw-strict",
    remove_unused_columns=False,
)

trainer = GRPOTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    reward_funcs=[]
)

def optimize(input_text,trainer=trainer):
    formatted_prompt = f"[Round 0]\nHuman: {SYSTEM_PROMPT}{input_text}\nAssistant:"
    out = trainer.generate_completion(formatted_prompt)
    return out["texts"][0]

