import torch
from transformers import TrainerCallback
import wandb
import random

class LogOutputsCallback(TrainerCallback):
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

        fixed_examples = [
            "[WAVELENGTH]{wavelength}[/WAVELENGTH][F_OSC]{f_osc}[/F_OSC][SEP][START_SMILES]",
            "[WAVELENGTH]{wavelength}[/WAVELENGTH][F_OSC]{f_osc}[/F_OSC][SEP][START_SMILES]",
            "[WAVELENGTH]{wavelength}[/WAVELENGTH][F_OSC]{f_osc}[/F_OSC][SEP][START_SMILES]"
        ]

        random_examples = []
        for example in fixed_examples:
            wavelength = random.uniform(200, 400) 
            f_osc = random.uniform(0.5, 1)        

            formatted_example = example.format(wavelength=round(wavelength, 2), f_osc=round(f_osc, 3))
            random_examples.append(formatted_example)

        self.fixed_examples = random_examples

    def on_evaluate(self, args, state, control, **kwargs):
        model = kwargs['model']

        log_data = []
        print("\nLogging model outputs for fixed examples during evaluation:")

        for i, input_text in enumerate(self.fixed_examples):
            input_ids = self.tokenizer(input_text, return_tensors="pt").input_ids.to(model.device)

            output = model.generate(
                input_ids=input_ids,
                max_new_tokens=128,  
                eos_token_id=self.tokenizer.convert_tokens_to_ids("[END_SMILES]"),
                pad_token_id=self.tokenizer.pad_token_id,
                do_sample=False 
            )

            generated_text = self.tokenizer.decode(output[0], skip_special_tokens=False)
            
            print(f"Input: {input_text}")
            print(f"Generated Output: {generated_text}\n")
