import itertools
import time
import torch as t

from src.mapping import (
    MODELS,
    OPT_MODELS,
    LIBRARIES,
    EXPERIMENT_FUNCTIONS
)
from src.utils import measure_execution_time


PERFORMANCE_METRICS = {
    "speed": measure_execution_time,
}


def measure_performance(performance_fn, experiment_fn, environment_fn, hf_model_name, n_runs=128):
    """
    Inputs:
    - performance_fn: function that benchmarks performance of any function (speed, memory usage, etc.)
    - experiment_fn: library specific function that performs a specific intervention (everything DURING benchmark)
    - environment_fn: function that sets up the environment for the experiment (everything BEFORE benchmark)
    - n_runs: number of times to benchmark the experiment

    Returns:
    - measurements. t.Tensor: a tensor of shape (n_runs,) containing the performance measurements
    """

    # setup environment
    env_dict = environment_fn(hf_model_name)
    
    # run benchmark n_runs times
    performance_values = t.zeros(n_runs)
    for i in range(n_runs):
        runtime = performance_fn(experiment_fn, env_dict)
        performance_values[i] = runtime
    return performance_values


def run_benchmark_pipeline():

    # define experiments
    benchmarks = itertools.product(
        EXPERIMENT_FUNCTIONS.keys(), 
        MODELS.keys(), 
        LIBRARIES.keys(), 
        PERFORMANCE_METRICS,
    )
    
    performances = {}
    for exp_name, model_name, lib_name, perf_name in benchmarks:
        print("---")
        print(f"Running: {exp_name}, {model_name}, {lib_name}, {perf_name}")
        
        # update output dict
        if exp_name not in performances:
            performances[exp_name] = {}
        if model_name not in performances[exp_name]:
            performances[exp_name][model_name] = {}
        if lib_name not in performances[exp_name][model_name]:
            performances[exp_name][model_name][lib_name] = {}

        # extract functions
        hf_model_name = MODELS[model_name]
        environment_fn = LIBRARIES[lib_name]
        experiment_fn = EXPERIMENT_FUNCTIONS[exp_name][lib_name]
        performance_fn = PERFORMANCE_METRICS[perf_name]

        # measure performance
        performance = measure_performance(performance_fn, experiment_fn, environment_fn, hf_model_name)
        performances[exp_name][model_name][lib_name][perf_name] = performance

        # print results; exclude first run as it is usually an outlier
        performance = performance[1:]
        print(f'Mean runtime: {performance.mean()}')
        print(f'Std. of runtime: {performance.std()}')


