from flask import Flask, render_template, send_from_directory
from pretraining_attribution import Analyze
from werkzeug.serving import WSGIRequestHandler

# Set maximum number of threads
max_threads = 32

# Extend WSGIRequestHandler to customize server settings
class CustomWSGIRequestHandler(WSGIRequestHandler):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.thread_pool = max_threads


app = Flask(__name__)

format = 'png'
# file_dir = "/tmp/<name>/pretraining-attribution/figs"
file_dir = "/users/<name>/figs_disposable"
models_dict = {
    "llama": ["7b", "13b"],
    "llama3": ["8B"],
    "vicuna": ["7b", "13b"],
    "gemma": ["2b", "7b"],
    "zephyr_gemma_v01": ["7b"],
    "mistralv01": ["7B"],
    "mistralv02": ["7B"]
}
formatting_names = {
    True: "with_formatting",
    False: "no_formatting"
}
sysprompt_names = {
    True: "with_sysprompt",
    False: "no_sysprompt"
}
debug_names = {
    True: "_debug",
    False: ""
}

datasets = [
    "conjugate_prompting", 
    "honest_llama", 
    ["natural_qa", "hh_rlhf"], 
    ["openwebtext", "hh_rlhf"], 
    "big_bench",
    "mmlu",
    "fewshot_mmlu"
]

# Sample data for demonstration purposes
experiments = {
    f"{model}_{n_params}_{formatting_names[formatting]}_{sysprompt_names[sysprompt]}{debug_names[debug]}": {
        "model": model,
        "n_params": n_params,
        "no_formatting": not formatting,
        "use_sysprompt": sysprompt,
        "debug": debug
    }
    for debug in [False, True]
    for model, possible_n_params in models_dict.items()
    for n_params in possible_n_params
    for formatting in [True, False]
    for sysprompt in [True, False]
    if not (formatting and not sysprompt)
}

# datasets = [
#     "big_bench",
#     "mmlu",
#     "fewshot_mmlu"
# ]

@app.route('/images/<path:filename>')
def images(filename):
    return send_from_directory(file_dir, filename)

@app.route('/')
def index():
    return render_template('index.html', experiments=experiments, datasets=datasets)

@app.route('/separation_experiment/<experiment_name>')
def separation_experiment(experiment_name):
    analyze = Analyze()
    experiment_args = experiments.get(experiment_name)
    if experiment_args:
        experiment_data = {}
        for dataset in datasets:
            # print(experiment_args.keys())
            out, dataset_experiment_name = analyze.generate_visualization_data(
                dataset=dataset, format=format, dir=file_dir, **experiment_args)
            
            if out is not None:
                experiment_data[dataset_experiment_name] = out
            # print(out)

        return render_template('separation_experiment.html', experiment_name=experiment_name, experiment_data=experiment_data)
    else:
        return "Experiment not found."
    
@app.route('/jailbreak_prediction_experiment/<experiment_name>')
def jailbreak_prediction_experiment(experiment_name):
    analyze = Analyze()
    experiment_args = experiments.get(experiment_name)
    if experiment_args:
        experiment_data = {}
        # print(experiment_args.keys())
        out, dataset_experiment_name = analyze.generate_jailbreak_prediction_plots(format=format, dir=file_dir, **experiment_args)
        
        if out is not None:
            experiment_data[dataset_experiment_name] = out
        # print(out)

        return render_template('jailbreak_prediction_experiment.html', experiment_name=experiment_name, experiment_data=experiment_data)
    else:
        return "Experiment not found."
    
@app.route('/linearity_testing_experiment/<experiment_name>')
def linearity_testing(experiment_name):
    analyze = Analyze()
    experiment_args = experiments.get(experiment_name)
    if experiment_args:
        experiment_args["dataset"] = "model_written_evals"
        experiment_data = {}
        # print(experiment_args.keys())
        # out, dataset_experiment_name = analyze.generate_linearity_testing_visualization(experiment_args, format=format, dir=file_dir)
        out, dataset_experiment_name = analyze.generate_linearity_testing_visualization(
            experiment_args, 
            format=format, 
            dir=file_dir,
            metric="y_diff_matching",
            log=False
        )
        
        if out is not None:
            experiment_data[dataset_experiment_name] = out
        # print(out)

        return render_template('linearity_testing_experiment.html', experiment_name=experiment_name, experiment_data=experiment_data)
    else:
        return "Experiment not found."
    
@app.route('/alpha_scaling_experiment/<experiment_name>/<dataset_name>')
def alpha_scaling(experiment_name, dataset_name):
    analyze = Analyze()
    experiment_args = experiments.get(experiment_name)
    experiment_args["dataset"] = dataset_name
    if experiment_args:
        experiment_data = {}
        # print(experiment_args.keys())
        out, dataset_experiment_name = analyze.generate_linearity_testing_visualization(
            experiment_args, 
            format=format, 
            dir=file_dir,
            metric="y_diff_matching",
            log=False
        )
        
        if out is not None:
            experiment_data[dataset_experiment_name] = out
        # print(out)

        return render_template('alpha_scaling_experiment.html', experiment_name=experiment_name, experiment_data=experiment_data)
    else:
        return "Experiment not found."


if __name__ == '__main__':
    # Run the Flask app using Gunicorn with multiprocessing
    # Use the appropriate number of workers (processes) as needed
    # Here we use 4 workers as an example
    import multiprocessing
    multiprocessing.set_start_method('fork')  # for Unix-like systems
    from gunicorn.app.base import BaseApplication

    class StandaloneApplication(BaseApplication):
        def __init__(self, app, options=None):
            self.options = options or {}
            self.application = app
            super().__init__()

        def load_config(self):
            config = {key: value for key, value in self.options.items()
                      if key in self.cfg.settings and value is not None}
            for key, value in config.items():
                self.cfg.set(key.lower(), value)

        def load(self):
            return self.application

    options = {
        'bind': '0.0.0.0:3124',  # Adjust host and port as needed
        'workers': 16,  # Number of workers (processes)
        'debug': True,
        'timeout': 600,
        'max_memory_per_child': 8000 * 1024 * 1024,
        "reload": True,
    }

    StandaloneApplication(app, options).run()
    # app.run(debug=True, threaded=True, request_handler=CustomWSGIRequestHandler, port=3124)
