import torch
import torch.nn.functional as F
import math
import os
import pandas as pd
import time
import torch.nn as nn
import wandb
import yaml
import argparse
from collections import defaultdict
from torch.utils.data import DataLoader
from functools import partial
from omegaconf import OmegaConf
from typing import List
from onmt.translate.translator import build_translator
from onmt.utils.parse import ArgumentParser

from multiguide.dataset.molecule_dataset import MoleculeDataset
from multiguide.training.bucket_batch_ddp_sampler import DistributedBucketBatchSampler
from multiguide.training.curriculum_bucket_batch_sampler import CurriculumBucketBatchSampler
from multiguide.property.metrics import compute_picp, compute_miw,\
                                        uncertainty_calibration_plot, \
                                        analyze_uncertainty_by_length, \
                                        expected_calibration_error, \
                                        calculate_ece
from multiguide.dataset.helpers import classifier_data_to_int, simplify_expression, get_vocab_from_trained_model
from multiguide.property.property_predictor import PropertyPredictor
from multiguide.helpers import PROJECT_ROOT
from multiguide.onmt.guided_translator import build_classifier_guided_translator

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def set_property_predictor(config, return_checkpoint=False):
    vocab = get_vocab_from_trained_model(config.classifier_guidance.onmt_checkpoint_path)
    property_predictor = PropertyPredictor(config, len(vocab))
    checkpoint_path = os.path.join(PROJECT_ROOT,
                                   'checkpoints',
                                   config.classifier_guidance.checkpoint_path)
    print(f'======= loading property checkpoint from {checkpoint_path}')
    checkpoint = torch.load(checkpoint_path, map_location=device)
    property_predictor.load_state_dict(checkpoint['model_state_dict'])
    property_predictor = property_predictor.to(device)
    if return_checkpoint:
        return property_predictor, checkpoint
    return property_predictor

# Define collate function outside of any other function to make it pickable
def collate_fn(batch, pad_idx):
    """
    Custom collate function that pads sequences in a batch to the same length.
    Similar to BucketIterator, it creates batches with similar sequence lengths.
    """
    # Sort batch by source sequence length (descending)
    batch.sort(key=lambda x: x[-1], reverse=True)
    # Separate the components
    src_seqs, prop, full_lengths, src_lengths = zip(*batch)
    # Find max length in this batch for source sequences
    max_src_len = max([len(seq) for seq in src_seqs])
    padded_src_seqs = [F.pad(seq, (0, max_src_len - len(seq)), value=pad_idx) for seq in src_seqs]
    # Stack tensors
    src_batch = torch.stack(padded_src_seqs)
    prop_batch = torch.stack(prop)
    full_lengths_batch = torch.stack(full_lengths)
    src_lengths_batch = torch.tensor(src_lengths)

    return src_batch, prop_batch, full_lengths_batch, src_lengths_batch

def classifier_evaluation(config, epoch, model, val_loader, run, device, target_mean, target_std):
    '''
        evaluate classifier
    '''
    model.eval()
    ces = []
    accuracies = []
    metrics = {'CCE': [], 'Accuracy': []}
    
    # Add memory debugging
    print(f"Memory before evaluation: {torch.cuda.memory_allocated()/1e9:.2f} GB")
    torch.cuda.empty_cache()  # Clear cache before evaluation
    
    #with torch.no_grad():
    with torch.cuda.amp.autocast(enabled=False):
        for i, batch in enumerate(val_loader):
            try:
                src_batch, prop_batch, full_length_batch, src_length_batch = batch
                #print(f'evaluating src_batch with size: {src_batch.shape}')
                src_batch = src_batch.to(device)
                prop_batch = prop_batch.to(device)
                pred = model(src_batch)
                prop_pred = torch.argmax(pred, dim=1)
                loss = compute_loss(config, pred, src_length_batch, full_length_batch, prop_batch, device)
                accuracy = (prop_pred == prop_batch.squeeze()).float().mean()
                ces.append(loss.item())
                accuracies.append(accuracy.item())
                # Clean up after each batch
                del src_batch, prop_batch, pred, prop_pred, loss, accuracy
                torch.cuda.empty_cache()
            except Exception as e:
                print(f"Error on batch {i}: {e}")
                print(f"Memory at error: {torch.cuda.memory_allocated()/1e9:.2f} GB")
                break
    if ces:
        metrics['CCE'] = sum(ces) / len(ces)
        metrics['Accuracy'] = sum(accuracies) / len(accuracies)
        run.log({'validation/CCE': metrics['CCE']})
        run.log({'validation/Accuracy': metrics['Accuracy']})
    return metrics
            

