from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import pipeline
import pandas as pd
import glob
import os
import re
import yaml
import datetime
from tqdm import tqdm
import torch

from utils.eval_metrics import is_valid_smiles, clogp, qed

def compute_qed_mae(df, ground_truth):
    df = df[df['valid_smiles'] == 1]
    return abs(ground_truth['QED'] - df['qed_value']).mean()

def compute_logp_mae(df, ground_truth):
    df = df[df['valid_smiles'] == 1]
    return abs(ground_truth['LOGP'] - df['clogp_value']).mean()

def validity(df):
    return df['valid_smiles'].mean()

def compute_generative_efficiency(df):
    unique_valid = df[df['valid_smiles'] == 1]['smiles_string'].nunique()
    return unique_valid / len(df) if len(df) > 0 else 0

eval_data = pd.read_parquet("PATH/tdc_data/oled_dataset.parquet")
sampled_data = eval_data.sample(n=100, random_state=21)

model_names = [
    "chemma-2b",
    "finetuned_chemlactica_125m",
    "finetuned_chemlactica_1.3b",
    "finetuned_chemma_2b",
]

model_paths = {
    "chemma-2b": "PATH/models/chemma-2b",
    "finetuned_chemlactica_125m": "PATH/finetuned_models/finetuned_chemlactica_125m",
    "finetuned_chemlactica_1.3b": "PATH/finetuned_models/finetuned_chemlactica_1.3b",
    "finetuned_chemma_2b": "PATH/finetuned_models/finetuned_chemma_2b"
}

for model_name in model_names:
    print(f"Evaluating {model_name}...")
    path = model_paths[model_name]
    tokenizer = AutoTokenizer.from_pretrained(path)
    tokenizer.padding_side = "left"
    model = AutoModelForCausalLM.from_pretrained(path).eval().to("cuda")
    tokenizer.padding_side = "left"



    num_samples = 100  # Number of samples per input
    all_eval_rows = []

    for idx, row in tqdm(sampled_data.iterrows(), total=len(sampled_data)):
        input_text = f"[QED]{row['QED']}[/QED][LOGP]{row['LOGP']}[/LOGP][START_SMILES]"
        batch_inputs = [input_text] * num_samples  # repeat input

        prompt = tokenizer(batch_inputs, return_tensors="pt", padding=True, truncation=True).to(model.device)
        outputs = model.generate(
            prompt.input_ids,
            do_sample=True,             # Now using sampling!
            top_k=50,                   # Optional: add temperature, top_k, top_p for more diversity
            temperature=1.0,
            max_length=512,
            eos_token_id=tokenizer.convert_tokens_to_ids("[END_SMILES]")
        )

        batch_outputs = tokenizer.batch_decode(outputs)

        smiles_list = []
        valid_list = []
        clogp_list = []
        qed_list = []

        for out in batch_outputs:
            smiles_pattern = re.search(r'\[START_SMILES\](.*?)(\[END_SMILES\]|$)', out)
            smiles_string = smiles_pattern.group(1) if smiles_pattern else None
            smiles_list.append(smiles_string)
            valid = is_valid_smiles(smiles_string)
            valid_list.append(valid)
            clogp_list.append(clogp(smiles_string))
            qed_list.append(qed(smiles_string))

        # Create a DataFrame for these 100 samples
        df = pd.DataFrame({
            "smiles_string": smiles_list,
            "valid_smiles": valid_list,
            "clogp_value": clogp_list,
            "qed_value": qed_list
        })

        # Calculate metrics for this data point
        gt = row  # ground truth row
        row_result = {
            "QED_MAE": compute_qed_mae(df, gt),
            "LogP_MAE": compute_logp_mae(df, gt),
            "Validity": validity(df),
            "GenEff": compute_generative_efficiency(df),
            "input_text": input_text,
            "gt_qed": row['QED'],
            "gt_logp": row['LOGP'],
        }
        all_eval_rows.append(row_result)


    eval_df = pd.DataFrame(all_eval_rows)
    summary = {
        "mean_QED_MAE": eval_df['QED_MAE'].mean(),
        "mean_LogP_MAE": eval_df['LogP_MAE'].mean(),
        "mean_Validity": eval_df['Validity'].mean(),
        "mean_GenEff": eval_df['GenEff'].mean(),
    }
    print(summary)
    # You can save eval_df as usual
    eval_df.to_csv(f"evaluation_samples/oled/oled_eval_metrics_{model_name}.csv", index=False)

    del model
    torch.cuda.empty_cache()
    print(f"Done with {model_name}, GPU memory released.\n")