# bayes_opt_lambda.py
# Perform Bayesian Optimization to find the best Lambda for mixed diffusion

import argparse
import subprocess
import json
import tempfile
import os
from skopt import gp_minimize
from skopt.space import Real
from skopt.utils import use_named_args


space = [Real(0.0, 1.0, name='Lambda')]


@use_named_args(space)
def objective(Lambda):

    my_env = os.environ.copy()
    my_env["CUDA_VISIBLE_DEVICES"] = "6"

    
    proc = subprocess.run(
        ['python', 'DeRaDiff_SDXL.py', '--Lambda', str(Lambda), '--model_name', 'demo-tmp-sdxl-500'],
        stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True, text=True, env=my_env
    )



    output_lines = proc.stdout.splitlines() + proc.stderr.splitlines()

    json_line = None
    for line in reversed(output_lines):
        if line.strip().startswith('{'):
            json_line = line.strip()
            break
    if json_line is None:
        raise RuntimeError(f"errorrr!!")

    out = json.loads(json_line)
    mean_score = out.get('mean_score')
    if mean_score is None:
        raise ValueError(f"check json file")

    # Return negative because skopt minimizes
    print("Mean score is", -float(mean_score))
    return -float(mean_score)    



def main():
    parser = argparse.ArgumentParser(description='Bayesian optimize Lambda')
    parser.add_argument('--n_calls', type=int, default=20,
                        help='Total BO iterations (including initial points)')
    parser.add_argument('--n_initial', type=int, default=5,
                        help='Number of initial random evaluations')
    args = parser.parse_args()

    
    res = gp_minimize(
        objective,
        space,
        n_calls=args.n_calls,
        n_initial_points=args.n_initial,
        acq_func='EI',     
        random_state=42,
    )

    best_lambda = res.x[0]
    best_score = -res.fun

    print(f"Best Lambda: {best_lambda:.4f}")
    print(f"Best mean PickScore: {best_score:.4f}")

    
    with open('bo_extended_results_500.json', 'w') as f:
        json.dump({
            'best_lambda': best_lambda,
            'best_score': best_score,
            'all_lambdas': list(res.x_iters),
            'all_scores': [-v for v in res.func_vals]
        }, f, indent=2)

if __name__ == '__main__':
    main()