def evaluate_property_predictor(model, val_loader, target_mean, target_std, config, run, epoch, device):
    '''
        evaluate property predictor
    '''
    start_time = time.time()
    if config.classifier_guidance.train.stratified_evaluation:
        print(f'======= using stratified evaluation =======')
        metrics = stratified_evaluation(config, epoch, model, val_loader, run, device, target_mean, target_std)
        if run is not None:
            for bin_range, metrics in metrics.items():
                run.log({f'validation-{bin_range[0]}-{bin_range[1]}%/{k}':v for k,v in metrics.items()})
            print(f'===== evaluation time: {time.time() - start_time} for epoch {epoch} =====')
    elif config.classifier_guidance.train.classifier_evaluation:
        metrics = classifier_evaluation(config, epoch, model, val_loader, run, device, target_mean, target_std)
        if run is not None:
            for k, v in metrics.items():
                run.log({f'validation-{k}':v})
            print(f'===== evaluation time: {time.time() - start_time} for epoch {epoch} =====')
    else:
        raise ValueError('Only supports stratified evaluation for now')
        #metrics = regular_evaluation(model, val_loader, target_mean, target_std, device)

    return metrics

def save_checkpoint(config, run, model, optimizer, scheduler, epoch, target_mean, target_std, metrics):
    '''
        save checkpoint
    '''
    dir_path = os.path.join(PROJECT_ROOT, 'experiments', config.general.experiment_name, 'checkpoints')
    os.makedirs(dir_path, exist_ok=True)
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler is not None else None,
        'target_mean': target_mean,
        'target_std': target_std,
        'metrics': metrics,
    }, os.path.join(PROJECT_ROOT, 'experiments', config.general.experiment_name, 'checkpoints', f'checkpoint_{epoch}.pt'))
    print(f'======= saved checkpoint to {os.path.join(PROJECT_ROOT, "experiments", config.general.experiment_name, "checkpoints", f"checkpoint_{epoch}.pt")}')
    if config.general.wandb.mode == 'online':
        artifact = wandb.Artifact(name=f'model_{run.id}', type=f'predictor_{config.classifier_guidance.property}')
        artifact.add_file(os.path.join(PROJECT_ROOT, 'experiments', config.general.experiment_name, 'checkpoints', f'checkpoint_{epoch}.pt'))
        run.log_artifact(artifact, aliases=[f'epoch{epoch}'])

def compute_loss(config, pred, src_length_batch, full_length_batch, value_batch_normalized, device):
    '''
        compute loss
    '''
    if config.classifier_guidance.as_regression and config.classifier_guidance.train.loss=='ce':
        raise ValueError('CE loss is not supported for regression')
    if config.classifier_guidance.train.loss=='mse':
        loss = F.mse_loss(pred, value_batch_normalized, reduction='none')
    elif config.classifier_guidance.train.loss=='nll':
        # TODO: add weighted loss here too
        assert isinstance(pred, tuple) and len(pred) == 2, 'NLL loss requires a tuple of (mean, log_var)'
        loss = nll_loss(pred[0], pred[1], value_batch_normalized)
    elif config.classifier_guidance.train.loss=='ce':
        loss = F.cross_entropy(pred, value_batch_normalized.long().squeeze(1), reduction='none')
    else:
        raise ValueError(f'Invalid loss function: {config.classifier_guidance.train.loss}')
    
    if config.classifier_guidance.train.weighted_loss:
        weights = (src_length_batch.unsqueeze(1)/full_length_batch).to(device)
        loss = loss * weights
        loss = loss.sum()/weights.sum()
    else:
        loss = loss.mean()
    return loss

def train_property_predictor_epoch(model,
                                   train_loader,
                                   optimizer,
                                   target_mean,
                                   target_std,
                                   config,
                                   device,
                                   epoch,
                                   run=None):
    '''
        train property predictor epoch
    '''
    total_loss = 0
    epoch_start = time.time()
    model.train()
    for seq_batch, value_batch, full_length_batch, src_length_batch in train_loader:
        #print(f'Batch size: {seq_batch.shape}')
        optimizer.zero_grad()
        seq_batch = seq_batch.to(device)
        value_batch = value_batch.to(device).float()
        full_length_batch = full_length_batch.to(device)
        src_length_batch = src_length_batch.to(device)
        pred = model(seq_batch)
        if config.classifier_guidance.as_regression:
            value_batch_normalized = (value_batch-target_mean.to(device))/target_std.to(device)
        else:
            value_batch_normalized = value_batch
        loss = compute_loss(config, pred, src_length_batch,
                            full_length_batch, value_batch_normalized,
                            device)
        total_loss += loss.item()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

    if run is not None:
        run.log({'training/train_loss': total_loss / len(train_loader)})
        run.log({'training/epoch': epoch})

    if epoch % config.classifier_guidance.train.print_every == 0:
        print(f'======= Train Loss: {total_loss / len(train_loader)} for epoch {epoch}, time: {time.time() - epoch_start} =======')
    return total_loss

def get_data_loaders(config, jitter=1e-8):
    '''
        get data loaders
    '''
    train_df, val_df = load_data(config)
    print(f'Data sizes: train: {len(train_df)}, val: {len(val_df)}')
    train_dataset = MoleculeDataset(train_df['rxn'].to_list(),
                                    train_df['property'].to_list(), 
                                    train_df['full_length'].to_list(),
                                    config)
    val_dataset = MoleculeDataset(val_df['rxn'].to_list(), 
                                  val_df['property'].to_list(), 
                                  val_df['full_length'].to_list(),
                                  config)
    ## Create the distributed batch sampler
    if config.classifier_guidance.train.curriculum_learning:
        dist_batch_sampler = CurriculumBucketBatchSampler(
            dataset=train_dataset,
            batch_sizes=config.classifier_guidance.dataset.batch_sizes,  # List of batch sizes for each bucket
            num_buckets=config.classifier_guidance.dataset.num_buckets,
            shuffle=True,
            seed=config.classifier_guidance.train.seed if hasattr(config.classifier_guidance.train, 'seed') else None,
            rank=0,
            num_replicas=1
        )
    else:
        dist_batch_sampler = DistributedBucketBatchSampler(
            dataset=train_dataset,
            batch_sizes=config.classifier_guidance.dataset.batch_sizes,  # List of batch sizes for each bucket
            num_buckets=config.classifier_guidance.dataset.num_buckets,
            shuffle=True,
            rank=0,
            num_replicas=1
        )
    # Create DataLoader with the distributed sampler
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_sampler=dist_batch_sampler,
        collate_fn=partial(collate_fn, pad_idx=train_dataset.pad_idx),
        prefetch_factor=config.classifier_guidance.dataset.prefetch_factor if config.classifier_guidance.dataset.num_workers > 0 else None,
        num_workers=config.classifier_guidance.dataset.num_workers,
        pin_memory=True
    )

    ## Create the distributed batch sampler
    if config.classifier_guidance.train.curriculum_learning:
        dist_batch_sampler = CurriculumBucketBatchSampler(
            dataset=val_dataset,
            batch_sizes=config.classifier_guidance.dataset.val_batch_sizes,  # List of batch sizes for each bucket
            num_buckets=config.classifier_guidance.dataset.num_buckets,
            shuffle=False,
            seed=config.classifier_guidance.train.seed if hasattr(config.classifier_guidance.train, 'seed') else None,
            rank=0,
            num_replicas=1
        )
    else:
        dist_batch_sampler = DistributedBucketBatchSampler(
            dataset=val_dataset,
            batch_sizes=config.classifier_guidance.dataset.val_batch_sizes,  # List of batch sizes for each bucket
            num_buckets=config.classifier_guidance.dataset.num_buckets,
            shuffle=False,
            rank=0,
            num_replicas=1
        )
    val_loader = DataLoader(
        dataset=val_dataset,
        batch_sampler=dist_batch_sampler,
        collate_fn=partial(collate_fn, pad_idx=val_dataset.pad_idx),
        prefetch_factor=config.classifier_guidance.dataset.prefetch_factor if config.classifier_guidance.dataset.num_workers > 0 else None,
        num_workers=config.classifier_guidance.dataset.num_workers,
        pin_memory=True
    )

    if config.classifier_guidance.as_regression:
        target_mean = torch.tensor(train_df['property'].mean(), device='cpu')
        target_std = torch.tensor(train_df['property'].std(), device='cpu')+jitter

        torch.save(target_mean, os.path.join(PROJECT_ROOT,
                                            'data', 
                                            'predictors', 
                                            config.classifier_guidance.property,
                                            str(config.classifier_guidance.dataset.dataset_name),
                                            'target_mean.pt'))
        torch.save(target_std, os.path.join(PROJECT_ROOT,
                                            'data', 
                                            'predictors', 
                                            config.classifier_guidance.property,
                                            str(config.classifier_guidance.dataset.dataset_name),
                                            'target_std.pt'))
    else:
        target_mean = None
        target_std = None

    return train_loader, val_loader, target_mean, target_std

def setup_wandb(config):
    ''' 
        setup wandb logging
    '''
    config_dict = OmegaConf.to_container(config, resolve=True)
    # TODO: figure out if it's better to upload the new config file? 
    # or use the one saved on wandb?

    # automatically handles resuming a run if needed
    return wandb.init(project=config.general.wandb.project,
                        entity=config.general.wandb.entity,
                        name=config.general.wandb.name,
                        mode=config.general.wandb.mode,
                        config=config_dict,
                        resume='must' if config.general.wandb.resume_run_id is not None else False,
                        id=config.general.wandb.resume_run_id
                    )

def load_data(config):
    train_df = pd.read_csv(os.path.join(PROJECT_ROOT,
                                        'data',
                                        'predictors',
                                        config.classifier_guidance.property,
                                        str(config.classifier_guidance.dataset.dataset_name),
                                        config.classifier_guidance.dataset.train_file))
    val_df = pd.read_csv(os.path.join(PROJECT_ROOT,
                                        'data',
                                        'predictors',
                                        config.classifier_guidance.property,
                                        str(config.classifier_guidance.dataset.dataset_name),
                                        config.classifier_guidance.dataset.val_file))
    print(f'Data sizes: train: {len(train_df)}, val: {len(val_df)}')
    return train_df, val_df

def construct_model(config, alphabet_size):
    if config.classifier_guidance.predictor_type=='neural_network':
        model = PropertyPredictor(config, alphabet_size)
    else:
        raise ValueError(f'Invalid predictor type: {config.classifier_guidance.predictor_type}')
    return model

def print_metrics(metrics, bin_range):
    output = f"\nCompletion {bin_range[0]}-{bin_range[1]}%: "

    for key, value in metrics.items():
        if isinstance(value, torch.Tensor):
            output += f"{key} = {value.item():.4f}, "
        else:
            output += f"{key} = {value:.4f}, "

    print(output)

def nll_loss(pred_value, pred_log_var, target):
    """Negative log-likelihood loss with learned variance"""
    # Clamp log_var for numerical stability
    pred_log_var = torch.clamp(pred_log_var, min=-10, max=10)
    
    return 0.5 * pred_log_var + 0.5 * torch.exp(-pred_log_var) * (pred_value - target)**2
    
def stratified_evaluation(config, epoch, model, dataloader, run, device, target_mean, target_std):
    """Evaluate model across different completion percentage bins using batched data.
    
    Args:
        model: PyTorch model with a predict or forward method
        dataloader: DataLoader providing batches of (partial_seqs, true_values, full_seqs)
    """
    # Define bins as percentage ranges
    bins = [(0, 20), (20, 40), (40, 60), (60, 80), (80, 100)]
    bin_results = {bin_range: {'pred_values': [],
                               'log_var': [],
                               'true_values': [],
                               'error': [], 
                               'variance': [],
                               'full_length': [],
                               'loss': [],
                               'picp': [],
                               'miw': [],
                               'accuracy': [],
                               'confidence': [],
                               'bin_count': []} for bin_range in bins}
    model.eval()
    with torch.no_grad():
        for batch in dataloader:
            partial_seqs, true_values, full_length_batch, src_length_batch = batch
            # Move to device if using GPU
            partial_seqs = partial_seqs.to(device)
            true_values = true_values.to(device)
            true_values_normalized = (true_values - target_mean.to(device)) / target_std.to(device)
            # Get predictions
            pred_normalized = model(partial_seqs)
            loss = compute_loss(config, pred_normalized, src_length_batch,
                                full_length_batch, true_values_normalized, device)
            if config.classifier_guidance.as_regression:
                if config.classifier_guidance.train.loss=='nll':
                    pred_values_normalized, log_var_normalized = pred_normalized
                else:
                    pred_values_normalized = pred_normalized
                    log_var_normalized = None
                pred_values = pred_values_normalized * target_std.to(device) + target_mean.to(device)
                log_var = log_var_normalized + 2 * torch.log(target_std.to(device)) if log_var_normalized is not None else None
            # Get completion percentages for the batch
            print(f'src_length_batch: {src_length_batch}')
            print(f'full_length_batch: {full_length_batch}')
            completion_pcts = (src_length_batch / full_length_batch.squeeze(1) * 100).tolist()
            # Process each sample in the batch
            for i, completion_pct in enumerate(completion_pcts):
                for bin_range in bins:
                    if bin_range[0] <= completion_pct < bin_range[1]:
                        # Calculate error
                        if config.classifier_guidance.as_regression:
                            error = (pred_values[i] - true_values[i])**2
                            variance = (true_values[i] - torch.mean(true_values))**2
                            bin_results[bin_range]['error'].append(error)
                            bin_results[bin_range]['variance'].append(variance)
                            bin_results[bin_range]['loss'].append(loss)
                            if config.classifier_guidance.train.loss=='nll':
                                picp = compute_picp(pred_values[i], log_var[i], true_values[i])
                                bin_results[bin_range]['picp'].append(picp)
                                miw = compute_miw(log_var[i])
                                bin_results[bin_range]['miw'].append(miw)
                                bin_results[bin_range]['pred_values'].append(pred_values[i])
                                bin_results[bin_range]['log_var'].append(torch.exp(log_var[i]))
                                bin_results[bin_range]['true_values'].append(true_values[i])
                                bin_results[bin_range]['full_length'].append(full_length_batch[i])
                            # accuracy, confidence, bin_count = expected_calibration_error(error, variance, num_bins=10, threshold=0.1)
                            # bin_results[bin_range]['accuracy'].append(accuracy)
                            # bin_results[bin_range]['confidence'].append(confidence)
                            # bin_results[bin_range]['bin_count'].append(bin_count)
                        else:
                            # compute accuracy
                            accuracy = (pred_values[i] == true_values[i]).float()
                            bin_results[bin_range]['accuracy'].append(accuracy)
    # Calculate metrics per bin
    metrics = {}
    for bin_range in bin_results.keys():
        if config.classifier_guidance.as_regression:
            errors = bin_results[bin_range]['error']
            variances = bin_results[bin_range]['variance']
            if errors:
                # ece = calculate_ece(bin_results[bin_range]['accuracy'],
                #                     bin_results[bin_range]['confidence'],
                #                     bin_results[bin_range]['bin_count'])
                # (config, predictions, variances, targets, num_bins=10)
                if config.classifier_guidance.train.loss=='nll':
                    calibration_error = uncertainty_calibration_plot(config=config,
                                                                    predictions=bin_results[bin_range]['pred_values'],
                                                                    variances=bin_results[bin_range]['log_var'],
                                                                    errors=bin_results[bin_range]['error'],
                                                                    targets=bin_results[bin_range]['true_values'],
                                                                    plot_name=f'calibration_plot_epoch_{epoch}_{bin_range[0]}-{bin_range[1]}%',
                                                                    wandb_panel=f'validation-{bin_range[0]}-{bin_range[1]}%',
                                                                    run=run,
                                                                    num_bins=min(len(bin_results[bin_range]['pred_values']), 10))
                    uncertainty_length_corr, error_length_corr = analyze_uncertainty_by_length(config=config,
                                                                                                errors=bin_results[bin_range]['error'],
                                                                                                variances=bin_results[bin_range]['log_var'],
                                                                                                lengths=bin_results[bin_range]['full_length'],
                                                                                                plot_name=f'uncertainty_length_corr_epoch_{epoch}_{bin_range[0]}-{bin_range[1]}%',
                                                                                                wandb_panel=f'validation-{bin_range[0]}-{bin_range[1]}%',
                                                                                                run=run)
                metrics[bin_range] = {
                    'mse': sum(errors) / len(errors),
                    'rmse': torch.sqrt(sum(errors) / len(errors)),
                    'r_squared': 1 - (sum(errors) / sum(variances)),
                    'loss': sum(bin_results[bin_range]['loss']) / len(bin_results[bin_range]['loss']),
                    'count': len(errors),
                }
                if config.classifier_guidance.train.loss=='nll':
                    metrics[bin_range]['picp'] = sum(bin_results[bin_range]['picp']) / len(bin_results[bin_range]['picp'])
                    metrics[bin_range]['miw'] = sum(bin_results[bin_range]['miw']) / len(bin_results[bin_range]['miw'])
                    metrics[bin_range]['calibration_error'] = calibration_error
                    metrics[bin_range]['uncertainty_length_corr'] = uncertainty_length_corr
                    metrics[bin_range]['error_length_corr'] = error_length_corr
                    # add ece here
                print_metrics(metrics[bin_range], bin_range)
            else:
                metrics[bin_range] = {'mse': None, 'rmse': None, 'r_squared': None, 'loss': None, 'picp': None, 'miw': None, 'count': 0}
                print(f"No samples in {bin_range[0]}-{bin_range[1]}% bin")
        else:
            metrics[bin_range] = {
                'accuracy': sum(bin_results[bin_range]['accuracy']) / len(bin_results[bin_range]['accuracy']),
                'count': len(bin_results[bin_range]['accuracy']),
            }

    return metrics

def regular_evaluation(model, val_loader, target_mean, target_std, scheduler=None, epoch=None, device=None):
    model.eval()
    total_val_loss = 0
    total_variance = 0
    total_sse = 0
    for seq_batch, value_batch, full_length_batch, src_length_batch in val_loader:
        seq_batch = seq_batch.to(device)
        value_batch = value_batch.to(device)
        value_batch_normalized = (value_batch - target_mean.to(device)) / target_std.to(device)
        pred = model(seq_batch)
        weights = (src_length_batch.unsqueeze(1) / full_length_batch).to(device)
        val_loss = F.mse_loss(pred, value_batch_normalized, reduction='none') * weights
        val_loss = val_loss.sum()
        total_val_loss += val_loss.item()
        # Denormalize predictions for metric calculation
        pred_unnormalized = pred * target_std.to(device) + target_mean.to(device)
        # Calculate metrics on the original scale
        sse = torch.sum((pred_unnormalized - value_batch) ** 2).item()
        total_sse += sse
        # Calculate variance on original scale
        variance = torch.sum((value_batch - torch.mean(value_batch)) ** 2).item()
        total_variance += variance

    val_avg = total_val_loss / len(val_loader)
    if scheduler is not None:
        scheduler.step(val_avg)
    rmse = math.sqrt(total_sse / len(val_loader.dataset))
    r_squared = 1 - (total_sse / total_variance) if total_variance != 0 else 0
    print(f'======= Val Loss: {val_avg}, RMSE: {rmse}, R-squared: {r_squared} for epoch {epoch} =======')
    return val_avg, rmse, r_squared


def setup_translator(classifier_config: dict):
    # Load your config file
    config_path = os.path.join(PROJECT_ROOT, 
                               "configs", 
                               "toy_experiment_onmt",
                                "translate.yml")
    with open(config_path, "r") as f:
        translator_config = yaml.safe_load(f)
    # Create opt namespace
    opt = argparse.Namespace()
    for key, value in translator_config.items():
        setattr(opt, key, value)
   # model_dir = os.path.join(PROJECT_ROOT, "checkpoints", "toy_experiment", classifier_config.classifier_guidance.experiment_name)
    opt.models = [classifier_config.classifier_guidance.onmt_checkpoint_path]
    # Set device correctly
    opt.gpu = -1 if not torch.cuda.is_available() else 0
    opt.output = os.path.join(PROJECT_ROOT, 
                              "experiments", 
                              "toy_experiment",
                              classifier_config.classifier_guidance.experiment_name, 
                              "translations.txt")
    # Validate options
    ArgumentParser.validate_translate_opts(opt)
    # Build translator
    if classifier_config.classifier_guidance.guidance_scale > 0:
        print(f'Using classifier guidance. Guidance scale: {classifier_config.classifier_guidance.guidance_scale}')
        translator = build_classifier_guided_translator(opt, 
                                                        report_score=False,
                                                        logger=None,
                                                        out_file=None,
                                                        config=classifier_config)
    elif classifier_config.classifier_guidance.guidance_scale == 0:
        print(f'Not using classifier guidance. Guidance scale: {classifier_config.classifier_guidance.guidance_scale}')
        translator = build_translator(opt, report_score=False)
    else:
        raise ValueError(f'Invalid guidance scale: {classifier_config.classifier_guidance.guidance_scale}')

    return translator, opt

def translate(src_lines, translator, opt):
    '''
    run translator on raw_src and raw_tgt.

        raw_src: list of strings
        raw_tgt: list of strings
        translator: translator object
    '''
    # Whole sequence as one input (matching training format)
    source_lines = [s.strip().split(' ') for s in src_lines]
    # TODO: replace with a perfect classifier
    all_scores, all_predictions = translator.translate(
        src=source_lines,
        src_feats=defaultdict(list),
        tgt=None,
        batch_size=opt.batch_size,
        batch_type='sents',
        attn_debug=opt.attn_debug,
        align_debug=opt.align_debug,
    )

    return source_lines, all_predictions, all_scores

def train_classifier(src_file: str, 
                     lengths_file: str, 
                     onmt_checkpoint_path: str,
                     config: dict):
    # generate classifier data
    vocab = get_vocab_from_trained_model(onmt_checkpoint_path)
    checkpoint_dir = os.path.join(PROJECT_ROOT, 
                                  "experiments",
                                   "toy_experiment",
                                    config.classifier_guidance.experiment_name,
                                    "classifier")
    os.makedirs(checkpoint_dir, exist_ok=True)

    data_classifier = classifier_data_to_int(src_file, lengths_file, onmt_checkpoint_path)
    ###### comparing this vocab with the onmt vocab
    print(f'Vocab from classifier: {vocab}')
    # get onmt vocab from checkpoint
    #checkpoint_path = '/scratch/project_2006950/multiguide/checkpoints/toy_experiment/onmt_train_vec_size_256_time_stamp_20250721_170848/onmt_model_step_6000.pt'
    vocab_onmt = get_vocab_from_trained_model(onmt_checkpoint_path)
    print(f'Vocab from onmt: {vocab_onmt}')
    model = PropertyPredictor(config=config, alphabet_size=len(vocab))
    model = model.to(device)
    print(f'===== Training on device {device}, with vocab size {len(vocab)}')
    optimizer = torch.optim.Adam(model.parameters(), lr=config.classifier_guidance.train.learning_rate)
    dataloader = DataLoader(data_classifier, batch_size=config.classifier_guidance.train.batch_size, shuffle=True)

    loss_fn = nn.CrossEntropyLoss()
    # TODO: maybe normalize the data before regression??
    for epoch in range(config.classifier_guidance.train.num_epochs):
        total_loss = 0
        for batch_idx, batch in enumerate(dataloader):
            optimizer.zero_grad() 
            x, y, src_length, full_length = batch
            # if src_length.item() < 4:
            #     continue
            x = x.to(device).long()
            y = y.to(device).float()
            y_pred = model(x)
            weight = (src_length.unsqueeze(1)/full_length).to(device)
            y_classes = (y == 7.).long() 
            loss = loss_fn(y_pred, y_classes) * weight
            loss = loss.sum()/weight.sum()
            loss.backward()
            total_loss += loss.item()
        if epoch%config.classifier_guidance.train.print_every==0: 
            print(f'epoch {epoch} total loss: {total_loss}')
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                # 'target_mean': torch.zeros(1),
                # 'target_std': torch.ones(1),
                'target_mean': torch.mean(torch.tensor([d[1] for d in data_classifier]).float()),
                'target_std': torch.std(torch.tensor([d[1] for d in data_classifier]).float()),
                'epoch': epoch,
                'loss': loss.item()
            }, os.path.join(checkpoint_dir, f"classifier_{epoch}.pt"))
        optimizer.step()
    
    # save checkpoint and optimizer state
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        # 'target_mean': torch.zeros(1),
        # 'target_std': torch.ones(1),
        'target_mean': torch.mean(torch.tensor([d[1] for d in data_classifier]).float()),
        'target_std': torch.std(torch.tensor([d[1] for d in data_classifier]).float()),
        'epoch': epoch,
        'loss': loss.item()
    }, os.path.join(checkpoint_dir, f"classifier_{epoch}.pt"))

def evaluate_translations(source_lines: List[str], 
                          all_predictions: List[List[str]],
                          all_scores: List[List[float]],
                          config: dict):
    accuracy = 0
    longest_accuracy = 0
    shortest_accuracy = 0
    for src, preds, scores in zip(source_lines, all_predictions, all_scores):
        # len(preds) = beam_size
        # for now we only consider the first prediction, check that it's correct
        # and print the score
        src_string = " ".join(src)
        simplifications = simplify_expression(src_string, degree=2)
        if config.classifier_guidance.guidance_scale > 0.:
            if preds[0] in simplifications:
                accuracy += 1
            if preds[0] == simplifications[0]:
                longest_accuracy += 1
                print(f'correct, src: {src_string.strip()}, pred: {preds[0]} with score {torch.exp(scores[0])}')
            elif preds[0] == simplifications[1]:
                shortest_accuracy += 1
                print(f'correct, src: {src_string.strip()}, pred: {preds[0]} with score {torch.exp(scores[0])}')
            else:
                correct = simplifications[not int(config.classifier_guidance.target_class_index)]
                print(f'incorrect, src: {src_string.strip()}, correct: {correct}, pred: {preds[0]} with score {torch.exp(scores[0])}')
        elif config.classifier_guidance.guidance_scale == 0.:
            if preds[0] == simplifications[0]: # first one is longest
                accuracy += 1
                longest_accuracy += 1
                print(f'correct, src: {src_string.strip()}, pred: {preds[0]} with score {torch.exp(scores[0])}')
            elif preds[0] == simplifications[1]: # second one is shortest
                shortest_accuracy += 1
                accuracy += 1
                print(f'correct, src: {src_string.strip()}, pred: {preds[0]} with score {torch.exp(scores[0])}')
            else:
                print(f'incorrect, src: {src_string.strip()}, correct: {simplifications[0]}, pred: {preds[0]} with score {torch.exp(scores[0])}')
    print(f'results for {config.classifier_guidance.target_class_index} with guidance scale {config.classifier_guidance.guidance_scale}')
    print(f'Accuracy: {accuracy/len(source_lines)}')
    print(f'Longest accuracy: {longest_accuracy/len(source_lines)}')
    print(f'Shortest accuracy: {shortest_accuracy/len(source_lines)}')
    # log to wandb
    # start wandb run
    if config.general.wandb.mode=='online':
        config_dict = OmegaConf.to_container(config, resolve=True)
        run = wandb.init(project=config.general.wandb.project,
                        entity=config.general.wandb.entity,
                        name=config.general.wandb.name,
                        mode=config.general.wandb.mode,
                        config=config_dict,
                        resume='must' if config.general.wandb.resume_run_id is not None else False,
                        id=config.general.wandb.resume_run_id
                    )
        run.log({
            'accuracy': accuracy/len(source_lines),
            'longest_accuracy': longest_accuracy/len(source_lines),
            'shortest_accuracy': shortest_accuracy/len(source_lines),
        }, step=run.step)

